Skip to content

Commit b5ba695

Browse files
committed
fix(train): ch folder struct, catch module errors
1 parent 16f2372 commit b5ba695

File tree

8 files changed

+33
-26
lines changed

8 files changed

+33
-26
lines changed
Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +0,0 @@
1-
from .callbacks import METRIC_LOOKUP, Accuracy, MeanIoU
2-
from .lightning_experiment import SegmentationExperiment
3-
from .train_metrics import accuracy, confusion_mat, iou
4-
5-
__all__ = [
6-
"SegmentationExperiment",
7-
"confusion_mat",
8-
"accuracy",
9-
"iou",
10-
"Accuracy",
11-
"MeanIoU",
12-
"METRIC_LOOKUP",
13-
]

cellseg_models_pytorch/training/callbacks/metric_callbacks.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1-
from typing import Any, Callable
1+
from typing import Any, Callable, Optional
22

33
import torch
44

5-
from ..train_metrics import accuracy, iou
5+
from ..functional.train_metrics import accuracy, iou
66

77
try:
88
from torchmetrics import Metric
9-
except ImportError:
10-
raise ImportError(
11-
"`torchmetrics` package needed for metrics. `pip install torchmetrics`"
9+
except ModuleNotFoundError:
10+
raise ModuleNotFoundError(
11+
"`torchmetrics` package is required when using metric callbacks. "
12+
"Install with `pip install torchmetrics`"
1213
)
1314

1415

1516
__all__ = ["Accuracy", "MeanIoU"]
1617

1718

1819
class Accuracy(Metric):
20+
higher_is_better: Optional[bool] = True
21+
full_state_update: bool = False
22+
1923
def __init__(
2024
self,
2125
compute_on_step: bool = True,
@@ -55,7 +59,7 @@ def update(
5559
self,
5660
pred: torch.Tensor,
5761
target: torch.Tensor,
58-
activation: str = "sofmax",
62+
activation: str = "softmax",
5963
) -> None:
6064
"""Update the batch accuracy list with one batch accuracy value.
6165
@@ -83,6 +87,9 @@ def compute(self) -> torch.Tensor:
8387

8488

8589
class MeanIoU(Metric):
90+
higher_is_better: Optional[bool] = True
91+
full_state_update: bool = False
92+
8693
def __init__(
8794
self,
8895
compute_on_step: bool = True,
@@ -119,7 +126,7 @@ def update(
119126
self,
120127
pred: torch.Tensor,
121128
target: torch.Tensor,
122-
activation: str = "sofmax",
129+
activation: str = "softmax",
123130
) -> None:
124131
"""Update the batch IoU list with one batch IoU matrix.
125132
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .train_metrics import accuracy, confusion_mat, iou
2+
3+
__all__ = ["confusion_mat", "accuracy", "iou"]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .lightning_experiment import SegmentationExperiment
2+
3+
__all__ = ["SegmentationExperiment"]

cellseg_models_pytorch/training/lightning_experiment.py renamed to cellseg_models_pytorch/training/lit/lightning_experiment.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from copy import deepcopy
22
from typing import Any, Dict, List
33

4-
import pytorch_lightning as pl
54
import torch
65
import torch.nn as nn
76

8-
from ..losses import JOINT_SEG_LOSSES, SEG_LOSS_LOOKUP, JointLoss, Loss, MultiTaskLoss
9-
from ..optimizers import OPTIM_LOOKUP, SCHED_LOOKUP, adjust_optim_params
10-
from .callbacks import METRIC_LOOKUP
7+
try:
8+
import pytorch_lightning as pl
9+
except ModuleNotFoundError:
10+
raise ModuleNotFoundError(
11+
"To use the `SegmentationExperiment`, pytorch-lightning is required. "
12+
"Install with `pip install pytorch-lightning`"
13+
)
14+
15+
from ...losses import JOINT_SEG_LOSSES, SEG_LOSS_LOOKUP, JointLoss, Loss, MultiTaskLoss
16+
from ...optimizers import OPTIM_LOOKUP, SCHED_LOOKUP, adjust_optim_params
17+
from ..callbacks import METRIC_LOOKUP
1118

1219

1320
class SegmentationExperiment(pl.LightningModule):

cellseg_models_pytorch/training/tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from cellseg_models_pytorch.training.train_metrics import accuracy, iou
4+
from cellseg_models_pytorch.training.functional.train_metrics import accuracy, iou
55

66

77
@pytest.mark.parametrize("metric", [accuracy, iou])

cellseg_models_pytorch/training/tests/test_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from cellseg_models_pytorch.datamodules.custom_datamodule import CustomDataModule
77
from cellseg_models_pytorch.datasets import SegmentationFolderDataset
88
from cellseg_models_pytorch.models import cellpose_plus
9-
from cellseg_models_pytorch.training import SegmentationExperiment
9+
from cellseg_models_pytorch.training.lit import SegmentationExperiment
1010

1111

1212
# @pytest.mark.parametrize

0 commit comments

Comments
 (0)