|
313 | 313 | "Here, we will use this functionality to create two models, one that adds thresholding of the outputs as post-processing and another that also adds weights in torchscript format." |
314 | 314 | ] |
315 | 315 | }, |
| 316 | + { |
| 317 | + "cell_type": "code", |
| 318 | + "execution_count": null, |
| 319 | + "id": "a210a1f7", |
| 320 | + "metadata": {}, |
| 321 | + "outputs": [], |
| 322 | + "source": [ |
| 323 | + "# get the python file defining the architecture\n", |
| 324 | + "def get_architecture_source(rdf):\n", |
| 325 | + " # here, we need the raw resource, which contains the information from the resource description\n", |
| 326 | + " # before evaluation, e.g. the file and name of the python file with the model architecture\n", |
| 327 | + " raw_resource = bioimageio.core.load_raw_resource_description(rdf)\n", |
| 328 | + " # the source information\n", |
| 329 | + " model_source = raw_resource.source\n", |
| 330 | + " # download the source file if necessary\n", |
| 331 | + " source_file = bioimageio.core.resource_io.utils.resolve_uri(\n", |
| 332 | + " model_source.source_file\n", |
| 333 | + " )\n", |
| 334 | + " # if the source file path does not exist, try combining it with the root path of the model\n", |
| 335 | + " if not os.path.exists(source_file):\n", |
| 336 | + " source_file = os.path.join(raw_resource.root_path, source_file)\n", |
| 337 | + " assert os.path.exists(source_file), source_file\n", |
| 338 | + " class_name = model_source.callable_name\n", |
| 339 | + " return f\"{source_file}:{class_name}\"" |
| 340 | + ] |
| 341 | + }, |
316 | 342 | { |
317 | 343 | "cell_type": "code", |
318 | 344 | "execution_count": null, |
|
340 | 366 | "]\n", |
341 | 367 | "postprocessing = [{\"binarize\": {\"threshold\": threshold}}]\n", |
342 | 368 | "\n", |
343 | | - "# copy the file containing the source code for the model architecture to the model output folder\n", |
344 | | - "model_path = os.path.join(model_resource.root_path, \"unet.py\")\n", |
345 | | - "assert os.path.exists(model_path), model_path\n", |
346 | | - "model_source = f\"{model_path}:UNet2d\"\n", |
| 369 | + "# get the model architecture\n", |
| 370 | + "# note that this is only necessary for pytorch state dict models\n", |
| 371 | + "model_source = get_architecture_source(rdf_doi)\n", |
347 | 372 | "\n", |
348 | 373 | "# we use the `parent` field to indicate that the new model is created based on\n", |
349 | 374 | "# the nucleus segmentation model we have obtained from bioimage.io\n", |
|
0 commit comments