1414from bioimageio .core import export_resource_package , load_raw_resource_description
1515from bioimageio .core .resource_io .nodes import URI
1616from bioimageio .core .resource_io .utils import resolve_local_source , resolve_source
17+ from bioimageio .spec .shared import fields
1718from bioimageio .spec .shared .raw_nodes import ImportableSourceFile , ImportableModule
1819
1920try :
@@ -60,16 +61,10 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
6061 tmp_archtecture = None
6162 weight_kwargs = {"kwargs" : model_kwargs } if model_kwargs else {}
6263 if ":" in architecture :
63- arch_file , callable_name = architecture .replace ("::" , ":" ).split (":" )
64-
65- # this goes haywire if we pass an absolute path, so need to copt to a tmp relative path
66- if os .path .isabs (arch_file ):
67- tmp_archtecture = Path ("this_model_architecture.py" )
68- copyfile (arch_file , root / tmp_archtecture )
69- arch = ImportableSourceFile (callable_name , tmp_archtecture )
70- else :
71- arch = ImportableSourceFile (callable_name , Path (arch_file ))
72-
64+ # note: path itself might include : for absolute paths in windows
65+ * arch_file_parts , callable_name = architecture .replace ("::" , ":" ).split (":" )
66+ arch_file = _ensure_local (":" .join (arch_file_parts ), root )
67+ arch = ImportableSourceFile (callable_name , arch_file )
7368 arch_hash = _get_hash (root / arch .source_file )
7469 weight_kwargs ["architecture_sha256" ] = arch_hash
7570 else :
@@ -122,30 +117,21 @@ def _get_weights(
122117 if tensorflow_version is None :
123118 raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
124119 weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
125- source = weight_source ,
126- sha256 = weight_hash ,
127- tensorflow_version = tensorflow_version ,
128- ** attachments ,
120+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
129121 )
130122
131123 elif weight_type == "tensorflow_saved_model_bundle" :
132124 if tensorflow_version is None :
133125 raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
134126 weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
135- source = weight_source ,
136- sha256 = weight_hash ,
137- tensorflow_version = tensorflow_version ,
138- ** attachments ,
127+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
139128 )
140129
141130 elif weight_type == "tensorflow_js" :
142131 if tensorflow_version is None :
143132 raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
144133 weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
145- source = weight_source ,
146- sha256 = weight_hash ,
147- tensorflow_version = tensorflow_version ,
148- ** attachments ,
134+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
149135 )
150136
151137 elif weight_type in weight_types :
@@ -363,7 +349,7 @@ def get_size(path):
363349 "allow_tiling" : True ,
364350 "model_keys" : None ,
365351 }
366- return {"deepimagej" : config }, attachments
352+ return {"deepimagej" : config }, [ Path ( a ) for a in attachments ]
367353
368354
369355def _write_sample_data (input_paths , output_paths , input_axes , output_axes , export_folder : Path ):
@@ -518,9 +504,8 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
518504 return [_ensure_local_or_url (s , root ) for s in source ]
519505
520506 local_source = resolve_local_source (source , root )
521- local_source = resolve_local_source (
522- local_source , root , None if isinstance (local_source , URI ) else root / local_source .name
523- )
507+ if not isinstance (local_source , URI ):
508+ local_source = resolve_local_source (local_source , root , root / local_source .name )
524509 return local_source .relative_to (root )
525510
526511
@@ -653,10 +638,25 @@ def build_model(
653638 Only requred for models with onnx weight format.
654639 weight_kwargs: additional keyword arguments for this weight type.
655640 """
641+ assert architecture is None or isinstance (architecture , str )
656642 if root is None :
657643 root = "."
658644 root = Path (root )
659645
646+ if attachments is not None :
647+ assert isinstance (attachments , dict )
648+ if "files" in attachments :
649+ afiles = attachments ["files" ]
650+ if isinstance (afiles , str ):
651+ afiles = [afiles ]
652+
653+ if isinstance (afiles , list ):
654+ afiles = _ensure_local_or_url (afiles , root )
655+ else :
656+ raise TypeError (attachments )
657+
658+ attachments ["files" ] = afiles
659+
660660 #
661661 # generate the model specific fields
662662 #
@@ -783,7 +783,7 @@ def build_model(
783783 elif "files" not in attachments :
784784 attachments ["files" ] = ij_attachments
785785 else :
786- attachments ["files" ]. extend ( ij_attachments )
786+ attachments ["files" ] = list ( set ( attachments [ "files" ]) | set ( ij_attachments ) )
787787
788788 if links is None :
789789 links = ["deepimagej/deepimagej" ]
@@ -803,7 +803,6 @@ def build_model(
803803
804804 # optional kwargs, don't pass them if none
805805 optional_kwargs = {
806- "attachments" : attachments ,
807806 "config" : config ,
808807 "git_repo" : git_repo ,
809808 "packaged_by" : packaged_by ,
@@ -814,13 +813,15 @@ def build_model(
814813 }
815814 kwargs = {k : v for k , v in optional_kwargs .items () if v is not None }
816815
816+ if attachments is not None :
817+ kwargs ["attachments" ] = model_spec .raw_nodes .Attachments (** attachments )
817818 if dependencies is not None :
818819 kwargs ["dependencies" ] = _get_dependencies (dependencies , root )
820+ if maintainers is not None :
821+ kwargs ["maintainers" ] = [model_spec .raw_nodes .Maintainer (** m ) for m in maintainers ]
819822 if parent is not None :
820823 assert len (parent ) == 2
821824 kwargs ["parent" ] = {"uri" : parent [0 ], "sha256" : parent [1 ]}
822- if maintainers is not None :
823- kwargs ["maintainers" ] = [model_spec .raw_nodes .Maintainer (** m ) for m in maintainers ]
824825
825826 try :
826827 model = model_spec .raw_nodes .Model (
0 commit comments