Skip to content

Commit be683e4

Browse files
committed
update nodes for model spec v0.4
1 parent 4ac31a0 commit be683e4

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

bioimageio/core/resource_io/nodes.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from marshmallow import missing
77
from marshmallow.utils import _Missing
88

9-
from bioimageio.spec.model.v0_3 import raw_nodes as model_raw_nodes
10-
from bioimageio.spec.rdf.v0_2 import raw_nodes as rdf_raw_nodes
9+
from bioimageio.spec.model import raw_nodes as model_raw_nodes
10+
from bioimageio.spec.rdf import raw_nodes as rdf_raw_nodes
1111
from bioimageio.spec.shared import raw_nodes
1212

1313

@@ -130,6 +130,14 @@ def __post_init__(self):
130130
self.axes = tuple(self.axes)
131131

132132

133+
@dataclass
134+
class ImportedSource(Node):
135+
factory: Callable
136+
137+
def __call__(self, *args, **kwargs):
138+
return self.factory(*args, **kwargs)
139+
140+
133141
@dataclass
134142
class _WeightsEntryBase(Node, model_raw_nodes._WeightsEntryBase):
135143
source: Path = missing
@@ -147,7 +155,7 @@ class OnnxWeightsEntry(_WeightsEntryBase, model_raw_nodes.OnnxWeightsEntry):
147155

148156
@dataclass
149157
class PytorchStateDictWeightsEntry(_WeightsEntryBase, model_raw_nodes.PytorchStateDictWeightsEntry):
150-
pass
158+
architecture: Union[_Missing, ImportedSource] = missing
151159

152160

153161
@dataclass
@@ -175,17 +183,8 @@ class TensorflowSavedModelBundleWeightsEntry(_WeightsEntryBase, model_raw_nodes.
175183
]
176184

177185

178-
@dataclass
179-
class ImportedSource(Node):
180-
factory: Callable
181-
182-
def __call__(self, *args, **kwargs):
183-
return self.factory(*args, **kwargs)
184-
185-
186186
@dataclass
187187
class Model(model_raw_nodes.Model, RDF, Node):
188-
source: Union[_Missing, ImportedSource] = missing
189188
test_inputs: List[Path] = missing
190189
test_outputs: List[Path] = missing
191190
weights: Dict[model_raw_nodes.WeightsFormat, WeightsEntry] = missing

0 commit comments

Comments
 (0)