Skip to content

Commit 2e967e3

Browse files
committed
Moves WeightsDescr down so it won't need forward references
Forward references are not yet properly supported by the interactive docs generation scripts. They would also probably cause trouble on older versions of python.
1 parent 6a5e9e8 commit 2e967e3

File tree

1 file changed

+83
-83
lines changed

1 file changed

+83
-83
lines changed

bioimageio/spec/model/v0_4.py

Lines changed: 83 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -238,89 +238,6 @@ def file(self):
238238
]
239239

240240

241-
class WeightsDescr(Node):
242-
keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
243-
onnx: Optional[OnnxWeightsDescr] = None
244-
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
245-
tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
246-
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
247-
None
248-
)
249-
torchscript: Optional[TorchscriptWeightsDescr] = None
250-
251-
@model_validator(mode="after")
252-
def check_one_entry(self) -> Self:
253-
if all(
254-
entry is None
255-
for entry in [
256-
self.keras_hdf5,
257-
self.onnx,
258-
self.pytorch_state_dict,
259-
self.tensorflow_js,
260-
self.tensorflow_saved_model_bundle,
261-
self.torchscript,
262-
]
263-
):
264-
raise ValueError("Missing weights entry")
265-
266-
return self
267-
268-
def __getitem__(
269-
self,
270-
key: WeightsFormat,
271-
):
272-
if key == "keras_hdf5":
273-
ret = self.keras_hdf5
274-
elif key == "onnx":
275-
ret = self.onnx
276-
elif key == "pytorch_state_dict":
277-
ret = self.pytorch_state_dict
278-
elif key == "tensorflow_js":
279-
ret = self.tensorflow_js
280-
elif key == "tensorflow_saved_model_bundle":
281-
ret = self.tensorflow_saved_model_bundle
282-
elif key == "torchscript":
283-
ret = self.torchscript
284-
else:
285-
raise KeyError(key)
286-
287-
if ret is None:
288-
raise KeyError(key)
289-
290-
return ret
291-
292-
@property
293-
def available_formats(self):
294-
return {
295-
**({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
296-
**({} if self.onnx is None else {"onnx": self.onnx}),
297-
**(
298-
{}
299-
if self.pytorch_state_dict is None
300-
else {"pytorch_state_dict": self.pytorch_state_dict}
301-
),
302-
**(
303-
{}
304-
if self.tensorflow_js is None
305-
else {"tensorflow_js": self.tensorflow_js}
306-
),
307-
**(
308-
{}
309-
if self.tensorflow_saved_model_bundle is None
310-
else {
311-
"tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
312-
}
313-
),
314-
**({} if self.torchscript is None else {"torchscript": self.torchscript}),
315-
}
316-
317-
@property
318-
def missing_formats(self):
319-
return {
320-
wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
321-
}
322-
323-
324241
class WeightsEntryDescrBase(FileDescr):
325242
type: ClassVar[WeightsFormat]
326243
weights_format_name: ClassVar[str] # human readable
@@ -544,6 +461,89 @@ def _tfv(cls, value: Any):
544461
return value
545462

546463

464+
class WeightsDescr(Node):
465+
keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
466+
onnx: Optional[OnnxWeightsDescr] = None
467+
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
468+
tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
469+
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
470+
None
471+
)
472+
torchscript: Optional[TorchscriptWeightsDescr] = None
473+
474+
@model_validator(mode="after")
475+
def check_one_entry(self) -> Self:
476+
if all(
477+
entry is None
478+
for entry in [
479+
self.keras_hdf5,
480+
self.onnx,
481+
self.pytorch_state_dict,
482+
self.tensorflow_js,
483+
self.tensorflow_saved_model_bundle,
484+
self.torchscript,
485+
]
486+
):
487+
raise ValueError("Missing weights entry")
488+
489+
return self
490+
491+
def __getitem__(
492+
self,
493+
key: WeightsFormat,
494+
):
495+
if key == "keras_hdf5":
496+
ret = self.keras_hdf5
497+
elif key == "onnx":
498+
ret = self.onnx
499+
elif key == "pytorch_state_dict":
500+
ret = self.pytorch_state_dict
501+
elif key == "tensorflow_js":
502+
ret = self.tensorflow_js
503+
elif key == "tensorflow_saved_model_bundle":
504+
ret = self.tensorflow_saved_model_bundle
505+
elif key == "torchscript":
506+
ret = self.torchscript
507+
else:
508+
raise KeyError(key)
509+
510+
if ret is None:
511+
raise KeyError(key)
512+
513+
return ret
514+
515+
@property
516+
def available_formats(self):
517+
return {
518+
**({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
519+
**({} if self.onnx is None else {"onnx": self.onnx}),
520+
**(
521+
{}
522+
if self.pytorch_state_dict is None
523+
else {"pytorch_state_dict": self.pytorch_state_dict}
524+
),
525+
**(
526+
{}
527+
if self.tensorflow_js is None
528+
else {"tensorflow_js": self.tensorflow_js}
529+
),
530+
**(
531+
{}
532+
if self.tensorflow_saved_model_bundle is None
533+
else {
534+
"tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
535+
}
536+
),
537+
**({} if self.torchscript is None else {"torchscript": self.torchscript}),
538+
}
539+
540+
@property
541+
def missing_formats(self):
542+
return {
543+
wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
544+
}
545+
546+
547547
class ParameterizedInputShape(Node):
548548
"""A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`."""
549549

0 commit comments

Comments
 (0)