Skip to content

Commit b98fd1b

Browse files
Support more dependency managers in build_model
1 parent d7dccb2 commit b98fd1b

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _get_data_range(data_range, dtype):
167167
else:
168168
raise RuntimeError(f"Cannot derived data range for dtype {dtype}")
169169
data_range = (min_, max_)
170-
assert isinstance(data_range, (tuple, list))
170+
assert isinstance(data_range, (tuple, list)), type(data_range)
171171
assert len(data_range) == 2
172172
return data_range
173173

@@ -249,6 +249,17 @@ def _build_cite(cite: Dict[str, str]):
249249
return citation_list
250250

251251

252+
def _get_dependencies(dependencies, root):
253+
if ":" in dependencies:
254+
manager, path = dependencies.split(":")
255+
else:
256+
manager = "conda"
257+
path = dependencies
258+
return model_spec.raw_nodes.Dependencies(
259+
manager=manager, file=_process_uri(path, root)
260+
)
261+
262+
252263
def build_model(
253264
weight_uri: str,
254265
test_inputs: List[Union[str, Path]],
@@ -383,7 +394,7 @@ def build_model(
383394
preprocessing = n_inputs * [None] if preprocessing is None else preprocessing
384395

385396
inputs = [
386-
_get_input_tensor(test_in, name, step, min_shape, axes, data_range, preproc)
397+
_get_input_tensor(test_in, name, step, min_shape, data_range, axes, preproc)
387398
for test_in, name, step, min_shape, axes, data_range, preproc in zip(
388399
test_inputs, input_name, input_step, input_min_shape, input_axes, input_data_range, preprocessing
389400
)
@@ -422,8 +433,8 @@ def build_model(
422433

423434
authors = _build_authors(authors)
424435
cite = _build_cite(cite)
425-
documentation = _process_uri(documentation, root)
426-
covers = [_process_uri(uri, root) for uri in covers]
436+
documentation = _process_uri(documentation, root, download=True)
437+
covers = [_process_uri(uri, root, download=True) for uri in covers]
427438

428439
# parse the weights
429440
weights, language, framework, source, source_hash, tmp_source = _get_weights(
@@ -448,9 +459,7 @@ def build_model(
448459
}
449460
kwargs = {k: v for k, v in optional_kwargs.items() if v is not None}
450461
if dependencies is not None:
451-
kwargs["dependencies"] = model_spec.raw_nodes.Dependencies(
452-
manager="conda", file=_process_uri(dependencies, root)
453-
)
462+
kwargs["dependencies"] = _get_dependencies(dependencies, root)
454463
if parent is not None:
455464
assert len(parent) == 2
456465
kwargs["parent"] = {"uri": parent[0], "sha256": parent[1]}

0 commit comments

Comments
 (0)