@@ -61,16 +61,10 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
6161 tmp_archtecture = None
6262 weight_kwargs = {"kwargs" : model_kwargs } if model_kwargs else {}
6363 if ":" in architecture :
64- arch_file , callable_name = architecture .replace ("::" , ":" ).split (":" )
65-
66- # this goes haywire if we pass an absolute path, so need to copt to a tmp relative path
67- if os .path .isabs (arch_file ):
68- tmp_archtecture = Path ("this_model_architecture.py" )
69- copyfile (arch_file , root / tmp_archtecture )
70- arch = ImportableSourceFile (callable_name , tmp_archtecture )
71- else :
72- arch = ImportableSourceFile (callable_name , Path (arch_file ))
73-
64+ # note: path itself might include : for absolute paths in windows
65+ * arch_file_parts , callable_name = architecture .replace ("::" , ":" ).split (":" )
66+ arch_file = _ensure_local (":" .join (arch_file_parts ), root )
67+ arch = ImportableSourceFile (callable_name , arch_file )
7468 arch_hash = _get_hash (root / arch .source_file )
7569 weight_kwargs ["architecture_sha256" ] = arch_hash
7670 else :
@@ -123,30 +117,21 @@ def _get_weights(
123117 if tensorflow_version is None :
124118 raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
125119 weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
126- source = weight_source ,
127- sha256 = weight_hash ,
128- tensorflow_version = tensorflow_version ,
129- ** attachments ,
120+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
130121 )
131122
132123 elif weight_type == "tensorflow_saved_model_bundle" :
133124 if tensorflow_version is None :
134125 raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
135126 weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
136- source = weight_source ,
137- sha256 = weight_hash ,
138- tensorflow_version = tensorflow_version ,
139- ** attachments ,
127+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
140128 )
141129
142130 elif weight_type == "tensorflow_js" :
143131 if tensorflow_version is None :
144132 raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
145133 weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
146- source = weight_source ,
147- sha256 = weight_hash ,
148- tensorflow_version = tensorflow_version ,
149- ** attachments ,
134+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
150135 )
151136
152137 elif weight_type in weight_types :
@@ -519,9 +504,8 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
519504 return [_ensure_local_or_url (s , root ) for s in source ]
520505
521506 local_source = resolve_local_source (source , root )
522- local_source = resolve_local_source (
523- local_source , root , None if isinstance (local_source , URI ) else root / local_source .name
524- )
507+ if not isinstance (local_source , URI ):
508+ local_source = resolve_local_source (local_source , root , root / local_source .name )
525509 return local_source .relative_to (root )
526510
527511
@@ -654,6 +638,7 @@ def build_model(
654638 Only requred for models with onnx weight format.
655639 weight_kwargs: additional keyword arguments for this weight type.
656640 """
641+ assert architecture is None or isinstance (architecture , str )
657642 if root is None :
658643 root = "."
659644 root = Path (root )
0 commit comments