1313from bioimageio .core import export_resource_package , load_raw_resource_description
1414from bioimageio .core .resource_io .nodes import URI
1515from bioimageio .core .resource_io .utils import resolve_local_source , resolve_source
16+ from bioimageio .spec .shared .raw_nodes import ImportableSourceFile , ImportableModule
1617
1718try :
1819 from typing import get_args
@@ -53,67 +54,58 @@ def _infer_weight_type(path):
5354 raise ValueError (f"Could not infer weight type from extension { ext } for weight file { path } " )
5455
5556
56- def _get_weights (original_weight_source , weight_type , source , root , ** kwargs ):
57+ def _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root ):
58+ assert architecture is not None
59+ tmp_archtecture = None
60+ weight_kwargs = {"kwargs" : model_kwargs } if model_kwargs else {}
61+ arch = spec .shared .fields .ImportableSource ().deserialize (architecture )
62+ if isinstance (arch , ImportableSourceFile ):
63+ if os .path .isabs (arch .source_file ):
64+ tmp_archtecture = Path ("this_model_architecture.py" )
65+
66+ copyfile (arch .source_file , root / tmp_archtecture )
67+ arch = ImportableSourceFile (arch .callable_name , tmp_archtecture )
68+
69+ arch_hash = _get_hash (root / arch .source_file )
70+ weight_kwargs ["architecture_sha256" ] = arch_hash
71+ elif isinstance (arch , ImportableModule ):
72+ pass
73+ else :
74+ raise NotImplementedError (arch )
75+
76+ weight_kwargs ["architecture" ] = arch
77+
78+ return weight_kwargs , tmp_archtecture
79+
80+
81+ def _get_weights (original_weight_source , weight_type , root , architecture = None , model_kwargs = None , ** kwargs ):
5782 weight_path = resolve_source (original_weight_source , root )
5883 if weight_type is None :
5984 weight_type = _infer_weight_type (weight_path )
6085 weight_hash = _get_hash (weight_path )
6186
62- tmp_source = None
63- # if we have a ":" (or deprecated "::") this is a python file with class specified,
64- # so we can compute the hash for it
65- if source is not None and ":" in source :
66- source_file , source_class = source .replace ("::" , ":" ).split (":" )
67-
68- # get the source path
69- source_file = _ensure_local (source_file , root )
70- source_hash = _get_hash (root / source_file )
71-
72- # if not relative, create local copy (otherwise this will not work)
73- if os .path .isabs (source_file ):
74- copyfile (source_file , "this_model_architecture.py" )
75- source = f"this_model_architecture.py:{ source_class } "
76- tmp_source = "this_model_architecture.py"
77- else :
78- source = f"{ source_file } :{ source_class } "
79- source = spec .shared .fields .ImportableSource ().deserialize (source )
80- else :
81- source_hash = None
82-
83- if "weight_attachments" in kwargs :
84- attachments = {"attachments" : ["weight_attachments" ]}
85- else :
86- attachments = {}
87-
87+ attachments = {"attachments" : kwargs ["weight_attachments" ]} if "weight_attachments" in kwargs else {}
8888 weight_types = model_spec .raw_nodes .WeightsFormat
8989 weight_source = _ensure_local_or_url (original_weight_source , root )
9090
91+ tmp_archtecture = None
9192 if weight_type == "pytorch_state_dict" :
92- # pytorch-state-dict -> we need a source
93- assert source is not None
93+ # pytorch-state-dict -> we need an architecture definition
94+ weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
95+ weight_kwargs .update (** attachments )
9496 weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (
95- source = weight_source , sha256 = weight_hash , ** attachments
97+ source = weight_source , sha256 = weight_hash , ** weight_kwargs
9698 )
97- language = "python"
98- framework = "pytorch"
9999
100100 elif weight_type == "onnx" :
101101 weights = model_spec .raw_nodes .OnnxWeightsEntry (
102102 source = weight_source , sha256 = weight_hash , opset_version = kwargs .get ("opset_version" , 12 ), ** attachments
103103 )
104- language = None
105- framework = None
106104
107105 elif weight_type == "pytorch_script" :
108106 weights = model_spec .raw_nodes .PytorchScriptWeightsEntry (
109107 source = weight_source , sha256 = weight_hash , ** attachments
110108 )
111- if source is None :
112- language = None
113- framework = None
114- else :
115- language = "python"
116- framework = "pytorch"
117109
118110 elif weight_type == "keras_hdf5" :
119111 weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
@@ -122,8 +114,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
122114 tensorflow_version = kwargs .get ("tensorflow_version" , "1.15" ),
123115 ** attachments ,
124116 )
125- language = "python"
126- framework = "tensorflow"
127117
128118 elif weight_type == "tensorflow_saved_model_bundle" :
129119 weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
@@ -132,8 +122,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
132122 tensorflow_version = kwargs .get ("tensorflow_version" , "1.15" ),
133123 ** attachments ,
134124 )
135- language = "python"
136- framework = "tensorflow"
137125
138126 elif weight_type == "tensorflow_js" :
139127 weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
@@ -142,16 +130,14 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
142130 tensorflow_version = kwargs .get ("tensorflow_version" , "1.15" ),
143131 ** attachments ,
144132 )
145- language = None
146- framework = None
147133
148134 elif weight_type in weight_types :
149135 raise ValueError (f"Weight type { weight_type } is not supported yet in 'build_spec'" )
150136 else :
151137 raise ValueError (f"Invalid weight type { weight_type } , expect one of { weight_types } " )
152138
153139 weights = {weight_type : weights }
154- return weights , language , framework , source , source_hash , tmp_source
140+ return weights , tmp_archtecture
155141
156142
157143def _get_data_range (data_range , dtype ):
@@ -286,13 +272,15 @@ def _get_deepimagej_macro(name, kwargs, export_folder):
286272 else :
287273 raise ValueError (f"Macro { name } is not available, must be one of { macro_names } ." )
288274
289- macro = f"{ name } .ijm"
290275 url = f"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/{ macro } "
291276
292277 path = os .path .join (export_folder , macro )
293278 # use https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/resource_io/utils.py#L267
294279 # instead if the implementation is update s.t. an output path is accepted
295280 with requests .get (url , stream = True ) as r :
281+ text = r .text
282+ if text .startswith ("4" ):
283+ raise RuntimeError (f"An error occured when downloading { url } : { r .text } " )
296284 with open (path , "w" ) as f :
297285 f .write (r .text )
298286
@@ -451,18 +439,18 @@ def build_model(
451439 cite : Dict [str , str ],
452440 output_path : Union [str , Path ],
453441 # model specific optional
454- source : Optional [str ] = None ,
442+ architecture : Optional [str ] = None ,
455443 model_kwargs : Optional [Dict [str , Union [int , float , str ]]] = None ,
456444 weight_type : Optional [str ] = None ,
457445 sample_inputs : Optional [List [str ]] = None ,
458446 sample_outputs : Optional [List [str ]] = None ,
459447 # tensor specific
460- input_name : Optional [List [str ]] = None ,
448+ input_names : Optional [List [str ]] = None ,
461449 input_step : Optional [List [List [int ]]] = None ,
462450 input_min_shape : Optional [List [List [int ]]] = None ,
463451 input_axes : Optional [List [str ]] = None ,
464452 input_data_range : Optional [List [List [Union [int , str ]]]] = None ,
465- output_name : Optional [List [str ]] = None ,
453+ output_names : Optional [List [str ]] = None ,
466454 output_reference : Optional [List [str ]] = None ,
467455 output_scale : Optional [List [List [int ]]] = None ,
468456 output_offset : Optional [List [List [int ]]] = None ,
@@ -526,12 +514,12 @@ def build_model(
526514 weight_type: the type of the weights.
527515 sample_inputs: list of sample inputs to demonstrate the model performance.
528516 sample_outputs: list of sample outputs corresponding to sample_inputs.
529- input_name: name of the input tensor .
517+ input_names: names of the input tensors .
530518 input_step: minimal valid increase of the input tensor shape.
531519 input_min_shape: minimal input tensor shape.
532520 input_axes: axes names for the input tensor.
533521 input_data_range: valid data range for the input tensor.
534- output_name: name of the output tensor .
522+ output_names: names of the output tensors .
535523 output_reference: name of the input reference tensor used to cimpute the output tensor shape.
536524 output_scale: multiplicative factor to compute the output tensor shape.
537525 output_offset: additive term to compute the output tensor shape.
@@ -567,7 +555,11 @@ def build_model(
567555 test_outputs = _ensure_local_or_url (test_outputs , root )
568556
569557 n_inputs = len (test_inputs )
570- input_name = n_inputs * [None ] if input_name is None else input_name
558+ if input_names is None :
559+ input_names = [f"input{ i } " for i in range (n_inputs )]
560+ else :
561+ assert len (input_names ) == len (test_inputs )
562+
571563 input_step = n_inputs * [None ] if input_step is None else input_step
572564 input_min_shape = n_inputs * [None ] if input_min_shape is None else input_min_shape
573565 input_axes = n_inputs * [None ] if input_axes is None else input_axes
@@ -577,12 +569,16 @@ def build_model(
577569 inputs = [
578570 _get_input_tensor (root / test_in , name , step , min_shape , data_range , axes , preproc )
579571 for test_in , name , step , min_shape , axes , data_range , preproc in zip (
580- test_inputs , input_name , input_step , input_min_shape , input_axes , input_data_range , preprocessing
572+ test_inputs , input_names , input_step , input_min_shape , input_axes , input_data_range , preprocessing
581573 )
582574 ]
583575
584576 n_outputs = len (test_outputs )
585- output_name = n_outputs * [None ] if output_name is None else output_name
577+ if output_names is None :
578+ output_names = [f"output{ i } " for i in range (n_outputs )]
579+ else :
580+ assert len (output_names ) == len (test_outputs )
581+
586582 output_reference = n_outputs * [None ] if output_reference is None else output_reference
587583 output_scale = n_outputs * [None ] if output_scale is None else output_scale
588584 output_offset = n_outputs * [None ] if output_offset is None else output_offset
@@ -595,7 +591,7 @@ def build_model(
595591 _get_output_tensor (root / test_out , name , reference , scale , offset , axes , data_range , postproc , hal )
596592 for test_out , name , reference , scale , offset , axes , data_range , postproc , hal in zip (
597593 test_outputs ,
598- output_name ,
594+ output_names ,
599595 output_reference ,
600596 output_scale ,
601597 output_offset ,
@@ -628,9 +624,7 @@ def build_model(
628624 covers = _ensure_local (covers , root )
629625
630626 # parse the weights
631- weights , language , framework , source , source_hash , tmp_source = _get_weights (
632- weight_uri , weight_type , source , root , ** weight_kwargs
633- )
627+ weights , tmp_archtecture = _get_weights (weight_uri , weight_type , root , architecture , model_kwargs , ** weight_kwargs )
634628
635629 # validate the sample inputs and outputs (if given)
636630 if sample_inputs is not None :
@@ -692,11 +686,6 @@ def build_model(
692686 "run_mode" : run_mode ,
693687 "sample_inputs" : sample_inputs ,
694688 "sample_outputs" : sample_outputs ,
695- "framework" : framework ,
696- "language" : language ,
697- "source" : source ,
698- "sha256" : source_hash ,
699- "kwargs" : model_kwargs ,
700689 "links" : links ,
701690 }
702691 kwargs = {k : v for k , v in optional_kwargs .items () if v is not None }
@@ -731,8 +720,8 @@ def build_model(
731720 except Exception as e :
732721 raise e
733722 finally :
734- if tmp_source is not None :
735- os .remove (tmp_source )
723+ if tmp_archtecture is not None :
724+ os .remove (tmp_archtecture )
736725
737726 model = load_raw_resource_description (model_package )
738727 return model
@@ -746,12 +735,12 @@ def add_weights(
746735 ** weight_kwargs ,
747736):
748737 """Add weight entry to bioimage.io model."""
749- # we need to patss the weight path as abs path to avoid confusion with different root directories
750- new_weights = _get_weights (Path (weight_uri ).absolute (), weight_type , source = None , root = Path ("." ), ** weight_kwargs )[
751- 0
752- ]
738+ # we need to pass the weight path as abs path to avoid confusion with different root directories
739+ new_weights , tmp_arch = _get_weights (Path (weight_uri ).absolute (), weight_type , root = Path ("." ), ** weight_kwargs )
753740 model .weights .update (new_weights )
754741 if output_path is not None :
755742 model_package = export_resource_package (model , output_path = output_path )
756743 model = load_raw_resource_description (model_package )
744+ if tmp_arch is not None :
745+ os .remove (tmp_arch )
757746 return model
0 commit comments