|
41 | 41 | } |
42 | 42 | ], |
43 | 43 | "source": [ |
44 | | - "import torch\n", |
| 44 | + "from platform import python_version\n", |
| 45 | + "\n", |
45 | 46 | "import lightning\n", |
| 47 | + "import torch\n", |
| 48 | + "\n", |
46 | 49 | "import cellseg_models_pytorch\n", |
47 | | - "from platform import python_version\n", |
48 | 50 | "\n", |
49 | 51 | "print(\"torch version:\", torch.__version__)\n", |
50 | 52 | "print(\"lightning version:\", lightning.__version__)\n", |
|
88 | 90 | ], |
89 | 91 | "source": [ |
90 | 92 | "from pathlib import Path\n", |
| 93 | + "\n", |
91 | 94 | "from cellseg_models_pytorch.datamodules import PannukeDataModule\n", |
92 | 95 | "\n", |
93 | 96 | "fold_split = {\"fold1\": \"train\", \"fold2\": \"train\", \"fold3\": \"valid\"}\n", |
|
150 | 153 | } |
151 | 154 | ], |
152 | 155 | "source": [ |
153 | | - "import numpy as np\n", |
154 | 156 | "import matplotlib.pyplot as plt\n", |
| 157 | + "import numpy as np\n", |
155 | 158 | "from skimage.color import label2rgb\n", |
156 | 159 | "\n", |
| 160 | + "from cellseg_models_pytorch.transforms.functional import gen_flow_maps\n", |
| 161 | + "\n", |
157 | 162 | "# filehandler contains methods to read and write images and masks\n", |
158 | 163 | "from cellseg_models_pytorch.utils import FileHandler\n", |
159 | | - "from cellseg_models_pytorch.transforms.functional import gen_flow_maps\n", |
160 | 164 | "\n", |
161 | 165 | "img_dir = save_dir / \"train\" / \"images\"\n", |
162 | 166 | "mask_dir = save_dir / \"train\" / \"labels\"\n", |
|
219 | 223 | "metadata": {}, |
220 | 224 | "outputs": [], |
221 | 225 | "source": [ |
| 226 | + "from typing import Dict, List, Tuple\n", |
| 227 | + "\n", |
| 228 | + "import lightning.pytorch as pl\n", |
222 | 229 | "import torch\n", |
223 | 230 | "import torch.nn as nn\n", |
224 | 231 | "import torch.optim as optim\n", |
225 | | - "import lightning.pytorch as pl\n", |
226 | | - "from typing import List, Tuple, Dict\n", |
227 | 232 | "\n", |
228 | 233 | "from cellseg_models_pytorch.losses import MultiTaskLoss\n", |
229 | 234 | "\n", |
|
388 | 393 | } |
389 | 394 | ], |
390 | 395 | "source": [ |
391 | | - "from cellseg_models_pytorch.losses import JointLoss, CELoss, DiceLoss, SSIM, MSE\n", |
| 396 | + "from cellseg_models_pytorch.losses import MSE, SSIM, CELoss, DiceLoss, JointLoss\n", |
392 | 397 | "from cellseg_models_pytorch.models import cellpose_base\n", |
393 | 398 | "\n", |
394 | 399 | "# initialize hovernet\n", |
|
464 | 469 | }, |
465 | 470 | { |
466 | 471 | "cell_type": "code", |
467 | | - "execution_count": 8, |
| 472 | + "execution_count": null, |
468 | 473 | "metadata": {}, |
469 | | - "outputs": [ |
470 | | - { |
471 | | - "name": "stderr", |
472 | | - "output_type": "stream", |
473 | | - "text": [ |
474 | | - "You are using a CUDA device ('NVIDIA GeForce RTX 3080 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", |
475 | | - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" |
476 | | - ] |
477 | | - }, |
478 | | - { |
479 | | - "name": "stdout", |
480 | | - "output_type": "stream", |
481 | | - "text": [ |
482 | | - "Found all folds. Skip downloading.\n", |
483 | | - "Found processed pannuke data. If in need of a re-download, please empty the `save_dir` folder.\n" |
484 | | - ] |
485 | | - }, |
486 | | - { |
487 | | - "name": "stderr", |
488 | | - "output_type": "stream", |
489 | | - "text": [ |
490 | | - "\n", |
491 | | - " | Name | Type | Params | Mode \n", |
492 | | - "----------------------------------------------------\n", |
493 | | - "0 | model | CellPoseUnet | 92.4 M | train\n", |
494 | | - "1 | criterion | MultiTaskLoss | 0 | train\n", |
495 | | - "----------------------------------------------------\n", |
496 | | - "3.8 M Trainable params\n", |
497 | | - "88.6 M Non-trainable params\n", |
498 | | - "92.4 M Total params\n", |
499 | | - "369.433 Total estimated model params size (MB)\n" |
500 | | - ] |
501 | | - }, |
502 | | - { |
503 | | - "data": { |
504 | | - "application/vnd.jupyter.widget-view+json": { |
505 | | - "model_id": "6d39790609e94710a1c2db412e95e672", |
506 | | - "version_major": 2, |
507 | | - "version_minor": 0 |
508 | | - }, |
509 | | - "text/plain": [ |
510 | | - "Sanity Checking: | | 0/? [00:00<?, ?it/s]" |
511 | | - ] |
512 | | - }, |
513 | | - "metadata": {}, |
514 | | - "output_type": "display_data" |
515 | | - }, |
516 | | - { |
517 | | - "data": { |
518 | | - "application/vnd.jupyter.widget-view+json": { |
519 | | - "model_id": "9e149243fda244edad091b1ebdb4bbb5", |
520 | | - "version_major": 2, |
521 | | - "version_minor": 0 |
522 | | - }, |
523 | | - "text/plain": [ |
524 | | - "Training: | | 0/? [00:00<?, ?it/s]" |
525 | | - ] |
526 | | - }, |
527 | | - "metadata": {}, |
528 | | - "output_type": "display_data" |
529 | | - }, |
530 | | - { |
531 | | - "data": { |
532 | | - "application/vnd.jupyter.widget-view+json": { |
533 | | - "model_id": "6915db3487c54dba98ad2bc147b1d75f", |
534 | | - "version_major": 2, |
535 | | - "version_minor": 0 |
536 | | - }, |
537 | | - "text/plain": [ |
538 | | - "Validation: | | 0/? [00:00<?, ?it/s]" |
539 | | - ] |
540 | | - }, |
541 | | - "metadata": {}, |
542 | | - "output_type": "display_data" |
543 | | - }, |
544 | | - { |
545 | | - "name": "stderr", |
546 | | - "output_type": "stream", |
547 | | - "text": [ |
548 | | - "Epoch 0, global step 648: 'val_loss' reached 1.21952 (best 1.21952), saving model to '/home/leos/pannuke/dino_cellpose/epoch=0-step=648.ckpt' as top 1\n" |
549 | | - ] |
550 | | - }, |
551 | | - { |
552 | | - "data": { |
553 | | - "application/vnd.jupyter.widget-view+json": { |
554 | | - "model_id": "93cfe15ce7b444b8b5ccc4e509f86e7b", |
555 | | - "version_major": 2, |
556 | | - "version_minor": 0 |
557 | | - }, |
558 | | - "text/plain": [ |
559 | | - "Validation: | | 0/? [00:00<?, ?it/s]" |
560 | | - ] |
561 | | - }, |
562 | | - "metadata": {}, |
563 | | - "output_type": "display_data" |
564 | | - }, |
565 | | - { |
566 | | - "name": "stderr", |
567 | | - "output_type": "stream", |
568 | | - "text": [ |
569 | | - "Epoch 1, global step 1296: 'val_loss' reached 1.19438 (best 1.19438), saving model to '/home/leos/pannuke/dino_cellpose/epoch=1-step=1296.ckpt' as top 1\n" |
570 | | - ] |
571 | | - }, |
572 | | - { |
573 | | - "data": { |
574 | | - "application/vnd.jupyter.widget-view+json": { |
575 | | - "model_id": "36558f8ef059416d9b6d5db628bf34fb", |
576 | | - "version_major": 2, |
577 | | - "version_minor": 0 |
578 | | - }, |
579 | | - "text/plain": [ |
580 | | - "Validation: | | 0/? [00:00<?, ?it/s]" |
581 | | - ] |
582 | | - }, |
583 | | - "metadata": {}, |
584 | | - "output_type": "display_data" |
585 | | - }, |
586 | | - { |
587 | | - "name": "stderr", |
588 | | - "output_type": "stream", |
589 | | - "text": [ |
590 | | - "Epoch 2, global step 1944: 'val_loss' reached 1.07785 (best 1.07785), saving model to '/home/leos/pannuke/dino_cellpose/epoch=2-step=1944.ckpt' as top 1\n" |
591 | | - ] |
592 | | - }, |
593 | | - { |
594 | | - "data": { |
595 | | - "application/vnd.jupyter.widget-view+json": { |
596 | | - "model_id": "f33dbb9ed6794453950d33f3069861fe", |
597 | | - "version_major": 2, |
598 | | - "version_minor": 0 |
599 | | - }, |
600 | | - "text/plain": [ |
601 | | - "Validation: | | 0/? [00:00<?, ?it/s]" |
602 | | - ] |
603 | | - }, |
604 | | - "metadata": {}, |
605 | | - "output_type": "display_data" |
606 | | - }, |
607 | | - { |
608 | | - "name": "stderr", |
609 | | - "output_type": "stream", |
610 | | - "text": [ |
611 | | - "Epoch 3, global step 2592: 'val_loss' reached 1.04576 (best 1.04576), saving model to '/home/leos/pannuke/dino_cellpose/epoch=3-step=2592.ckpt' as top 1\n" |
612 | | - ] |
613 | | - }, |
614 | | - { |
615 | | - "data": { |
616 | | - "application/vnd.jupyter.widget-view+json": { |
617 | | - "model_id": "45c76d38563b471ea2996d14d6bdbd8e", |
618 | | - "version_major": 2, |
619 | | - "version_minor": 0 |
620 | | - }, |
621 | | - "text/plain": [ |
622 | | - "Validation: | | 0/? [00:00<?, ?it/s]" |
623 | | - ] |
624 | | - }, |
625 | | - "metadata": {}, |
626 | | - "output_type": "display_data" |
627 | | - }, |
628 | | - { |
629 | | - "name": "stderr", |
630 | | - "output_type": "stream", |
631 | | - "text": [ |
632 | | - "Epoch 4, global step 3240: 'val_loss' reached 0.94926 (best 0.94926), saving model to '/home/leos/pannuke/dino_cellpose/epoch=4-step=3240.ckpt' as top 1\n", |
633 | | - "`Trainer.fit` stopped: `max_epochs=5` reached.\n" |
634 | | - ] |
635 | | - }, |
636 | | - { |
637 | | - "name": "stdout", |
638 | | - "output_type": "stream", |
639 | | - "text": [ |
640 | | - "gg\n" |
641 | | - ] |
642 | | - } |
643 | | - ], |
| 474 | + "outputs": [], |
644 | 475 | "source": [ |
645 | 476 | "# Train the model\n", |
646 | 477 | "trainer.fit(model=experiment, datamodule=pannuke_module)" |
|
688 | 519 | ], |
689 | 520 | "source": [ |
690 | 521 | "import torch.nn.functional as F\n", |
691 | | - "from cellseg_models_pytorch.utils import percentile_normalize_torch\n", |
692 | 522 | "\n", |
| 523 | + "from cellseg_models_pytorch.utils import percentile_normalize_torch\n", |
693 | 524 | "\n", |
694 | 525 | "img_dir = save_dir / \"valid\" / \"images\"\n", |
695 | 526 | "mask_dir = save_dir / \"valid\" / \"labels\"\n", |
|
816 | 647 | } |
817 | 648 | ], |
818 | 649 | "source": [ |
819 | | - "import numpy as np\n", |
820 | | - "from cellseg_models_pytorch.utils import draw_thing_contours\n", |
821 | 650 | "import matplotlib.patches as mpatches\n", |
| 651 | + "import numpy as np\n", |
822 | 652 | "\n", |
| 653 | + "from cellseg_models_pytorch.utils import draw_thing_contours\n", |
823 | 654 | "\n", |
824 | 655 | "fig, ax = plt.subplots(5, 2, figsize=(10, 17))\n", |
825 | 656 | "ax = ax.flatten()\n", |
|
0 commit comments