33import os
44from pathlib import Path
55from typing import Any , Dict , List , Optional , Tuple , Union
6+ from warnings import warn
67
78import imageio
89import numpy as np
@@ -73,6 +74,22 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
7374 return weight_kwargs , tmp_archtecture
7475
7576
77+ def _get_attachments (attachments , root ):
78+ assert isinstance (attachments , dict )
79+ if "files" in attachments :
80+ afiles = attachments ["files" ]
81+ if isinstance (afiles , str ):
82+ afiles = [afiles ]
83+
84+ if isinstance (afiles , list ):
85+ afiles = _ensure_local_or_url (afiles , root )
86+ else :
87+ raise TypeError (attachments )
88+
89+ attachments ["files" ] = afiles
90+ return attachments
91+
92+
7693def _get_weights (
7794 original_weight_source ,
7895 weight_type ,
@@ -81,67 +98,94 @@ def _get_weights(
8198 model_kwargs = None ,
8299 tensorflow_version = None ,
83100 opset_version = None ,
101+ pytorch_version = None ,
84102 dependencies = None ,
85- ** kwargs ,
103+ attachments = None ,
86104):
87105 weight_path = resolve_source (original_weight_source , root )
88106 if weight_type is None :
89107 weight_type = _infer_weight_type (weight_path )
90108 weight_hash = _get_hash (weight_path )
91109
92- attachments = {"attachments" : kwargs ["weight_attachments" ]} if "weight_attachments" in kwargs else {}
93110 weight_types = model_spec .raw_nodes .WeightsFormat
94111 weight_source = _ensure_local_or_url (original_weight_source , root )
95112
113+ weight_kwargs = {"source" : weight_source , "sha256" : weight_hash }
114+ if attachments is not None :
115+ weight_kwargs ["attachments" ] = _get_attachments (attachments , root )
116+ if dependencies is not None :
117+ weight_kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
118+
96119 tmp_archtecture = None
97120 if weight_type == "pytorch_state_dict" :
98121 # pytorch-state-dict -> we need an architecture definition
99- weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
100- weight_kwargs .update (** attachments )
101- weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (
102- source = weight_source , sha256 = weight_hash , ** weight_kwargs
103- )
104- if dependencies is not None :
105- weight_kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
122+ pytorch_weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
123+ weight_kwargs .update (** pytorch_weight_kwargs )
124+ if pytorch_version is not None :
125+ weight_kwargs ["pytorch_version" ] = pytorch_version
126+ elif dependencies is None :
127+ warn (
128+ "You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
129+ "It may not be possible to create an environmnet where your model can be used."
130+ )
131+ weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (** weight_kwargs )
106132
107133 elif weight_type == "onnx" :
108- if opset_version is None :
109- raise ValueError ("opset_version needs to be passed for building an onnx model" )
110- weights = model_spec .raw_nodes .OnnxWeightsEntry (
111- source = weight_source , sha256 = weight_hash , opset_version = opset_version , ** attachments
112- )
134+ if opset_version is not None :
135+ weight_kwargs ["opset_version" ] = opset_version
136+ elif dependencies is None :
137+ warn (
138+ "You are building an onnx model but have neither passed dependencies nor the opset_version."
139+ "It may not be possible to create an environmnet where your model can be used."
140+ )
141+ weights = model_spec .raw_nodes .OnnxWeightsEntry (** weight_kwargs )
113142
114143 elif weight_type == "torchscript" :
115- weights = model_spec .raw_nodes .TorchscriptWeightsEntry (source = weight_source , sha256 = weight_hash , ** attachments )
144+ if pytorch_version is not None :
145+ weight_kwargs ["pytorch_version" ] = pytorch_version
146+ elif dependencies is None :
147+ warn (
148+ "You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
149+ "It may not be possible to create an environmnet where your model can be used."
150+ )
151+ weights = model_spec .raw_nodes .TorchscriptWeightsEntry (** weight_kwargs )
116152
117153 elif weight_type == "keras_hdf5" :
118- if tensorflow_version is None :
119- raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
120- weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
121- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
122- )
154+ if tensorflow_version is not None :
155+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
156+ elif dependencies is None :
157+ warn (
158+ "You are building a keras model but have neither passed dependencies nor the tensorflow_version."
159+ "It may not be possible to create an environmnet where your model can be used."
160+ )
161+ weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (** weight_kwargs )
123162
124163 elif weight_type == "tensorflow_saved_model_bundle" :
125- if tensorflow_version is None :
126- raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
127- weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
128- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
129- )
164+ if tensorflow_version is not None :
165+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
166+ elif dependencies is None :
167+ warn (
168+ "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
169+ "It may not be possible to create an environmnet where your model can be used."
170+ )
171+ weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (** weight_kwargs )
130172
131173 elif weight_type == "tensorflow_js" :
132- if tensorflow_version is None :
133- raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
134- weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
135- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
136- )
174+ if tensorflow_version is not None :
175+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
176+ elif dependencies is None :
177+ warn (
178+ "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
179+ "It may not be possible to create an environmnet where your model can be used."
180+ )
181+ weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (** weight_kwargs )
137182
138183 elif weight_type in weight_types :
139184 raise ValueError (f"Weight type { weight_type } is not supported yet in 'build_spec'" )
140185 else :
141186 raise ValueError (f"Invalid weight type { weight_type } , expect one of { weight_types } " )
142187
143- weights = {weight_type : weights }
144- return weights , tmp_archtecture
188+ return {weight_type : weights }, tmp_archtecture
145189
146190
147191def _get_data_range (data_range , dtype ):
@@ -563,7 +607,8 @@ def build_model(
563607 add_deepimagej_config : bool = False ,
564608 tensorflow_version : Optional [str ] = None ,
565609 opset_version : Optional [int ] = None ,
566- ** weight_kwargs ,
610+ pytorch_version : Optional [str ] = None ,
611+ weight_attachments : Optional [Dict [str , Union [str , List [str ]]]] = None ,
567612):
568613 """Create a zipped bioimage.io model.
569614
@@ -635,30 +680,18 @@ def build_model(
635680 dependencies: relative path to file with dependencies for this model.
636681 root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
637682 add_deepimagej_config: add the deepimagej config to the model.
638- tensorflow_version: the tensorflow version used for training the model.
639- Only requred for models with tensorflow or keras weight format.
640- opset_version: the opset version used in this model.
641- Only requred for models with onnx weight format.
642- weight_kwargs: additional keyword arguments for this weight type.
683+ tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
684+ opset_version: the opset version for this model. Only for onnx weights.
685+ pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights.
686+ weight_attachments: extra weight specific attachments.
643687 """
644688 assert architecture is None or isinstance (architecture , str )
645689 if root is None :
646690 root = "."
647691 root = Path (root )
648692
649693 if attachments is not None :
650- assert isinstance (attachments , dict )
651- if "files" in attachments :
652- afiles = attachments ["files" ]
653- if isinstance (afiles , str ):
654- afiles = [afiles ]
655-
656- if isinstance (afiles , list ):
657- afiles = _ensure_local_or_url (afiles , root )
658- else :
659- raise TypeError (attachments )
660-
661- attachments ["files" ] = afiles
694+ attachments = _get_attachments (attachments , root )
662695
663696 #
664697 # generate the model specific fields
@@ -750,8 +783,9 @@ def build_model(
750783 model_kwargs ,
751784 tensorflow_version = tensorflow_version ,
752785 opset_version = opset_version ,
786+ pytorch_version = pytorch_version ,
753787 dependencies = dependencies ,
754- ** weight_kwargs ,
788+ attachments = weight_attachments ,
755789 )
756790
757791 # validate the sample inputs and outputs (if given)
0 commit comments