Skip to content

Commit 112b2b2

Browse files
Update label info (#3925)
add label info to init Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 2bcf1b2 commit 112b2b2

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

src/otx/algo/anomaly/padim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
from anomalib.models.image import Padim as AnomalibPadim
1515

1616
from otx.core.model.anomaly import OTXAnomaly
17+
from otx.core.types.label import AnomalyLabelInfo
1718
from otx.core.types.task import OTXTaskType
1819

1920
if TYPE_CHECKING:
2021
from lightning.pytorch.utilities.types import STEP_OUTPUT
2122
from torch.optim.optimizer import Optimizer
2223

2324
from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs
25+
from otx.core.types.label import LabelInfoTypes
2426

2527

2628
class Padim(OTXAnomaly, AnomalibPadim):
@@ -40,6 +42,7 @@ class Padim(OTXAnomaly, AnomalibPadim):
4042

4143
def __init__(
4244
self,
45+
label_info: LabelInfoTypes = AnomalyLabelInfo(),
4346
backbone: str = "resnet18",
4447
layers: list[str] = ["layer1", "layer2", "layer3"], # noqa: B006
4548
pre_trained: bool = True,
@@ -51,7 +54,7 @@ def __init__(
5154
] = OTXTaskType.ANOMALY_CLASSIFICATION,
5255
input_size: tuple[int, int] = (256, 256),
5356
) -> None:
54-
OTXAnomaly.__init__(self, input_size)
57+
OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size)
5558
AnomalibPadim.__init__(
5659
self,
5760
backbone=backbone,

src/otx/algo/anomaly/stfpm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
from anomalib.models.image.stfpm import Stfpm as AnomalibStfpm
1515

1616
from otx.core.model.anomaly import OTXAnomaly
17+
from otx.core.types.label import AnomalyLabelInfo
1718
from otx.core.types.task import OTXTaskType
1819

1920
if TYPE_CHECKING:
2021
from lightning.pytorch.utilities.types import STEP_OUTPUT
2122
from torch.optim.optimizer import Optimizer
2223

2324
from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs
25+
from otx.core.types.label import LabelInfoTypes
2426

2527

2628
class Stfpm(OTXAnomaly, AnomalibStfpm):
@@ -38,6 +40,7 @@ class Stfpm(OTXAnomaly, AnomalibStfpm):
3840

3941
def __init__(
4042
self,
43+
label_info: LabelInfoTypes = AnomalyLabelInfo(),
4144
layers: Sequence[str] = ["layer1", "layer2", "layer3"],
4245
backbone: str = "resnet18",
4346
task: Literal[
@@ -48,7 +51,7 @@ def __init__(
4851
input_size: tuple[int, int] = (256, 256),
4952
**kwargs,
5053
) -> None:
51-
OTXAnomaly.__init__(self, input_size=input_size)
54+
OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size)
5255
AnomalibStfpm.__init__(
5356
self,
5457
backbone=backbone,

src/otx/core/model/anomaly.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
from lightning.pytorch.callbacks.callback import Callback
4040
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
4141
from torchmetrics import Metric
42-
from otx.core.types.label import AnomalyLabelInfo
42+
43+
from otx.core.types.label import LabelInfoTypes
4344

4445
AnomalyModelInputs: TypeAlias = (
4546
AnomalyClassificationDataBatch | AnomalySegmentationDataBatch | AnomalyDetectionDataBatch
@@ -57,8 +58,8 @@ class OTXAnomaly(OTXModel):
5758
Model input size in the order of height and width. Defaults to None.
5859
"""
5960

60-
def __init__(self, input_size: tuple[int, int]) -> None:
61-
super().__init__(label_info=AnomalyLabelInfo(), input_size=input_size)
61+
def __init__(self, label_info: LabelInfoTypes, input_size: tuple[int, int]) -> None:
62+
super().__init__(label_info=label_info, input_size=input_size)
6263
self.optimizer: list[OptimizerCallable] | OptimizerCallable = None
6364
self.scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = None
6465
self.trainer: Trainer

0 commit comments

Comments
 (0)