1414from anomalib .models .image import Padim as AnomalibPadim
1515
1616from otx .core .model .anomaly import OTXAnomaly
17+ from otx .core .types .label import AnomalyLabelInfo
1718from otx .core .types .task import OTXTaskType
1819
1920if TYPE_CHECKING :
2021 from lightning .pytorch .utilities .types import STEP_OUTPUT
2122 from torch .optim .optimizer import Optimizer
2223
2324 from otx .core .model .anomaly import AnomalyModelInputs , AnomalyModelOutputs
25+ from otx .core .types .label import LabelInfoTypes
2426
2527
2628class Padim (OTXAnomaly , AnomalibPadim ):
@@ -40,6 +42,7 @@ class Padim(OTXAnomaly, AnomalibPadim):
4042
4143 def __init__ (
4244 self ,
45+ label_info : LabelInfoTypes = AnomalyLabelInfo (),
4346 backbone : str = "resnet18" ,
4447 layers : list [str ] = ["layer1" , "layer2" , "layer3" ], # noqa: B006
4548 pre_trained : bool = True ,
@@ -51,7 +54,7 @@ def __init__(
5154 ] = OTXTaskType .ANOMALY_CLASSIFICATION ,
5255 input_size : tuple [int , int ] = (256 , 256 ),
5356 ) -> None :
54- OTXAnomaly .__init__ (self , input_size )
57+ OTXAnomaly .__init__ (self , label_info = label_info , input_size = input_size )
5558 AnomalibPadim .__init__ (
5659 self ,
5760 backbone = backbone ,
0 commit comments