|
41 | 41 | ], |
42 | 42 | "source": [ |
43 | 43 | "# version info\n", |
44 | | - "import torch\n", |
| 44 | + "from platform import python_version\n", |
| 45 | + "\n", |
45 | 46 | "import lightning\n", |
| 47 | + "import torch\n", |
| 48 | + "\n", |
46 | 49 | "import cellseg_models_pytorch\n", |
47 | | - "from platform import python_version\n", |
48 | 50 | "\n", |
49 | 51 | "print(\"torch version:\", torch.__version__)\n", |
50 | 52 | "print(\"lightning version:\", lightning.__version__)\n", |
|
86 | 88 | ], |
87 | 89 | "source": [ |
88 | 90 | "from pathlib import Path\n", |
| 91 | + "\n", |
89 | 92 | "from cellseg_models_pytorch.datamodules import PannukeDataModule\n", |
90 | 93 | "\n", |
91 | 94 | "# fold1 and fold2 are used for training, fold3 is used for validation\n", |
|
144 | 147 | } |
145 | 148 | ], |
146 | 149 | "source": [ |
147 | | - "import numpy as np\n", |
148 | 150 | "import matplotlib.pyplot as plt\n", |
| 151 | + "import numpy as np\n", |
149 | 152 | "from skimage.color import label2rgb\n", |
150 | 153 | "\n", |
151 | | - "# filehandler contains methods to read and write images and masks\n", |
152 | | - "from cellseg_models_pytorch.utils import FileHandler\n", |
153 | 154 | "from cellseg_models_pytorch.transforms.functional import (\n", |
154 | | - " gen_stardist_maps,\n", |
155 | 155 | " gen_dist_maps,\n", |
| 156 | + " gen_stardist_maps,\n", |
156 | 157 | ")\n", |
157 | 158 | "\n", |
| 159 | + "# filehandler contains methods to read and write images and masks\n", |
| 160 | + "from cellseg_models_pytorch.utils import FileHandler\n", |
| 161 | + "\n", |
158 | 162 | "img_dir = save_dir / \"train\" / \"images\"\n", |
159 | 163 | "mask_dir = save_dir / \"train\" / \"labels\"\n", |
160 | 164 | "imgs = sorted(img_dir.glob(\"*\"))\n", |
|
216 | 220 | "metadata": {}, |
217 | 221 | "outputs": [], |
218 | 222 | "source": [ |
| 223 | + "from typing import Dict, List, Optional, Tuple\n", |
| 224 | + "\n", |
| 225 | + "import lightning.pytorch as pl\n", |
219 | 226 | "import torch\n", |
220 | 227 | "import torch.nn as nn\n", |
221 | 228 | "import torch.optim as optim\n", |
222 | 229 | "import torchmetrics\n", |
223 | | - "import lightning.pytorch as pl\n", |
224 | | - "from typing import List, Tuple, Dict, Optional\n", |
225 | 230 | "\n", |
226 | 231 | "\n", |
227 | 232 | "class SegmentationExperiment(pl.LightningModule):\n", |
|
392 | 397 | "source": [ |
393 | 398 | "import torch.optim as optim\n", |
394 | 399 | "\n", |
395 | | - "from cellseg_models_pytorch.models import stardist_base_multiclass\n", |
396 | 400 | "from cellseg_models_pytorch.losses import (\n", |
397 | 401 | " MAE,\n", |
398 | 402 | " MSE,\n", |
399 | | - " DiceLoss,\n", |
400 | 403 | " BCELoss,\n", |
401 | 404 | " CELoss,\n", |
| 405 | + " DiceLoss,\n", |
402 | 406 | " JointLoss,\n", |
403 | 407 | " MultiTaskLoss,\n", |
404 | 408 | ")\n", |
| 409 | + "from cellseg_models_pytorch.models import stardist_base_multiclass\n", |
405 | 410 | "\n", |
406 | 411 | "# seed the experiment for reproducibility\n", |
407 | 412 | "pl.seed_everything(42)\n", |
|
888 | 893 | } |
889 | 894 | ], |
890 | 895 | "source": [ |
| 896 | + "import matplotlib.patches as mpatches\n", |
891 | 897 | "import numpy as np\n", |
| 898 | + "\n", |
892 | 899 | "from cellseg_models_pytorch.utils import draw_thing_contours\n", |
893 | | - "import matplotlib.patches as mpatches\n", |
894 | 900 | "\n", |
895 | 901 | "fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n", |
896 | 902 | "ax = ax.flatten()\n", |
|
0 commit comments