@@ -238,89 +238,6 @@ def file(self):
238238]
239239
240240
241- class WeightsDescr (Node ):
242- keras_hdf5 : Optional [KerasHdf5WeightsDescr ] = None
243- onnx : Optional [OnnxWeightsDescr ] = None
244- pytorch_state_dict : Optional [PytorchStateDictWeightsDescr ] = None
245- tensorflow_js : Optional [TensorflowJsWeightsDescr ] = None
246- tensorflow_saved_model_bundle : Optional [TensorflowSavedModelBundleWeightsDescr ] = (
247- None
248- )
249- torchscript : Optional [TorchscriptWeightsDescr ] = None
250-
251- @model_validator (mode = "after" )
252- def check_one_entry (self ) -> Self :
253- if all (
254- entry is None
255- for entry in [
256- self .keras_hdf5 ,
257- self .onnx ,
258- self .pytorch_state_dict ,
259- self .tensorflow_js ,
260- self .tensorflow_saved_model_bundle ,
261- self .torchscript ,
262- ]
263- ):
264- raise ValueError ("Missing weights entry" )
265-
266- return self
267-
268- def __getitem__ (
269- self ,
270- key : WeightsFormat ,
271- ):
272- if key == "keras_hdf5" :
273- ret = self .keras_hdf5
274- elif key == "onnx" :
275- ret = self .onnx
276- elif key == "pytorch_state_dict" :
277- ret = self .pytorch_state_dict
278- elif key == "tensorflow_js" :
279- ret = self .tensorflow_js
280- elif key == "tensorflow_saved_model_bundle" :
281- ret = self .tensorflow_saved_model_bundle
282- elif key == "torchscript" :
283- ret = self .torchscript
284- else :
285- raise KeyError (key )
286-
287- if ret is None :
288- raise KeyError (key )
289-
290- return ret
291-
292- @property
293- def available_formats (self ):
294- return {
295- ** ({} if self .keras_hdf5 is None else {"keras_hdf5" : self .keras_hdf5 }),
296- ** ({} if self .onnx is None else {"onnx" : self .onnx }),
297- ** (
298- {}
299- if self .pytorch_state_dict is None
300- else {"pytorch_state_dict" : self .pytorch_state_dict }
301- ),
302- ** (
303- {}
304- if self .tensorflow_js is None
305- else {"tensorflow_js" : self .tensorflow_js }
306- ),
307- ** (
308- {}
309- if self .tensorflow_saved_model_bundle is None
310- else {
311- "tensorflow_saved_model_bundle" : self .tensorflow_saved_model_bundle
312- }
313- ),
314- ** ({} if self .torchscript is None else {"torchscript" : self .torchscript }),
315- }
316-
317- @property
318- def missing_formats (self ):
319- return {
320- wf for wf in get_args (WeightsFormat ) if wf not in self .available_formats
321- }
322-
323-
324241class WeightsEntryDescrBase (FileDescr ):
325242 type : ClassVar [WeightsFormat ]
326243 weights_format_name : ClassVar [str ] # human readable
@@ -544,6 +461,89 @@ def _tfv(cls, value: Any):
544461 return value
545462
546463
464+ class WeightsDescr (Node ):
465+ keras_hdf5 : Optional [KerasHdf5WeightsDescr ] = None
466+ onnx : Optional [OnnxWeightsDescr ] = None
467+ pytorch_state_dict : Optional [PytorchStateDictWeightsDescr ] = None
468+ tensorflow_js : Optional [TensorflowJsWeightsDescr ] = None
469+ tensorflow_saved_model_bundle : Optional [TensorflowSavedModelBundleWeightsDescr ] = (
470+ None
471+ )
472+ torchscript : Optional [TorchscriptWeightsDescr ] = None
473+
474+ @model_validator (mode = "after" )
475+ def check_one_entry (self ) -> Self :
476+ if all (
477+ entry is None
478+ for entry in [
479+ self .keras_hdf5 ,
480+ self .onnx ,
481+ self .pytorch_state_dict ,
482+ self .tensorflow_js ,
483+ self .tensorflow_saved_model_bundle ,
484+ self .torchscript ,
485+ ]
486+ ):
487+ raise ValueError ("Missing weights entry" )
488+
489+ return self
490+
491+ def __getitem__ (
492+ self ,
493+ key : WeightsFormat ,
494+ ):
495+ if key == "keras_hdf5" :
496+ ret = self .keras_hdf5
497+ elif key == "onnx" :
498+ ret = self .onnx
499+ elif key == "pytorch_state_dict" :
500+ ret = self .pytorch_state_dict
501+ elif key == "tensorflow_js" :
502+ ret = self .tensorflow_js
503+ elif key == "tensorflow_saved_model_bundle" :
504+ ret = self .tensorflow_saved_model_bundle
505+ elif key == "torchscript" :
506+ ret = self .torchscript
507+ else :
508+ raise KeyError (key )
509+
510+ if ret is None :
511+ raise KeyError (key )
512+
513+ return ret
514+
515+ @property
516+ def available_formats (self ):
517+ return {
518+ ** ({} if self .keras_hdf5 is None else {"keras_hdf5" : self .keras_hdf5 }),
519+ ** ({} if self .onnx is None else {"onnx" : self .onnx }),
520+ ** (
521+ {}
522+ if self .pytorch_state_dict is None
523+ else {"pytorch_state_dict" : self .pytorch_state_dict }
524+ ),
525+ ** (
526+ {}
527+ if self .tensorflow_js is None
528+ else {"tensorflow_js" : self .tensorflow_js }
529+ ),
530+ ** (
531+ {}
532+ if self .tensorflow_saved_model_bundle is None
533+ else {
534+ "tensorflow_saved_model_bundle" : self .tensorflow_saved_model_bundle
535+ }
536+ ),
537+ ** ({} if self .torchscript is None else {"torchscript" : self .torchscript }),
538+ }
539+
540+ @property
541+ def missing_formats (self ):
542+ return {
543+ wf for wf in get_args (WeightsFormat ) if wf not in self .available_formats
544+ }
545+
546+
547547class ParameterizedInputShape (Node ):
548548 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`."""
549549
0 commit comments