|
1 | | -from typing import Any, Callable |
| 1 | +from typing import Any, Callable, Optional |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 |
|
5 | | -from ..train_metrics import accuracy, iou |
| 5 | +from ..functional.train_metrics import accuracy, iou |
6 | 6 |
|
7 | 7 | try: |
8 | 8 | from torchmetrics import Metric |
9 | | -except ImportError: |
10 | | - raise ImportError( |
11 | | - "`torchmetrics` package needed for metrics. `pip install torchmetrics`" |
| 9 | +except ModuleNotFoundError: |
| 10 | + raise ModuleNotFoundError( |
| 11 | + "`torchmetrics` package is required when using metric callbacks. " |
| 12 | + "Install with `pip install torchmetrics`" |
12 | 13 | ) |
13 | 14 |
|
14 | 15 |
|
15 | 16 | __all__ = ["Accuracy", "MeanIoU"] |
16 | 17 |
|
17 | 18 |
|
18 | 19 | class Accuracy(Metric): |
| 20 | + higher_is_better: Optional[bool] = True |
| 21 | + full_state_update: bool = False |
| 22 | + |
19 | 23 | def __init__( |
20 | 24 | self, |
21 | 25 | compute_on_step: bool = True, |
@@ -55,7 +59,7 @@ def update( |
55 | 59 | self, |
56 | 60 | pred: torch.Tensor, |
57 | 61 | target: torch.Tensor, |
58 | | - activation: str = "sofmax", |
| 62 | + activation: str = "softmax", |
59 | 63 | ) -> None: |
60 | 64 | """Update the batch accuracy list with one batch accuracy value. |
61 | 65 |
|
@@ -83,6 +87,9 @@ def compute(self) -> torch.Tensor: |
83 | 87 |
|
84 | 88 |
|
85 | 89 | class MeanIoU(Metric): |
| 90 | + higher_is_better: Optional[bool] = True |
| 91 | + full_state_update: bool = False |
| 92 | + |
86 | 93 | def __init__( |
87 | 94 | self, |
88 | 95 | compute_on_step: bool = True, |
@@ -119,7 +126,7 @@ def update( |
119 | 126 | self, |
120 | 127 | pred: torch.Tensor, |
121 | 128 | target: torch.Tensor, |
122 | | - activation: str = "sofmax", |
| 129 | + activation: str = "softmax", |
123 | 130 | ) -> None: |
124 | 131 | """Update the batch IoU list with one batch IoU matrix. |
125 | 132 |
|
|
0 commit comments