|
34 | 34 | "import os\n", |
35 | 35 | "\n", |
36 | 36 | "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", |
37 | | - " %pip install bioimageio.core==0.6.7 torch==2.3.1 onnxruntime==1.18.0" |
| 37 | + " %pip install bioimageio.core==0.6.7 torch==2.3.1 onnxruntime==1.18.0" |
38 | 38 | ] |
39 | 39 | }, |
40 | 40 | { |
|
55 | 55 | "from bioimageio.spec.pretty_validation_errors import (\n", |
56 | 56 | " enable_pretty_validation_errors_in_ipynb,\n", |
57 | 57 | ")\n", |
| 58 | + "\n", |
58 | 59 | "enable_pretty_validation_errors_in_ipynb()" |
59 | 60 | ] |
60 | 61 | }, |
|
78 | 79 | "import matplotlib.pyplot as plt\n", |
79 | 80 | "import numpy as np\n", |
80 | 81 | "\n", |
| 82 | + "\n", |
81 | 83 | "# Function to display input and prediction output images\n", |
82 | 84 | "def show_images(sample_tensor, prediction_tensor):\n", |
83 | | - " input_array = sample_tensor.members['input0'].data\n", |
84 | | - " \n", |
| 85 | + " input_array = sample_tensor.members[\"input0\"].data\n", |
| 86 | + "\n", |
85 | 87 | " # Check for the number of channels to enable display\n", |
86 | 88 | " input_array = np.squeeze(input_array)\n", |
87 | | - " if len(input_array.shape)>2:\n", |
| 89 | + " if len(input_array.shape) > 2:\n", |
88 | 90 | " input_array = input_array[0]\n", |
89 | 91 | "\n", |
90 | | - " output_array = prediction_tensor.members['output0'].data\n", |
91 | | - " \n", |
| 92 | + " output_array = prediction_tensor.members[\"output0\"].data\n", |
| 93 | + "\n", |
92 | 94 | " # Check for the number of channels to enable display\n", |
93 | 95 | " output_array = np.squeeze(output_array)\n", |
94 | | - " if len(output_array.shape)>2:\n", |
| 96 | + " if len(output_array.shape) > 2:\n", |
95 | 97 | " output_array = output_array[0]\n", |
96 | 98 | "\n", |
97 | 99 | " plt.figure()\n", |
98 | | - " ax1 = plt.subplot(1,2,1)\n", |
| 100 | + " ax1 = plt.subplot(1, 2, 1)\n", |
99 | 101 | " ax1.set_title(\"Input\")\n", |
100 | | - " ax1.axis('off')\n", |
| 102 | + " ax1.axis(\"off\")\n", |
101 | 103 | " plt.imshow(input_array)\n", |
102 | | - " ax2 = plt.subplot(1,2,2)\n", |
| 104 | + " ax2 = plt.subplot(1, 2, 2)\n", |
103 | 105 | " ax2.set_title(\"Prediction\")\n", |
104 | | - " ax2.axis('off')\n", |
| 106 | + " ax2.axis(\"off\")\n", |
105 | 107 | " plt.imshow(output_array)\n", |
106 | | - " plt.show()\n", |
107 | | - " \n", |
108 | | - " " |
| 108 | + " plt.show()" |
109 | 109 | ] |
110 | 110 | }, |
111 | 111 | { |
|
153 | 153 | "metadata": {}, |
154 | 154 | "outputs": [], |
155 | 155 | "source": [ |
156 | | - "BMZ_MODEL_ID = \"\"#\"affable-shark\"\n", |
157 | | - "BMZ_MODEL_DOI = \"\" #\"10.5281/zenodo.6287342\"\n", |
| 156 | + "BMZ_MODEL_ID = \"\" # \"affable-shark\"\n", |
| 157 | + "BMZ_MODEL_DOI = \"\" # \"10.5281/zenodo.6287342\"\n", |
158 | 158 | "BMZ_MODEL_URL = \"https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/affable-shark/draft/files/rdf.yaml\"" |
159 | 159 | ] |
160 | 160 | }, |
|
178 | 178 | "# Load the model description\n", |
179 | 179 | "# ------------------------------------------------------------------------------\n", |
180 | 180 | "if BMZ_MODEL_ID != \"\":\n", |
181 | | - " model = load_description(BMZ_MODEL_ID) \n", |
182 | | - " print(f\"\\nThe model '{model.name}' with ID '{BMZ_MODEL_ID}' has been correctly loaded.\")\n", |
| 181 | + " model = load_description(BMZ_MODEL_ID)\n", |
| 182 | + " print(\n", |
| 183 | + " f\"\\nThe model '{model.name}' with ID '{BMZ_MODEL_ID}' has been correctly loaded.\"\n", |
| 184 | + " )\n", |
183 | 185 | "elif BMZ_MODEL_DOI != \"\":\n", |
184 | | - " model = load_description(BMZ_MODEL_DOI) \n", |
185 | | - " print(f\"\\nThe model '{model.name}' with DOI '{BMZ_MODEL_DOI}' has been correctly loaded.\")\n", |
| 186 | + " model = load_description(BMZ_MODEL_DOI)\n", |
| 187 | + " print(\n", |
| 188 | + " f\"\\nThe model '{model.name}' with DOI '{BMZ_MODEL_DOI}' has been correctly loaded.\"\n", |
| 189 | + " )\n", |
186 | 190 | "elif BMZ_MODEL_URL != \"\":\n", |
187 | | - " model = load_description(BMZ_MODEL_URL) \n", |
188 | | - " print(f\"\\nThe model '{model.name}' with URL '{BMZ_MODEL_URL}' has been correctly loaded.\")\n", |
| 191 | + " model = load_description(BMZ_MODEL_URL)\n", |
| 192 | + " print(\n", |
| 193 | + " f\"\\nThe model '{model.name}' with URL '{BMZ_MODEL_URL}' has been correctly loaded.\"\n", |
| 194 | + " )\n", |
189 | 195 | "else:\n", |
190 | | - " print('\\nPlease specify a model ID, DOI or URL')\n", |
| 196 | + " print(\"\\nPlease specify a model ID, DOI or URL\")\n", |
191 | 197 | "\n", |
192 | 198 | "if \"draft\" in BMZ_MODEL_ID or \"draft\" in BMZ_MODEL_DOI or \"draft\" in BMZ_MODEL_URL:\n", |
193 | | - " print(f\"\\nThis is the DRAFT version of '{model.name}'. \\nDraft versions have not been reviewed by the Bioimage Model Zoo Team and may contain harmful code. Run with caution.\")\n", |
| 199 | + " print(\n", |
| 200 | + " f\"\\nThis is the DRAFT version of '{model.name}'. \\nDraft versions have not been reviewed by the Bioimage Model Zoo Team and may contain harmful code. Run with caution.\"\n", |
| 201 | + " )\n", |
194 | 202 | "\n", |
195 | 203 | "# To be added later:\n", |
196 | 204 | "# elif model.version != model.lastest_version:\n", |
197 | 205 | "# print('\\nThe loaded version of the model is: ' + model.version, 'but the latest version of the model is: ' + model.lastest_version)\n", |
198 | 206 | "\n", |
199 | | - "# TODO: on the model loading success responses add version loaded\n", |
200 | | - "\n" |
| 207 | + "# TODO: on the model loading success responses add version loaded" |
201 | 208 | ] |
202 | 209 | }, |
203 | 210 | { |
|
216 | 223 | "outputs": [], |
217 | 224 | "source": [ |
218 | 225 | "print(f\"The model '{model.name}' has the following properties and metadata\\n\")\n", |
219 | | - "print(f\" Description:\") \n", |
| 226 | + "print(f\" Description:\")\n", |
220 | 227 | "pprint(model.description)\n", |
221 | 228 | "\n", |
222 | 229 | "print(\"\\n The authors of the model are: \")\n", |
|
246 | 253 | " plt.imshow(cover_data)\n", |
247 | 254 | " plt.xticks([])\n", |
248 | 255 | " plt.yticks([])\n", |
249 | | - " plt.show()\n" |
| 256 | + " plt.show()" |
250 | 257 | ] |
251 | 258 | }, |
252 | 259 | { |
|
269 | 276 | "metadata": {}, |
270 | 277 | "outputs": [], |
271 | 278 | "source": [ |
272 | | - "print(f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\")\n", |
| 279 | + "print(\n", |
| 280 | + " f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\"\n", |
| 281 | + ")\n", |
273 | 282 | "for ipt in model.inputs:\n", |
274 | 283 | " print(f\"\\ninput '{ipt.id}' with axes:\")\n", |
275 | 284 | " pprint(ipt.axes)\n", |
|
280 | 289 | " for p in ipt.preprocessing:\n", |
281 | 290 | " print(p)\n", |
282 | 291 | "\n", |
283 | | - "print(\"\\n-------------------------------------------------------------------------------\")\n", |
| 292 | + "print(\n", |
| 293 | + " \"\\n-------------------------------------------------------------------------------\"\n", |
| 294 | + ")\n", |
284 | 295 | "# # and what the model outputs are\n", |
285 | | - "print(f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\")\n", |
| 296 | + "print(\n", |
| 297 | + " f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\"\n", |
| 298 | + ")\n", |
286 | 299 | "for out in model.outputs:\n", |
287 | 300 | " print(f\"\\noutput '{out.id}' with axes:\")\n", |
288 | 301 | " pprint(out.axes)\n", |
|
572 | 585 | "\n", |
573 | 586 | " # Check for the number of channels to enable display\n", |
574 | 587 | " input_array = np.squeeze(input_array)\n", |
575 | | - " if len(input_array.shape)>2:\n", |
| 588 | + " if len(input_array.shape) > 2:\n", |
576 | 589 | " input_array = input_array[0]\n", |
577 | | - " \n", |
| 590 | + "\n", |
578 | 591 | " np_input_list.append(input_array)\n", |
579 | 592 | "\n", |
580 | 593 | "\n", |
|
584 | 597 | "\n", |
585 | 598 | " # Check for the number of channels to enable display\n", |
586 | 599 | " output_array = np.squeeze(output_array)\n", |
587 | | - " if len(output_array.shape)>2:\n", |
| 600 | + " if len(output_array.shape) > 2:\n", |
588 | 601 | " output_array = output_array[0]\n", |
589 | | - " \n", |
| 602 | + "\n", |
590 | 603 | " np_output_list.append(output_array)\n", |
591 | 604 | "\n", |
592 | 605 | "plt.imshow(np_input_list[0])" |
|
609 | 622 | "name": "python", |
610 | 623 | "nbconvert_exporter": "python", |
611 | 624 | "pygments_lexer": "ipython3", |
612 | | - "version": "3.11.9" |
| 625 | + "version": "3.9.19" |
613 | 626 | } |
614 | 627 | }, |
615 | 628 | "nbformat": 4, |
|
0 commit comments