33import os
44from pathlib import Path
55from typing import Any , Dict , List , Optional , Tuple , Union
6+ from warnings import warn
67
78import imageio
89import numpy as np
@@ -73,6 +74,22 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
7374 return weight_kwargs , tmp_archtecture
7475
7576
77+ def _get_attachments (attachments , root ):
78+ assert isinstance (attachments , dict )
79+ if "files" in attachments :
80+ afiles = attachments ["files" ]
81+ if isinstance (afiles , str ):
82+ afiles = [afiles ]
83+
84+ if isinstance (afiles , list ):
85+ afiles = _ensure_local_or_url (afiles , root )
86+ else :
87+ raise TypeError (attachments )
88+
89+ attachments ["files" ] = afiles
90+ return attachments
91+
92+
7693def _get_weights (
7794 original_weight_source ,
7895 weight_type ,
@@ -81,67 +98,94 @@ def _get_weights(
8198 model_kwargs = None ,
8299 tensorflow_version = None ,
83100 opset_version = None ,
101+ pytorch_version = None ,
84102 dependencies = None ,
85- ** kwargs ,
103+ attachments = None ,
86104):
87105 weight_path = resolve_source (original_weight_source , root )
88106 if weight_type is None :
89107 weight_type = _infer_weight_type (weight_path )
90108 weight_hash = _get_hash (weight_path )
91109
92- attachments = {"attachments" : kwargs ["weight_attachments" ]} if "weight_attachments" in kwargs else {}
93110 weight_types = model_spec .raw_nodes .WeightsFormat
94111 weight_source = _ensure_local_or_url (original_weight_source , root )
95112
113+ weight_kwargs = {"source" : weight_source , "sha256" : weight_hash }
114+ if attachments is not None :
115+ weight_kwargs ["attachments" ] = _get_attachments (attachments , root )
116+ if dependencies is not None :
117+ weight_kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
118+
96119 tmp_archtecture = None
97120 if weight_type == "pytorch_state_dict" :
98121 # pytorch-state-dict -> we need an architecture definition
99- weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
100- weight_kwargs .update (** attachments )
101- weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (
102- source = weight_source , sha256 = weight_hash , ** weight_kwargs
103- )
104- if dependencies is not None :
105- weight_kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
122+ pytorch_weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
123+ weight_kwargs .update (** pytorch_weight_kwargs )
124+ if pytorch_version is not None :
125+ weight_kwargs ["pytorch_version" ] = pytorch_version
126+ elif dependencies is None :
127+ warn (
128+ "You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
129+ "It may not be possible to create an environmnet where your model can be used."
130+ )
131+ weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (** weight_kwargs )
106132
107133 elif weight_type == "onnx" :
108- if opset_version is None :
109- raise ValueError ("opset_version needs to be passed for building an onnx model" )
110- weights = model_spec .raw_nodes .OnnxWeightsEntry (
111- source = weight_source , sha256 = weight_hash , opset_version = opset_version , ** attachments
112- )
134+ if opset_version is not None :
135+ weight_kwargs ["opset_version" ] = opset_version
136+ elif dependencies is None :
137+ warn (
138+ "You are building an onnx model but have neither passed dependencies nor the opset_version."
139+ "It may not be possible to create an environmnet where your model can be used."
140+ )
141+ weights = model_spec .raw_nodes .OnnxWeightsEntry (** weight_kwargs )
113142
114143 elif weight_type == "torchscript" :
115- weights = model_spec .raw_nodes .TorchscriptWeightsEntry (source = weight_source , sha256 = weight_hash , ** attachments )
144+ if pytorch_version is not None :
145+ weight_kwargs ["pytorch_version" ] = pytorch_version
146+ elif dependencies is None :
147+ warn (
148+ "You are building a pytorch model but have neither passed dependencies nor the pytorch_version."
149+ "It may not be possible to create an environmnet where your model can be used."
150+ )
151+ weights = model_spec .raw_nodes .TorchscriptWeightsEntry (** weight_kwargs )
116152
117153 elif weight_type == "keras_hdf5" :
118- if tensorflow_version is None :
119- raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
120- weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
121- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
122- )
154+ if tensorflow_version is not None :
155+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
156+ elif dependencies is None :
157+ warn (
158+ "You are building a keras model but have neither passed dependencies nor the tensorflow_version."
159+ "It may not be possible to create an environmnet where your model can be used."
160+ )
161+ weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (** weight_kwargs )
123162
124163 elif weight_type == "tensorflow_saved_model_bundle" :
125- if tensorflow_version is None :
126- raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
127- weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
128- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
129- )
164+ if tensorflow_version is not None :
165+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
166+ elif dependencies is None :
167+ warn (
168+ "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
169+ "It may not be possible to create an environmnet where your model can be used."
170+ )
171+ weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (** weight_kwargs )
130172
131173 elif weight_type == "tensorflow_js" :
132- if tensorflow_version is None :
133- raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
134- weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
135- source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
136- )
174+ if tensorflow_version is not None :
175+ weight_kwargs ["tensorflow_version" ] = tensorflow_version
176+ elif dependencies is None :
177+ warn (
178+ "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version."
179+ "It may not be possible to create an environmnet where your model can be used."
180+ )
181+ weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (** weight_kwargs )
137182
138183 elif weight_type in weight_types :
139184 raise ValueError (f"Weight type { weight_type } is not supported yet in 'build_spec'" )
140185 else :
141186 raise ValueError (f"Invalid weight type { weight_type } , expect one of { weight_types } " )
142187
143- weights = {weight_type : weights }
144- return weights , tmp_archtecture
188+ return {weight_type : weights }, tmp_archtecture
145189
146190
147191def _get_data_range (data_range , dtype ):
@@ -589,7 +633,8 @@ def build_model(
589633 add_deepimagej_config : bool = False ,
590634 tensorflow_version : Optional [str ] = None ,
591635 opset_version : Optional [int ] = None ,
592- ** weight_kwargs ,
636+ pytorch_version : Optional [str ] = None ,
637+ weight_attachments : Optional [Dict [str , Union [str , List [str ]]]] = None ,
593638):
594639 """Create a zipped bioimage.io model.
595640
@@ -661,30 +706,18 @@ def build_model(
661706 dependencies: relative path to file with dependencies for this model.
662707 root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
663708 add_deepimagej_config: add the deepimagej config to the model.
664- tensorflow_version: the tensorflow version used for training the model.
665- Only requred for models with tensorflow or keras weight format.
666- opset_version: the opset version used in this model.
667- Only requred for models with onnx weight format.
668- weight_kwargs: additional keyword arguments for this weight type.
709+ tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
710+ opset_version: the opset version for this model. Only for onnx weights.
711+ pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights.
712+ weight_attachments: extra weight specific attachments.
669713 """
670714 assert architecture is None or isinstance (architecture , str )
671715 if root is None :
672716 root = "."
673717 root = Path (root )
674718
675719 if attachments is not None :
676- assert isinstance (attachments , dict )
677- if "files" in attachments :
678- afiles = attachments ["files" ]
679- if isinstance (afiles , str ):
680- afiles = [afiles ]
681-
682- if isinstance (afiles , list ):
683- afiles = _ensure_local_or_url (afiles , root )
684- else :
685- raise TypeError (attachments )
686-
687- attachments ["files" ] = afiles
720+ attachments = _get_attachments (attachments , root )
688721
689722 #
690723 # generate the model specific fields
@@ -776,8 +809,9 @@ def build_model(
776809 model_kwargs ,
777810 tensorflow_version = tensorflow_version ,
778811 opset_version = opset_version ,
812+ pytorch_version = pytorch_version ,
779813 dependencies = dependencies ,
780- ** weight_kwargs ,
814+ attachments = weight_attachments ,
781815 )
782816
783817 # validate the sample inputs and outputs (if given)
0 commit comments