Skip to content

Commit 0385d87

Browse files
Fix more issues in build_spec and update example notebook
1 parent 6d1ae85 commit 0385d87

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,24 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
5959
assert architecture is not None
6060
tmp_archtecture = None
6161
weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {}
62-
arch = spec.shared.fields.ImportableSource().deserialize(architecture)
63-
if isinstance(arch, ImportableSourceFile):
64-
if os.path.isabs(arch.source_file):
65-
tmp_archtecture = Path("this_model_architecture.py")
62+
if ":" in architecture:
63+
arch_file, callable_name = architecture.replace("::", ":").split(":")
6664

67-
copyfile(arch.source_file, root / tmp_archtecture)
68-
arch = ImportableSourceFile(arch.callable_name, tmp_archtecture)
65+
# this goes haywire if we pass an absolute path, so need to copt to a tmp relative path
66+
if os.path.isabs(arch_file):
67+
tmp_archtecture = Path("this_model_architecture.py")
68+
copyfile(arch_file, root / tmp_archtecture)
69+
arch = ImportableSourceFile(callable_name, tmp_archtecture)
70+
else:
71+
arch = ImportableSourceFile(callable_name, Path(arch_file))
6972

7073
arch_hash = _get_hash(root / arch.source_file)
7174
weight_kwargs["architecture_sha256"] = arch_hash
72-
elif isinstance(arch, ImportableModule):
73-
pass
7475
else:
75-
raise NotImplementedError(arch)
76+
arch = spec.shared.fields.ImportableSource().deserialize(architecture)
77+
assert isinstance(arch, ImportableModule)
7678

7779
weight_kwargs["architecture"] = arch
78-
7980
return weight_kwargs, tmp_archtecture
8081

8182

example/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.ipynb_checkpoints/

example/bioimageio-core-usage.ipynb

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@
144144
"metadata": {},
145145
"outputs": [],
146146
"source": [
147+
"# the function 'test_model' from 'bioimageio.core.resource_tests' can be used to fully test the model,\n",
148+
"# including running prediction for the test input(s) and checking that they agree with the test output(s)\n",
149+
"# before using a model, it is recommended to check that it properly works with this function\n",
150+
"# 'test_model' returns a dict, if there are any errros they will be in the key \"error\"\n",
151+
"# if the model passes it will be None\n",
147152
"from bioimageio.core.resource_tests import test_model\n",
148153
"test_result = test_model(model_resource)\n",
149154
"if test_result[\"error\"]:\n",
@@ -326,7 +331,7 @@
326331
"\n",
327332
"`bioimageio.core` also implements functionality to create a model package compatible with the [bioimageio model spec](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/model_spec_latest.md) ready to be shared via\n",
328333
"the [bioimage.io model zoo](https://bioimage.io/#/).\n",
329-
"Here, we will use this functionality to create two models, one that adds thresholding of the outputs as post-processing and another that also adds weights in torchscript format."
334+
"Here, we will use this functionality to create two models, one that adds thresholding as post-processing to the outputs and another one that also adds weights in torchscript format."
330335
]
331336
},
332337
{
@@ -336,20 +341,21 @@
336341
"metadata": {},
337342
"outputs": [],
338343
"source": [
339-
"# get the python file defining the architecture\n",
344+
"# get the python file defining the architecture.\n",
345+
"# this is only required for models with pytorch_state_dict weights\n",
340346
"def get_architecture_source(rdf):\n",
341347
" # here, we need the raw resource, which contains the information from the resource description\n",
342348
" # before evaluation, e.g. the file and name of the python file with the model architecture\n",
343349
" raw_resource = bioimageio.core.load_raw_resource_description(rdf)\n",
344-
" # the source information\n",
345-
" model_source = raw_resource.source\n",
350+
" # the python file defining the architecture for the pytorch weihgts\n",
351+
" model_source = raw_resource.weights[\"pytorch_state_dict\"].architecture\n",
346352
" # download the source file if necessary\n",
347353
" source_file = bioimageio.core.resource_io.utils.resolve_source(\n",
348354
" model_source.source_file\n",
349355
" )\n",
350356
" # if the source file path does not exist, try combining it with the root path of the model\n",
351357
" if not os.path.exists(source_file):\n",
352-
" source_file = os.path.join(raw_resource.root_path, source_file)\n",
358+
" source_file = os.path.join(raw_resource.root_path, os.path.split(source_file)[1])\n",
353359
" assert os.path.exists(source_file), source_file\n",
354360
" class_name = model_source.callable_name\n",
355361
" return f\"{source_file}:{class_name}\""
@@ -390,9 +396,9 @@
390396
"# the nucleus segmentation model we have obtained from bioimage.io\n",
391397
"# this field is optional and only needs to be given for models that are created based on\n",
392398
"# other models from bioimage.io\n",
393-
"# the parent is specified via it's doi and the hash of the weight file\n",
394-
"weight_file = model_resource.weights[\"pytorch_state_dict\"].source\n",
395-
"with open(weight_file, \"rb\") as f:\n",
399+
"# the parent is specified via it's doi and the hash of its rdf file\n",
400+
"rdf_file = os.path.join(model_resource.root_path, \"rdf.yaml\")\n",
401+
"with open(rdf_file, \"rb\") as f:\n",
396402
" weight_hash = hashlib.sha256(f.read()).hexdigest()\n",
397403
"parent = (rdf_doi, weight_hash)\n",
398404
"\n",
@@ -405,10 +411,24 @@
405411
"# for more informantion, check out the function signature\n",
406412
"# https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/build_spec/build_model.py#L252\n",
407413
"cite = {cite_entry.text: cite_entry.url for cite_entry in model_resource.cite}\n",
414+
"\n",
415+
"# the axes descriptions for the inputs / outputs\n",
416+
"input_axes = [\"bcyx\"]\n",
417+
"output_axes = [\"bcyx\"]\n",
418+
"\n",
419+
"# the pytorch_state_dict weight file\n",
420+
"weight_file = model_resource.weights[\"pytorch_state_dict\"].source\n",
421+
"\n",
422+
"# the path to save the new model with torchscript weights\n",
423+
"zip_path = f\"{model_root}/new_model2.zip\"\n",
424+
"\n",
425+
"# build the model! it will be saved to 'zip_path'\n",
408426
"new_model_raw = build_model(\n",
409-
" weight_file,\n",
427+
" weight_uri=weight_file,\n",
410428
" test_inputs=model_resource.test_inputs,\n",
411429
" test_outputs=[new_output_path],\n",
430+
" input_axes=input_axes,\n",
431+
" output_axes=output_axes,\n",
412432
" output_path=zip_path,\n",
413433
" name=name,\n",
414434
" description=\"nucleus segmentation model with thresholding\",\n",
@@ -419,9 +439,8 @@
419439
" tags=[\"nucleus-segmentation\"],\n",
420440
" cite=cite,\n",
421441
" parent=parent,\n",
422-
" root=model_root,\n",
423-
" source=model_source,\n",
424-
" model_kwargs=model_resource.kwargs,\n",
442+
" architecture=model_source,\n",
443+
" model_kwargs=model_resource.weights[\"pytorch_state_dict\"].kwargs,\n",
425444
" preprocessing=preprocessing,\n",
426445
" postprocessing=postprocessing\n",
427446
")"
@@ -456,17 +475,17 @@
456475
"outputs": [],
457476
"source": [
458477
"# `convert_weigths_to_pytorch_script` creates torchscript weigths based on the weights loaded from pytorch_state_dict\n",
459-
"from bioimageio.core.weight_converter.torch import convert_weights_to_pytorch_script\n",
478+
"from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript\n",
460479
"# `add_weights` adds new weights to the model specification\n",
461480
"from bioimageio.core.build_spec import add_weights\n",
462481
"\n",
463482
"# the path to save the newly created torchscript weights\n",
464483
"weight_path = os.path.join(model_root, \"weights.torchscript\")\n",
465-
"convert_weights_to_pytorch_script(new_model, weight_path)\n",
484+
"convert_weights_to_torchscript(new_model, weight_path)\n",
466485
"\n",
467486
"# the path to save the new model with torchscript weights\n",
468487
"zip_path = f\"{model_root}/new_model2.zip\"\n",
469-
"new_model2_raw = add_weights(new_model_raw, weight_path, weight_type=\"pytorch_script\", output_path=zip_path)"
488+
"new_model2_raw = add_weights(new_model_raw, weight_path, weight_type=\"torchscript\", output_path=zip_path)"
470489
]
471490
},
472491
{
@@ -478,7 +497,7 @@
478497
"source": [
479498
"# load the new model from the zipped package, run prediction and check the result\n",
480499
"new_model = bioimageio.core.load_resource_description(zip_path)\n",
481-
"prediction = predict_numpy(new_model, input_image, weight_format=\"pytorch_script\")\n",
500+
"prediction = predict_numpy(new_model, input_image, weight_format=\"torchscript\")\n",
482501
"show_images(input_image, prediction, names=[\"input\", \"binarized-prediction\"])"
483502
]
484503
},

0 commit comments

Comments
 (0)