19
19
//!
20
20
//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
21
21
22
- use serde:: { Deserialize , Serialize } ;
22
+ use serde_core:: de:: { self , MapAccess , Visitor } ;
23
+ use serde_core:: ser:: SerializeStruct ;
24
+ use serde_core:: { Deserialize , Deserializer , Serialize , Serializer } ;
25
+ use std:: fmt;
23
26
24
27
use crate :: { ArrowError , DataType , extension:: ExtensionType } ;
25
28
@@ -129,7 +132,7 @@ impl FixedShapeTensor {
129
132
}
130
133
131
134
/// Extension type metadata for [`FixedShapeTensor`].
132
- #[ derive( Debug , Clone , PartialEq , Deserialize , Serialize ) ]
135
+ #[ derive( Debug , Clone , PartialEq ) ]
133
136
pub struct FixedShapeTensorMetadata {
134
137
/// The physical shape of the contained tensors.
135
138
shape : Vec < usize > ,
@@ -141,6 +144,143 @@ pub struct FixedShapeTensorMetadata {
141
144
permutations : Option < Vec < usize > > ,
142
145
}
143
146
147
+ impl Serialize for FixedShapeTensorMetadata {
148
+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
149
+ where
150
+ S : Serializer ,
151
+ {
152
+ let mut state = serializer. serialize_struct ( "FixedShapeTensorMetadata" , 3 ) ?;
153
+ state. serialize_field ( "shape" , & self . shape ) ?;
154
+ state. serialize_field ( "dim_names" , & self . dim_names ) ?;
155
+ state. serialize_field ( "permutations" , & self . permutations ) ?;
156
+ state. end ( )
157
+ }
158
+ }
159
+
160
+ #[ derive( Debug ) ]
161
+ enum MetadataField {
162
+ Shape ,
163
+ DimNames ,
164
+ Permutations ,
165
+ }
166
+
167
+ struct MetadataFieldVisitor ;
168
+
169
+ impl < ' de > Visitor < ' de > for MetadataFieldVisitor {
170
+ type Value = MetadataField ;
171
+
172
+ fn expecting ( & self , formatter : & mut fmt:: Formatter ) -> fmt:: Result {
173
+ formatter. write_str ( "`shape`, `dim_names`, or `permutations`" )
174
+ }
175
+
176
+ fn visit_str < E > ( self , value : & str ) -> Result < MetadataField , E >
177
+ where
178
+ E : de:: Error ,
179
+ {
180
+ match value {
181
+ "shape" => Ok ( MetadataField :: Shape ) ,
182
+ "dim_names" => Ok ( MetadataField :: DimNames ) ,
183
+ "permutations" => Ok ( MetadataField :: Permutations ) ,
184
+ _ => Err ( de:: Error :: unknown_field (
185
+ value,
186
+ & [ "shape" , "dim_names" , "permutations" ] ,
187
+ ) ) ,
188
+ }
189
+ }
190
+ }
191
+
192
+ impl < ' de > Deserialize < ' de > for MetadataField {
193
+ fn deserialize < D > ( deserializer : D ) -> Result < MetadataField , D :: Error >
194
+ where
195
+ D : Deserializer < ' de > ,
196
+ {
197
+ deserializer. deserialize_identifier ( MetadataFieldVisitor )
198
+ }
199
+ }
200
+
201
+ struct FixedShapeTensorMetadataVisitor ;
202
+
203
+ impl < ' de > Visitor < ' de > for FixedShapeTensorMetadataVisitor {
204
+ type Value = FixedShapeTensorMetadata ;
205
+
206
+ fn expecting ( & self , formatter : & mut fmt:: Formatter ) -> fmt:: Result {
207
+ formatter. write_str ( "struct FixedShapeTensorMetadata" )
208
+ }
209
+
210
+ fn visit_seq < V > ( self , mut seq : V ) -> Result < FixedShapeTensorMetadata , V :: Error >
211
+ where
212
+ V : de:: SeqAccess < ' de > ,
213
+ {
214
+ let shape = seq
215
+ . next_element ( ) ?
216
+ . ok_or_else ( || de:: Error :: invalid_length ( 0 , & self ) ) ?;
217
+ let dim_names = seq
218
+ . next_element ( ) ?
219
+ . ok_or_else ( || de:: Error :: invalid_length ( 1 , & self ) ) ?;
220
+ let permutations = seq
221
+ . next_element ( ) ?
222
+ . ok_or_else ( || de:: Error :: invalid_length ( 2 , & self ) ) ?;
223
+ Ok ( FixedShapeTensorMetadata {
224
+ shape,
225
+ dim_names,
226
+ permutations,
227
+ } )
228
+ }
229
+
230
+ fn visit_map < V > ( self , mut map : V ) -> Result < FixedShapeTensorMetadata , V :: Error >
231
+ where
232
+ V : MapAccess < ' de > ,
233
+ {
234
+ let mut shape = None ;
235
+ let mut dim_names = None ;
236
+ let mut permutations = None ;
237
+
238
+ while let Some ( key) = map. next_key ( ) ? {
239
+ match key {
240
+ MetadataField :: Shape => {
241
+ if shape. is_some ( ) {
242
+ return Err ( de:: Error :: duplicate_field ( "shape" ) ) ;
243
+ }
244
+ shape = Some ( map. next_value ( ) ?) ;
245
+ }
246
+ MetadataField :: DimNames => {
247
+ if dim_names. is_some ( ) {
248
+ return Err ( de:: Error :: duplicate_field ( "dim_names" ) ) ;
249
+ }
250
+ dim_names = Some ( map. next_value ( ) ?) ;
251
+ }
252
+ MetadataField :: Permutations => {
253
+ if permutations. is_some ( ) {
254
+ return Err ( de:: Error :: duplicate_field ( "permutations" ) ) ;
255
+ }
256
+ permutations = Some ( map. next_value ( ) ?) ;
257
+ }
258
+ }
259
+ }
260
+
261
+ let shape = shape. ok_or_else ( || de:: Error :: missing_field ( "shape" ) ) ?;
262
+
263
+ Ok ( FixedShapeTensorMetadata {
264
+ shape,
265
+ dim_names,
266
+ permutations,
267
+ } )
268
+ }
269
+ }
270
+
271
+ impl < ' de > Deserialize < ' de > for FixedShapeTensorMetadata {
272
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
273
+ where
274
+ D : Deserializer < ' de > ,
275
+ {
276
+ deserializer. deserialize_struct (
277
+ "FixedShapeTensorMetadata" ,
278
+ & [ "shape" , "dim_names" , "permutations" ] ,
279
+ FixedShapeTensorMetadataVisitor ,
280
+ )
281
+ }
282
+ }
283
+
144
284
impl FixedShapeTensorMetadata {
145
285
/// Returns metadata for a fixed shape tensor extension type.
146
286
///
@@ -377,9 +517,8 @@ mod tests {
377
517
}
378
518
379
519
#[ test]
380
- #[ should_panic(
381
- expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`"
382
- ) ]
520
+ #[ should_panic( expected = "FixedShapeTensor metadata deserialization failed: \
521
+ unknown field `not-shape`, expected one of `shape`, `dim_names`, `permutations`") ]
383
522
fn invalid_metadata ( ) {
384
523
let fixed_shape_tensor =
385
524
FixedShapeTensor :: try_new ( DataType :: Float32 , [ 100 , 200 , 500 ] , None , None ) . unwrap ( ) ;
0 commit comments