66from marshmallow import missing
77from marshmallow .utils import _Missing
88
9- from bioimageio .spec .model . v0_3 import raw_nodes as model_raw_nodes
10- from bioimageio .spec .rdf . v0_2 import raw_nodes as rdf_raw_nodes
9+ from bioimageio .spec .model import raw_nodes as model_raw_nodes
10+ from bioimageio .spec .rdf import raw_nodes as rdf_raw_nodes
1111from bioimageio .spec .shared import raw_nodes
1212
1313
@@ -130,6 +130,14 @@ def __post_init__(self):
130130 self .axes = tuple (self .axes )
131131
132132
133+ @dataclass
134+ class ImportedSource (Node ):
135+ factory : Callable
136+
137+ def __call__ (self , * args , ** kwargs ):
138+ return self .factory (* args , ** kwargs )
139+
140+
133141@dataclass
134142class _WeightsEntryBase (Node , model_raw_nodes ._WeightsEntryBase ):
135143 source : Path = missing
@@ -147,7 +155,7 @@ class OnnxWeightsEntry(_WeightsEntryBase, model_raw_nodes.OnnxWeightsEntry):
147155
148156@dataclass
149157class PytorchStateDictWeightsEntry (_WeightsEntryBase , model_raw_nodes .PytorchStateDictWeightsEntry ):
150- pass
158+ architecture : Union [ _Missing , ImportedSource ] = missing
151159
152160
153161@dataclass
@@ -175,17 +183,8 @@ class TensorflowSavedModelBundleWeightsEntry(_WeightsEntryBase, model_raw_nodes.
175183]
176184
177185
178- @dataclass
179- class ImportedSource (Node ):
180- factory : Callable
181-
182- def __call__ (self , * args , ** kwargs ):
183- return self .factory (* args , ** kwargs )
184-
185-
186186@dataclass
187187class Model (model_raw_nodes .Model , RDF , Node ):
188- source : Union [_Missing , ImportedSource ] = missing
189188 test_inputs : List [Path ] = missing
190189 test_outputs : List [Path ] = missing
191190 weights : Dict [model_raw_nodes .WeightsFormat , WeightsEntry ] = missing
0 commit comments