19
19
//!
20
20
//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
21
21
22
- use serde:: { Deserialize , Serialize } ;
22
+ use serde:: { Deserialize , Deserializer , Serialize , Serializer } ;
23
23
24
24
use crate :: { ArrowError , DataType , extension:: ExtensionType } ;
25
25
@@ -129,7 +129,7 @@ impl FixedShapeTensor {
129
129
}
130
130
131
131
/// Extension type metadata for [`FixedShapeTensor`].
132
- #[ derive( Debug , Clone , PartialEq , Deserialize , Serialize ) ]
132
+ #[ derive( Debug , Clone , PartialEq ) ]
133
133
pub struct FixedShapeTensorMetadata {
134
134
/// The physical shape of the contained tensors.
135
135
shape : Vec < usize > ,
@@ -141,6 +141,144 @@ pub struct FixedShapeTensorMetadata {
141
141
permutations : Option < Vec < usize > > ,
142
142
}
143
143
144
+ impl Serialize for FixedShapeTensorMetadata {
145
+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
146
+ where
147
+ S : Serializer ,
148
+ {
149
+ use serde:: ser:: SerializeStruct ;
150
+ let mut state = serializer. serialize_struct ( "FixedShapeTensorMetadata" , 3 ) ?;
151
+ state. serialize_field ( "shape" , & self . shape ) ?;
152
+ state. serialize_field ( "dim_names" , & self . dim_names ) ?;
153
+ state. serialize_field ( "permutations" , & self . permutations ) ?;
154
+ state. end ( )
155
+ }
156
+ }
157
+
158
+ impl < ' de > Deserialize < ' de > for FixedShapeTensorMetadata {
159
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
160
+ where
161
+ D : Deserializer < ' de > ,
162
+ {
163
+ use serde:: de:: { self , MapAccess , Visitor } ;
164
+ use std:: fmt;
165
+
166
+ #[ derive( Debug ) ]
167
+ enum Field {
168
+ Shape ,
169
+ DimNames ,
170
+ Permutations ,
171
+ }
172
+
173
+ impl < ' de > Deserialize < ' de > for Field {
174
+ fn deserialize < D > ( deserializer : D ) -> Result < Field , D :: Error >
175
+ where
176
+ D : Deserializer < ' de > ,
177
+ {
178
+ struct FieldVisitor ;
179
+
180
+ impl < ' de > Visitor < ' de > for FieldVisitor {
181
+ type Value = Field ;
182
+
183
+ fn expecting ( & self , formatter : & mut fmt:: Formatter ) -> fmt:: Result {
184
+ formatter. write_str ( "`shape`, `dim_names`, or `permutations`" )
185
+ }
186
+
187
+ fn visit_str < E > ( self , value : & str ) -> Result < Field , E >
188
+ where
189
+ E : de:: Error ,
190
+ {
191
+ match value {
192
+ "shape" => Ok ( Field :: Shape ) ,
193
+ "dim_names" => Ok ( Field :: DimNames ) ,
194
+ "permutations" => Ok ( Field :: Permutations ) ,
195
+ _ => Err ( de:: Error :: unknown_field (
196
+ value,
197
+ & [ "shape" , "dim_names" , "permutations" ] ,
198
+ ) ) ,
199
+ }
200
+ }
201
+ }
202
+
203
+ deserializer. deserialize_identifier ( FieldVisitor )
204
+ }
205
+ }
206
+
207
+ struct FixedShapeTensorMetadataVisitor ;
208
+
209
+ impl < ' de > Visitor < ' de > for FixedShapeTensorMetadataVisitor {
210
+ type Value = FixedShapeTensorMetadata ;
211
+
212
+ fn expecting ( & self , formatter : & mut fmt:: Formatter ) -> fmt:: Result {
213
+ formatter. write_str ( "struct FixedShapeTensorMetadata" )
214
+ }
215
+
216
+ fn visit_seq < V > ( self , mut seq : V ) -> Result < FixedShapeTensorMetadata , V :: Error >
217
+ where
218
+ V : de:: SeqAccess < ' de > ,
219
+ {
220
+ let shape = seq
221
+ . next_element ( ) ?
222
+ . ok_or_else ( || de:: Error :: invalid_length ( 0 , & self ) ) ?;
223
+ let dim_names = seq
224
+ . next_element ( ) ?
225
+ . ok_or_else ( || de:: Error :: invalid_length ( 1 , & self ) ) ?;
226
+ let permutations = seq
227
+ . next_element ( ) ?
228
+ . ok_or_else ( || de:: Error :: invalid_length ( 2 , & self ) ) ?;
229
+ Ok ( FixedShapeTensorMetadata {
230
+ shape,
231
+ dim_names,
232
+ permutations,
233
+ } )
234
+ }
235
+
236
+ fn visit_map < V > ( self , mut map : V ) -> Result < FixedShapeTensorMetadata , V :: Error >
237
+ where
238
+ V : MapAccess < ' de > ,
239
+ {
240
+ let mut shape = None ;
241
+ let mut dim_names = None ;
242
+ let mut permutations = None ;
243
+
244
+ while let Some ( key) = map. next_key ( ) ? {
245
+ match key {
246
+ Field :: Shape => {
247
+ if shape. is_some ( ) {
248
+ return Err ( de:: Error :: duplicate_field ( "shape" ) ) ;
249
+ }
250
+ shape = Some ( map. next_value ( ) ?) ;
251
+ }
252
+ Field :: DimNames => {
253
+ if dim_names. is_some ( ) {
254
+ return Err ( de:: Error :: duplicate_field ( "dim_names" ) ) ;
255
+ }
256
+ dim_names = Some ( map. next_value ( ) ?) ;
257
+ }
258
+ Field :: Permutations => {
259
+ if permutations. is_some ( ) {
260
+ return Err ( de:: Error :: duplicate_field ( "permutations" ) ) ;
261
+ }
262
+ permutations = Some ( map. next_value ( ) ?) ;
263
+ }
264
+ }
265
+ }
266
+
267
+ let shape = shape. ok_or_else ( || de:: Error :: missing_field ( "shape" ) ) ?;
268
+
269
+ Ok ( FixedShapeTensorMetadata {
270
+ shape,
271
+ dim_names,
272
+ permutations,
273
+ } )
274
+ }
275
+ }
276
+
277
+ const FIELDS : & [ & str ] = & [ "shape" , "dim_names" , "permutations" ] ;
278
+ deserializer. deserialize_struct ( "FixedShapeTensorMetadata" , FIELDS , FixedShapeTensorMetadataVisitor )
279
+ }
280
+ }
281
+
144
282
impl FixedShapeTensorMetadata {
145
283
/// Returns metadata for a fixed shape tensor extension type.
146
284
///
0 commit comments