@@ -78,7 +78,16 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
7878 return weight_kwargs , tmp_archtecture
7979
8080
81- def _get_weights (original_weight_source , weight_type , root , architecture = None , model_kwargs = None , ** kwargs ):
81+ def _get_weights (
82+ original_weight_source ,
83+ weight_type ,
84+ root ,
85+ architecture = None ,
86+ model_kwargs = None ,
87+ tensorflow_version = None ,
88+ opset_version = None ,
89+ ** kwargs ,
90+ ):
8291 weight_path = resolve_source (original_weight_source , root )
8392 if weight_type is None :
8493 weight_type = _infer_weight_type (weight_path )
@@ -98,8 +107,10 @@ def _get_weights(original_weight_source, weight_type, root, architecture=None, m
98107 )
99108
100109 elif weight_type == "onnx" :
110+ if opset_version is None :
111+ raise ValueError ("opset_version needs to be passed for building an onnx model" )
101112 weights = model_spec .raw_nodes .OnnxWeightsEntry (
102- source = weight_source , sha256 = weight_hash , opset_version = kwargs . get ( " opset_version" , 12 ) , ** attachments
113+ source = weight_source , sha256 = weight_hash , opset_version = opset_version , ** attachments
103114 )
104115
105116 elif weight_type == "pytorch_script" :
@@ -108,26 +119,32 @@ def _get_weights(original_weight_source, weight_type, root, architecture=None, m
108119 )
109120
110121 elif weight_type == "keras_hdf5" :
122+ if tensorflow_version is None :
123+ raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
111124 weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
112125 source = weight_source ,
113126 sha256 = weight_hash ,
114- tensorflow_version = kwargs . get ( " tensorflow_version" , "1.15" ) ,
127+ tensorflow_version = tensorflow_version ,
115128 ** attachments ,
116129 )
117130
118131 elif weight_type == "tensorflow_saved_model_bundle" :
132+ if tensorflow_version is None :
133+ raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
119134 weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
120135 source = weight_source ,
121136 sha256 = weight_hash ,
122- tensorflow_version = kwargs . get ( " tensorflow_version" , "1.15" ) ,
137+ tensorflow_version = tensorflow_version ,
123138 ** attachments ,
124139 )
125140
126141 elif weight_type == "tensorflow_js" :
142+ if tensorflow_version is None :
143+ raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
127144 weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
128145 source = weight_source ,
129146 sha256 = weight_hash ,
130- tensorflow_version = kwargs . get ( " tensorflow_version" , "1.15" ) ,
147+ tensorflow_version = tensorflow_version ,
131148 ** attachments ,
132149 )
133150
@@ -471,6 +488,8 @@ def build_model(
471488 links : Optional [List [str ]] = None ,
472489 root : Optional [Union [Path , str ]] = None ,
473490 add_deepimagej_config : bool = False ,
491+ tensorflow_version : Optional [str ] = None ,
492+ opset_version : Optional [int ] = None ,
474493 ** weight_kwargs ,
475494):
476495 """Create a zipped bioimage.io model.
@@ -539,7 +558,10 @@ def build_model(
539558 dependencies: relative path to file with dependencies for this model.
540559 root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
541560 add_deepimagej_config: add the deepimagej config to the model.
542- weight_kwargs: keyword arguments for this weight type, e.g. "tensorflow_version".
561+ tensorflow_version: the tensorflow version used for training the model.
562+ Needs to be passed for tensorflow or keras models.
563+ opset_version: the opset version used in this model. Needs to be passed for onnx models.
564+ weight_kwargs: additional keyword arguments for this weight type.
543565 """
544566 if root is None :
545567 root = "."
@@ -624,7 +646,16 @@ def build_model(
624646 covers = _ensure_local (covers , root )
625647
626648 # parse the weights
627- weights , tmp_archtecture = _get_weights (weight_uri , weight_type , root , architecture , model_kwargs , ** weight_kwargs )
649+ weights , tmp_archtecture = _get_weights (
650+ weight_uri ,
651+ weight_type ,
652+ root ,
653+ architecture ,
654+ model_kwargs ,
655+ tensorflow_version = tensorflow_version ,
656+ opset_version = opset_version ,
657+ ** weight_kwargs ,
658+ )
628659
629660 # validate the sample inputs and outputs (if given)
630661 if sample_inputs is not None :
@@ -732,11 +763,24 @@ def add_weights(
732763 weight_uri : Union [str , Path ],
733764 weight_type : Optional [str ] = None ,
734765 output_path : Optional [Union [str , Path ]] = None ,
766+ architecture : Optional [str ] = None ,
767+ model_kwargs : Optional [Dict [str , Union [int , float , str ]]] = None ,
768+ tensorflow_version : Optional [str ] = None ,
769+ opset_version : Optional [str ] = None ,
735770 ** weight_kwargs ,
736771):
737772 """Add weight entry to bioimage.io model."""
738773 # we need to pass the weight path as abs path to avoid confusion with different root directories
739- new_weights , tmp_arch = _get_weights (Path (weight_uri ).absolute (), weight_type , root = Path ("." ), ** weight_kwargs )
774+ new_weights , tmp_arch = _get_weights (
775+ Path (weight_uri ).absolute (),
776+ weight_type ,
777+ root = Path ("." ),
778+ architecture = architecture ,
779+ model_kwargs = model_kwargs ,
780+ tensorflow_version = tensorflow_version ,
781+ opset_version = opset_version ,
782+ ** weight_kwargs ,
783+ )
740784 model .weights .update (new_weights )
741785 if output_path is not None :
742786 model_package = export_resource_package (model , output_path = output_path )
0 commit comments