Skip to content

Commit 4dd5e77

Browse files
kprokofiCopilot
andauthored
Simplify OTXModel API (#5013)
Co-authored-by: Copilot <[email protected]>
1 parent 96f98f3 commit 4dd5e77

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+343
-658
lines changed

library/src/otx/backend/native/cli/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,21 @@ def get_otx_root_path() -> Path:
3434
RECIPE_PATH = get_otx_root_path() / "recipe"
3535

3636

37-
def list_models(task: OTXTaskType | None = None, pattern: str | None = None, print_table: bool = False) -> list[str]:
37+
def list_models(
38+
task: OTXTaskType | None = None, pattern: str | None = None, print_table: bool = False, return_recipes: bool = False
39+
) -> list[str]:
3840
"""Returns a list of available models for training.
3941
4042
Args:
4143
task (OTXTaskType | None, optional): Recipe Filter by Task.
4244
pattern (Optional[str], optional): A string pattern to filter the list of available models. Defaults to None.
4345
print_table (bool, optional): Output the recipe information as a Rich.Table.
4446
This is primarily used for `otx find` in the CLI.
47+
return_recipes (bool, optional): If True, return the recipe paths instead of model names.
4548
4649
Returns:
47-
list[str]: A list of available models for pretraining.
50+
list[str]: A list of available models or recipes for fine-tuning.
51+
4852
4953
Example:
5054
# Return all available model list.
@@ -94,4 +98,7 @@ def list_models(task: OTXTaskType | None = None, pattern: str | None = None, pri
9498
)
9599
console.print(table, width=console.width, justify="center")
96100

101+
if return_recipes:
102+
return recipe_list
103+
97104
return list({Path(recipe).stem for recipe in recipe_list})

library/src/otx/backend/native/engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,13 @@ class OTXEngine(Engine):
8585

8686
def __init__(
8787
self,
88-
model: OTXModel | PathLike,
88+
model: OTXModel | PathLike | str,
8989
data: OTXDataModule | PathLike,
9090
work_dir: PathLike = "./otx-workspace",
9191
checkpoint: PathLike | None = None,
9292
device: DeviceType = DeviceType.auto,
9393
num_devices: int = 1,
94+
task: OTXTaskType | None = None,
9495
**kwargs,
9596
):
9697
"""Initializes the OTX Engine.
@@ -103,6 +104,8 @@ def __init__(
103104
checkpoint (PathLike | None, optional): Path to the checkpoint file (model weights). Defaults to None.
104105
device (DeviceType, optional): The device type to use. Defaults to DeviceType.auto.
105106
num_devices (int, optional): The number of devices to use. If it is 2 or more, it will behave as multi-gpu.
107+
task (OTXTaskType | None, optional): The task type to use. Useful when you provide model name
108+
and this model can be used for multiple tasks. Defaults to None.
106109
**kwargs: Additional keyword arguments for pl.Trainer.
107110
"""
108111
self._cache = TrainerArgumentsCache(**kwargs)
@@ -112,10 +115,13 @@ def __init__(
112115
if not isinstance(data, (OTXDataModule, str, os.PathLike)):
113116
msg = f"data should be OTXDataModule or PathLike, but got {type(data)}"
114117
raise TypeError(msg)
118+
if task is not None and isinstance(data, OTXDataModule) and task != data.task:
119+
msg = f"task and data.task should be the same, but got {task} and {data.task}"
120+
raise ValueError(msg)
115121
self._auto_configurator = AutoConfigurator(
116122
data_root=data if isinstance(data, (str, os.PathLike)) else None,
117-
task=data.task if isinstance(data, OTXDataModule) else None,
118-
model_config_path=None if isinstance(model, OTXModel) else model,
123+
task=data.task if isinstance(data, OTXDataModule) else task,
124+
model=None if isinstance(model, OTXModel) else model,
119125
)
120126
self._datamodule: OTXDataModule = (
121127
data if isinstance(data, OTXDataModule) else self._auto_configurator.get_datamodule()

library/src/otx/backend/native/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
TVModel,
1111
VisionTransformer,
1212
)
13-
from .detection import ATSS, RTDETR, SSD, DFine, RTMDet
13+
from .detection import ATSS, RTDETR, SSD, YOLOX, DEIMDFine, DFine, RTMDet
1414
from .instance_segmentation import MaskRCNN, MaskRCNNTV, RTMDetInst
1515
from .keypoint_detection import RTMPose
1616
from .segmentation import DinoV2Seg, LiteHRNet, SegNext
@@ -19,6 +19,8 @@
1919
"ATSS",
2020
"RTDETR",
2121
"SSD",
22+
"YOLOX",
23+
"DEIMDFine",
2224
"DFine",
2325
"DinoV2Seg",
2426
"EfficientNet",

library/src/otx/backend/native/models/base.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from otx.backend.native.utils.utils import (
3434
ensure_callable,
3535
is_ckpt_for_finetuning,
36-
is_ckpt_from_otx_v1,
3736
remove_state_dict_prefix,
3837
)
3938
from otx.config.data import TileConfig
@@ -102,8 +101,24 @@ def _default_scheduler_callable(
102101
class OTXModel(LightningModule):
103102
"""Base class for the models used in OTX.
104103
104+
This class is a subclass of `LightningModule`. It is not intended to be used directly.
105+
105106
Args:
106-
num_classes: Number of classes this model can predict.
107+
label_info (LabelInfoTypes | int | Sequence): Information about the labels used in the model.
108+
If `int` is given, label info will be constructed from number of classes,
109+
if `Sequence` is given, label info will be constructed from the sequence of label names.
110+
model_name (str, optional): Name of the model. Defaults to "OTXModel".
111+
optimizer (OptimizerCallable, optional): Optimizer callable. Defaults to DefaultOptimizerCallable.
112+
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler callable.
113+
Defaults to DefaultSchedulerCallable.
114+
metric (MetricCallable, optional): Metric callable. Defaults to NullMetricCallable.
115+
torch_compile (bool, optional): Whether to use torch compile. Defaults to False.
116+
tile_config (TileConfig | dict, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
117+
data_input_params (DataInputParams | dict | None, optional): Parameters for image preprocessing.
118+
This parameter contains image input size, mean, and std, that is used to preprocess the input image.
119+
If None is given, default parameters for the specific model will be used.
120+
In most cases you don't need to set this parameter unless you change the image size or pretrained weights.
121+
Defaults to None.
107122
108123
Attributes:
109124
explain_mode: If true, `self.predict_step()` will produce a XAI output as well
@@ -118,7 +133,7 @@ class OTXModel(LightningModule):
118133
def __init__(
119134
self,
120135
label_info: LabelInfoTypes | int | Sequence,
121-
data_input_params: DataInputParams | dict,
136+
data_input_params: DataInputParams | dict | None = None,
122137
task: OTXTaskType | None = None,
123138
model_name: str = "OTXModel",
124139
optimizer: OptimizerCallable = DefaultOptimizerCallable,
@@ -133,7 +148,10 @@ def __init__(
133148
label_info (LabelInfoTypes | int | Sequence): Information about the labels used in the model.
134149
If `int` is given, label info will be constructed from number of classes,
135150
if `Sequence` is given, label info will be constructed from the sequence of label names.
136-
data_input_params (DataInputParams | dict): Parameters of the input data such as input size, mean, and std.
151+
data_input_params (DataInputParams | dict | None, optional): Parameters for image preprocessing.
152+
This parameter contains image input size, mean, and std, that is used to preprocess the input image.
153+
If None is given, default parameters for the specific model will be used.
154+
Defaults to None.
137155
model_name (str, optional): Name of the model. Defaults to "OTXModel".
138156
optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable.
139157
scheduler (LRSchedulerCallable | LRSchedulerListCallable): Callable for the learning rate scheduler.
@@ -148,11 +166,17 @@ def __init__(
148166
super().__init__()
149167

150168
self._label_info = self._dispatch_label_info(label_info)
169+
self.model_name = model_name
151170
if isinstance(data_input_params, dict):
152171
data_input_params = DataInputParams(**data_input_params)
172+
elif data_input_params is None:
173+
data_input_params = (
174+
self._default_preprocessing_params[self.model_name]
175+
if isinstance(self._default_preprocessing_params, dict)
176+
else self._default_preprocessing_params
177+
)
153178
self._check_preprocessing_params(data_input_params)
154179
self.data_input_params = data_input_params
155-
self.model_name = model_name
156180
self.model = self._create_model()
157181
self.optimizer_callable = ensure_callable(optimizer)
158182
self.scheduler_callable = ensure_callable(scheduler)
@@ -455,11 +479,7 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
455479

456480
def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
457481
"""Load state dict incrementally."""
458-
ckpt_label_info: LabelInfo | None = (
459-
ckpt.get("hyper_parameters", {}).get("label_info")
460-
if not is_ckpt_from_otx_v1(ckpt)
461-
else self.get_ckpt_label_info_v1(ckpt)
462-
)
482+
ckpt_label_info: LabelInfo | None = ckpt.get("hyper_parameters", {}).get("label_info")
463483

464484
if ckpt_label_info is None:
465485
msg = "Checkpoint should have `label_info`."
@@ -485,7 +505,7 @@ def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -
485505
)
486506

487507
# Model weights
488-
state_dict: dict[str, Any] = ckpt.get("state_dict", {}) if not is_ckpt_from_otx_v1(ckpt) else ckpt
508+
state_dict: dict[str, Any] = ckpt.get("state_dict", {})
489509

490510
if state_dict is None or state_dict == {}:
491511
msg = "Checkpoint should have `state_dict`."
@@ -501,21 +521,13 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
501521
If checkpoint's label_info and OTXLitModule's label_info are different,
502522
load_state_pre_hook for smart weight loading will be registered.
503523
"""
504-
if is_ckpt_from_otx_v1(ckpt):
505-
msg = "The checkpoint comes from OTXv1, checkpoint keys will be updated automatically."
506-
warnings.warn(msg, stacklevel=2)
507-
state_dict = self.load_from_otx_v1_ckpt(ckpt)
508-
elif is_ckpt_for_finetuning(ckpt):
524+
if is_ckpt_for_finetuning(ckpt):
509525
self.on_load_checkpoint(ckpt)
510526
state_dict = ckpt["state_dict"]
511527
else:
512528
state_dict = ckpt
513529
return super().load_state_dict(state_dict, *args, **kwargs)
514530

515-
def load_from_otx_v1_ckpt(self, ckpt: dict[str, Any]) -> dict:
516-
"""Load the previous OTX ckpt according to OTX2.0."""
517-
raise NotImplementedError
518-
519531
@staticmethod
520532
def get_ckpt_label_info_v1(ckpt: dict) -> LabelInfo:
521533
"""Generate label info from OTX v1 checkpoint."""
@@ -561,6 +573,15 @@ def _set_label_info(self, label_info: LabelInfoTypes) -> None:
561573

562574
self._label_info = new_label_info
563575

576+
@property
577+
@abstractmethod
578+
def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]:
579+
"""Parameters for image preprocessing.
580+
581+
Each model architecture must implement this property, returning a DataInputParams
582+
containing the image input size, mean, and std, that is used to preprocess the input image.
583+
"""
584+
564585
@property
565586
def num_classes(self) -> int:
566587
"""Returns model's number of classes. Can be redefined at the model's level."""
@@ -594,9 +615,13 @@ def _customize_outputs(
594615

595616
def forward(
596617
self,
597-
inputs: OTXDataBatch,
598-
) -> OTXPredBatch | OTXBatchLossEntity:
618+
inputs: OTXDataBatch | Tensor,
619+
) -> OTXPredBatch | OTXBatchLossEntity | Tensor:
599620
"""Model forward function."""
621+
# Simple forward
622+
if isinstance(inputs, Tensor):
623+
return self.forward_for_tracing(inputs)
624+
600625
# If customize_inputs is overridden
601626
if isinstance(inputs, OTXTileBatchDataEntity):
602627
return self.forward_tiles(inputs)

library/src/otx/backend/native/models/classification/factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class MobileNetV3:
5050
def __new__(
5151
cls,
5252
label_info: LabelInfoTypes,
53-
data_input_params: DataInputParams | dict,
53+
data_input_params: DataInputParams | None = None,
5454
task: Literal["multi_class", "multi_label", "h_label"] = "multi_class",
5555
freeze_backbone: bool = False,
5656
model_name: Literal["mobilenetv3_large", "mobilenetv3_small"] = "mobilenetv3_large",
@@ -120,7 +120,7 @@ class EfficientNet:
120120
def __new__(
121121
cls,
122122
label_info: LabelInfoTypes,
123-
data_input_params: DataInputParams,
123+
data_input_params: DataInputParams | None = None,
124124
task: Literal["multi_class", "multi_label", "h_label"] = "multi_class",
125125
model_name: Literal[
126126
"efficientnet_b0",
@@ -192,7 +192,7 @@ class TimmModel:
192192
def __new__(
193193
cls,
194194
label_info: LabelInfoTypes,
195-
data_input_params: DataInputParams,
195+
data_input_params: DataInputParams | None = None,
196196
task: Literal["multi_class", "multi_label", "h_label"] = "multi_class",
197197
model_name: str = "tf_efficientnetv2_s.in21k",
198198
freeze_backbone: bool = False,
@@ -279,7 +279,7 @@ class TVModel:
279279
def __new__(
280280
cls,
281281
label_info: LabelInfoTypes,
282-
data_input_params: DataInputParams,
282+
data_input_params: DataInputParams | None = None,
283283
task: Literal["multi_class", "multi_label", "h_label"] = "multi_class",
284284
model_name: str = "efficientnet_v2_s",
285285
freeze_backbone: bool = False,
@@ -361,7 +361,7 @@ class VisionTransformer:
361361
def __new__(
362362
cls,
363363
label_info: LabelInfoTypes,
364-
data_input_params: DataInputParams,
364+
data_input_params: DataInputParams | None = None,
365365
task: Literal["multi_class", "multi_label", "h_label"] = "multi_class",
366366
model_name: Literal[
367367
"vit-tiny",

library/src/otx/backend/native/models/classification/hlabel_models/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class OTXHlabelClsModel(OTXModel):
3838
3939
Args:
4040
label_info (HLabelInfo): Information about the hierarchical labels.
41-
data_input_params (DataInputParams): Parameters for data input.
41+
data_input_params (DataInputParams | None, optional): Parameters for image data preprocessing. If None is given,
42+
default parameters for the specific model will be used.
4243
model_name (str, optional): Name of the model. Defaults to "hlabel_classification_model".
4344
optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable.
4445
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler.
@@ -52,7 +53,7 @@ class OTXHlabelClsModel(OTXModel):
5253
def __init__(
5354
self,
5455
label_info: HLabelInfo,
55-
data_input_params: DataInputParams,
56+
data_input_params: DataInputParams | None = None,
5657
model_name: str = "hlabel_classification_model",
5758
freeze_backbone: bool = False,
5859
optimizer: OptimizerCallable = DefaultOptimizerCallable,
@@ -233,3 +234,7 @@ def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
233234
return self.model(images=image, mode="explain")
234235

235236
return self.model(images=image, mode="tensor")
237+
238+
@property
239+
def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]:
240+
return DataInputParams(input_size=(224, 224), mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375))

library/src/otx/backend/native/models/classification/hlabel_models/efficientnet.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
AsymmetricAngularLossWithIgnore,
2121
)
2222
from otx.backend.native.models.classification.necks.gap import GlobalAveragePooling
23-
from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper
2423
from otx.backend.native.schedulers import LRSchedulerListCallable
2524
from otx.metrics.accuracy import HLabelClsMetricCallable
2625
from otx.types.label import HLabelInfo
@@ -37,7 +36,7 @@ class EfficientNetHLabelCls(OTXHlabelClsModel):
3736
def __init__(
3837
self,
3938
label_info: HLabelInfo,
40-
data_input_params: DataInputParams,
39+
data_input_params: DataInputParams | None = None,
4140
model_name: Literal[
4241
"efficientnet_b0",
4342
"efficientnet_b1",
@@ -86,7 +85,3 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type:
8685
multiclass_loss=nn.CrossEntropyLoss(),
8786
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
8887
)
89-
90-
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
91-
"""Load the previous OTX ckpt according to OTX2.0."""
92-
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "hlabel", add_prefix)

library/src/otx/backend/native/models/classification/hlabel_models/mobilenet_v3.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
AsymmetricAngularLossWithIgnore,
2222
)
2323
from otx.backend.native.models.classification.necks.gap import GlobalAveragePooling
24-
from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper
2524
from otx.backend.native.schedulers import LRSchedulerListCallable
2625
from otx.data.entity.base import OTXBatchLossEntity
2726
from otx.data.entity.torch import OTXDataBatch, OTXPredBatch
@@ -41,7 +40,7 @@ class MobileNetV3HLabelCls(OTXHlabelClsModel):
4140
def __init__(
4241
self,
4342
label_info: HLabelInfo,
44-
data_input_params: DataInputParams,
43+
data_input_params: DataInputParams | None = None,
4544
model_name: Literal["mobilenetv3_large", "mobilenetv3_small"] = "mobilenetv3_large",
4645
freeze_backbone: bool = False,
4746
optimizer: OptimizerCallable = DefaultOptimizerCallable,
@@ -82,10 +81,6 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type:
8281
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
8382
)
8483

85-
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
86-
"""Load the previous OTX ckpt according to OTX2.0."""
87-
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "hlabel", add_prefix)
88-
8984
def _customize_inputs(self, inputs: OTXDataBatch) -> dict[str, Any]:
9085
if self.training:
9186
mode = "loss"

library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
AsymmetricAngularLossWithIgnore,
2121
)
2222
from otx.backend.native.models.classification.necks.gap import GlobalAveragePooling
23-
from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper
2423
from otx.backend.native.schedulers import LRSchedulerListCallable
2524
from otx.metrics.accuracy import HLabelClsMetricCallable
2625
from otx.types.label import HLabelInfo
@@ -51,7 +50,7 @@ class TimmModelHLabelCls(OTXHlabelClsModel):
5150
def __init__(
5251
self,
5352
label_info: HLabelInfo,
54-
data_input_params: DataInputParams,
53+
data_input_params: DataInputParams | None = None,
5554
model_name: str = "tf_efficientnetv2_s.in21k",
5655
freeze_backbone: bool = False,
5756
optimizer: OptimizerCallable = DefaultOptimizerCallable,
@@ -85,7 +84,3 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type:
8584
multiclass_loss=nn.CrossEntropyLoss(),
8685
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
8786
)
88-
89-
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
90-
"""Load the previous OTX ckpt according to OTX2.0."""
91-
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)

library/src/otx/backend/native/models/classification/hlabel_models/torchvision_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class TVModelHLabelCls(OTXHlabelClsModel):
4646
def __init__(
4747
self,
4848
label_info: HLabelInfo,
49-
data_input_params: DataInputParams,
49+
data_input_params: DataInputParams | None = None,
5050
model_name: str = "efficientnet_v2_s",
5151
freeze_backbone: bool = False,
5252
optimizer: OptimizerCallable = DefaultOptimizerCallable,

0 commit comments

Comments
 (0)