@@ -53,63 +53,68 @@ def _infer_weight_type(path):
5353 raise ValueError (f"Could not infer weight type from extension { ext } for weight file { path } " )
5454
5555
56- def _get_weights (original_weight_source , weight_type , source , root , ** kwargs ):
57- weight_path = resolve_source (original_weight_source , root )
58- if weight_type is None :
59- weight_type = _infer_weight_type (weight_path )
60- weight_hash = _get_hash (weight_path )
56+ def _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root ):
57+ assert architecture is not None
58+ tmp_archtecture = None
6159
62- tmp_source = None
6360 # if we have a ":" (or deprecated "::") this is a python file with class specified,
6461 # so we can compute the hash for it
65- if source is not None and ":" in source :
66- source_file , source_class = source .replace ("::" , ":" ).split (":" )
62+ if ":" in architecture :
63+ arch_file , arch_class = architecture .replace ("::" , ":" ).split (":" )
6764
6865 # get the source path
69- source_file = _ensure_local (source_file , root )
70- source_hash = _get_hash (root / source_file )
66+ arch_file = _ensure_local (arch_file , root )
67+ arch_hash = _get_hash (root / arch_file )
7168
7269 # if not relative, create local copy (otherwise this will not work)
73- if os .path .isabs (source_file ):
74- copyfile (source_file , "this_model_architecture.py" )
75- source = f"this_model_architecture.py:{ source_class } "
76- tmp_source = "this_model_architecture.py"
70+ if os .path .isabs (arch_file ):
71+ copyfile (arch_file , "this_model_architecture.py" )
72+ arch = f"this_model_architecture.py:{ arch_class } "
73+ tmp_archtecture = "this_model_architecture.py"
7774 else :
78- source = f"{ source_file } :{ source_class } "
79- source = spec .shared .fields .ImportableSource ().deserialize (source )
75+ arch = f"{ arch_file } :{ arch_class } "
76+ arch = spec .shared .fields .Importablearch ().deserialize (arch )
77+
78+ weight_kwargs = {"architecture" : arch , "architecture_sha256" : arch_hash }
79+
80+ # otherwise this is a python class or function name
8081 else :
81- source_hash = None
82+ weight_kwargs = {"architecture" : architecture }
83+
84+ if model_kwargs is not None :
85+ weight_kwargs ["kwargs" ] = model_kwargs
86+
87+ return weight_kwargs , tmp_archtecture
88+
89+
90+ def _get_weights (original_weight_source , weight_type , root , architecture = None , model_kwargs = None , ** kwargs ):
91+ weight_path = resolve_source (original_weight_source , root )
92+ if weight_type is None :
93+ weight_type = _infer_weight_type (weight_path )
94+ weight_hash = _get_hash (weight_path )
8295
8396 attachments = {"attachments" : kwargs ["weight_attachments" ]} if "weight_attachments" in kwargs else {}
8497 weight_types = model_spec .raw_nodes .WeightsFormat
8598 weight_source = _ensure_local_or_url (original_weight_source , root )
8699
100+ tmp_archtecture = None
87101 if weight_type == "pytorch_state_dict" :
88- # pytorch-state-dict -> we need a source
89- assert source is not None
102+ # pytorch-state-dict -> we need an architecture definition
103+ weight_kwargs , tmp_file = _get_pytorch_state_dict_weight_kwargs (architecture , model_kwargs , root )
104+ weight_kwargs .update (** attachments )
90105 weights = model_spec .raw_nodes .PytorchStateDictWeightsEntry (
91- source = weight_source , sha256 = weight_hash , ** attachments
106+ source = weight_source , sha256 = weight_hash , ** weight_kwargs
92107 )
93- language = "python"
94- framework = "pytorch"
95108
96109 elif weight_type == "onnx" :
97110 weights = model_spec .raw_nodes .OnnxWeightsEntry (
98111 source = weight_source , sha256 = weight_hash , opset_version = kwargs .get ("opset_version" , 12 ), ** attachments
99112 )
100- language = None
101- framework = None
102113
103114 elif weight_type == "pytorch_script" :
104115 weights = model_spec .raw_nodes .PytorchScriptWeightsEntry (
105116 source = weight_source , sha256 = weight_hash , ** attachments
106117 )
107- if source is None :
108- language = None
109- framework = None
110- else :
111- language = "python"
112- framework = "pytorch"
113118
114119 elif weight_type == "keras_hdf5" :
115120 weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
@@ -118,8 +123,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
118123 tensorflow_version = kwargs .get ("tensorflow_version" , "1.15" ),
119124 ** attachments ,
120125 )
121- language = "python"
122- framework = "tensorflow"
123126
124127 elif weight_type == "tensorflow_saved_model_bundle" :
125128 weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
@@ -128,8 +131,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
128131 tensorflow_version = kwargs .get ("tensorflow_version" , "1.15" ),
129132 ** attachments ,
130133 )
131- language = "python"
132- framework = "tensorflow"
133134
134135 elif weight_type == "tensorflow_js" :
135136 weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
@@ -138,16 +139,14 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
138139 tensorflow_version = kwargs .get ("tensorflow_version" , "1.15" ),
139140 ** attachments ,
140141 )
141- language = None
142- framework = None
143142
144143 elif weight_type in weight_types :
145144 raise ValueError (f"Weight type { weight_type } is not supported yet in 'build_spec'" )
146145 else :
147146 raise ValueError (f"Invalid weight type { weight_type } , expect one of { weight_types } " )
148147
149148 weights = {weight_type : weights }
150- return weights , language , framework , source , source_hash , tmp_source
149+ return weights , tmp_archtecture
151150
152151
153152def _get_data_range (data_range , dtype ):
@@ -449,7 +448,7 @@ def build_model(
449448 cite : Dict [str , str ],
450449 output_path : Union [str , Path ],
451450 # model specific optional
452- source : Optional [str ] = None ,
451+ architecture : Optional [str ] = None ,
453452 model_kwargs : Optional [Dict [str , Union [int , float , str ]]] = None ,
454453 weight_type : Optional [str ] = None ,
455454 sample_inputs : Optional [List [str ]] = None ,
@@ -626,9 +625,7 @@ def build_model(
626625 covers = _ensure_local (covers , root )
627626
628627 # parse the weights
629- weights , language , framework , source , source_hash , tmp_source = _get_weights (
630- weight_uri , weight_type , source , root , ** weight_kwargs
631- )
628+ weights , tmp_archtecture = _get_weights (weight_uri , weight_type , root , architecture , model_kwargs , ** weight_kwargs )
632629
633630 # validate the sample inputs and outputs (if given)
634631 if sample_inputs is not None :
@@ -690,11 +687,6 @@ def build_model(
690687 "run_mode" : run_mode ,
691688 "sample_inputs" : sample_inputs ,
692689 "sample_outputs" : sample_outputs ,
693- "framework" : framework ,
694- "language" : language ,
695- "source" : source ,
696- "sha256" : source_hash ,
697- "kwargs" : model_kwargs ,
698690 "links" : links ,
699691 }
700692 kwargs = {k : v for k , v in optional_kwargs .items () if v is not None }
@@ -729,8 +721,8 @@ def build_model(
729721 except Exception as e :
730722 raise e
731723 finally :
732- if tmp_source is not None :
733- os .remove (tmp_source )
724+ if tmp_archtecture is not None :
725+ os .remove (tmp_archtecture )
734726
735727 model = load_raw_resource_description (model_package )
736728 return model
@@ -744,12 +736,12 @@ def add_weights(
744736 ** weight_kwargs ,
745737):
746738 """Add weight entry to bioimage.io model."""
747- # we need to patss the weight path as abs path to avoid confusion with different root directories
748- new_weights = _get_weights (Path (weight_uri ).absolute (), weight_type , source = None , root = Path ("." ), ** weight_kwargs )[
749- 0
750- ]
739+ # we need to pass the weight path as abs path to avoid confusion with different root directories
740+ new_weights , tmp_arch = _get_weights (Path (weight_uri ).absolute (), weight_type , root = Path ("." ), ** weight_kwargs )
751741 model .weights .update (new_weights )
752742 if output_path is not None :
753743 model_package = export_resource_package (model , output_path = output_path )
754744 model = load_raw_resource_description (model_package )
745+ if tmp_arch is not None :
746+ os .remove (tmp_arch )
755747 return model
0 commit comments