Skip to content

Commit f550746

Browse files
sovrasovSongki Choi
andauthored
[HOTFIX][RELEASE1.0] Fix missing classes in the first incremental round & cleanup classification utils (#1839)
* Cleanup * Fix missing classes in cls checkpoint * Fix linters * Add debug comment * Fix default dump_feature=True for Geti --------- Co-authored-by: Songki Choi <[email protected]>
1 parent 418850a commit f550746

File tree

8 files changed

+15
-194
lines changed

8 files changed

+15
-194
lines changed

otx/algorithms/classification/adapters/mmcls/utils/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from .builder import build_classifier
6-
from .config_utils import (
7-
patch_config,
8-
patch_datasets,
9-
patch_evaluation,
10-
prepare_for_training,
11-
)
6+
from .config_utils import patch_datasets, patch_evaluation
127

138
__all__ = [
14-
"patch_config",
9+
"build_classifier",
1510
"patch_datasets",
1611
"patch_evaluation",
17-
"prepare_for_training",
18-
"build_classifier",
1912
]

otx/algorithms/classification/adapters/mmcls/utils/config_utils.py

Lines changed: 3 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -14,153 +14,23 @@
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
1616

17-
import math
18-
from typing import List, Optional, Union
17+
from typing import List, Optional
1918

2019
from mmcv import Config, ConfigDict
2120

2221
from otx.algorithms.common.adapters.mmcv.utils import (
23-
get_configs_by_keys,
2422
get_configs_by_pairs,
2523
get_dataset_configs,
2624
get_meta_keys,
27-
is_epoch_based_runner,
2825
patch_color_conversion,
29-
prepare_work_dir,
30-
remove_from_config,
31-
remove_from_configs_by_type,
32-
update_config,
33-
)
34-
from otx.api.entities.label import Domain, LabelEntity
35-
from otx.api.utils.argument_checks import (
36-
DirectoryPathCheck,
37-
check_input_parameters_type,
3826
)
27+
from otx.api.entities.label import Domain
28+
from otx.api.utils.argument_checks import check_input_parameters_type
3929
from otx.mpa.utils.logger import get_logger
4030

4131
logger = get_logger()
4232

4333

44-
@check_input_parameters_type({"work_dir": DirectoryPathCheck})
45-
def patch_config(
46-
config: Config,
47-
work_dir: str,
48-
labels: List[LabelEntity],
49-
): # pylint: disable=too-many-branches
50-
"""Update config function."""
51-
52-
# Add training cancelation hook.
53-
if "custom_hooks" not in config:
54-
config.custom_hooks = []
55-
if "CancelTrainingHook" not in {hook.type for hook in config.custom_hooks}:
56-
config.custom_hooks.append(ConfigDict({"type": "CancelTrainingHook"}))
57-
58-
# Remove high level data pipelines definition leaving them only inside `data` section.
59-
remove_from_config(config, "train_pipeline")
60-
remove_from_config(config, "test_pipeline")
61-
remove_from_config(config, "train_pipeline_strong")
62-
# Remove cancel interface hook
63-
remove_from_configs_by_type(config.custom_hooks, "CancelInterfaceHook")
64-
65-
config.checkpoint_config.max_keep_ckpts = 5
66-
config.checkpoint_config.interval = config.evaluation.get("interval", 1)
67-
68-
set_data_classes(config, labels)
69-
70-
config.gpu_ids = range(1)
71-
config.work_dir = work_dir
72-
73-
74-
@check_input_parameters_type()
75-
def patch_model_config(
76-
config: Config,
77-
labels: List[LabelEntity],
78-
):
79-
"""Patch model config."""
80-
set_num_classes(config, len(labels))
81-
82-
83-
@check_input_parameters_type()
84-
def patch_adaptive_repeat_dataset(
85-
config: Union[Config, ConfigDict],
86-
num_samples: int,
87-
decay: float = -0.002,
88-
factor: float = 30,
89-
):
90-
"""Patch the repeat times and training epochs adatively.
91-
92-
Frequent dataloading inits and evaluation slow down training when the
93-
sample size is small. Adjusting epoch and dataset repetition based on
94-
empirical exponential decay improves the training time by applying high
95-
repeat value to small sample size dataset and low repeat value to large
96-
sample.
97-
98-
:param config: mmcv config
99-
:param num_samples: number of training samples
100-
:param decay: decaying rate
101-
:param factor: base repeat factor
102-
"""
103-
data_train = config.data.train
104-
if data_train.type == "RepeatDataset" and getattr(data_train, "adaptive_repeat_times", False):
105-
if is_epoch_based_runner(config.runner):
106-
cur_epoch = config.runner.max_epochs
107-
new_repeat = max(round(math.exp(decay * num_samples) * factor), 1)
108-
new_epoch = math.ceil(cur_epoch / new_repeat)
109-
if new_epoch == 1:
110-
return
111-
config.runner.max_epochs = new_epoch
112-
data_train.times = new_repeat
113-
114-
115-
@check_input_parameters_type()
116-
def prepare_for_training(
117-
config: Union[Config, ConfigDict],
118-
data_config: ConfigDict,
119-
) -> Union[Config, ConfigDict]:
120-
"""Prepare configs for training phase."""
121-
prepare_work_dir(config)
122-
123-
train_num_samples = 0
124-
for subset in ["train", "val", "test"]:
125-
data_config_ = data_config.data.get(subset)
126-
config_ = config.data.get(subset)
127-
if data_config_ is None:
128-
continue
129-
for key in ["otx_dataset", "labels"]:
130-
found = get_configs_by_keys(data_config_, key, return_path=True)
131-
if len(found) == 0:
132-
continue
133-
assert len(found) == 1
134-
if subset == "train" and key == "otx_dataset":
135-
found_value = list(found.values())[0]
136-
if found_value:
137-
train_num_samples = len(found_value)
138-
update_config(config_, found)
139-
140-
if train_num_samples > 0:
141-
patch_adaptive_repeat_dataset(config, train_num_samples)
142-
143-
return config
144-
145-
146-
@check_input_parameters_type()
147-
def set_data_classes(config: Config, labels: List[LabelEntity]):
148-
"""Setter data classes into config."""
149-
# Save labels in data configs.
150-
for subset in ("train", "val", "test"):
151-
for cfg in get_dataset_configs(config, subset):
152-
cfg.labels = labels
153-
154-
155-
@check_input_parameters_type()
156-
def set_num_classes(config: Config, num_classes: int):
157-
"""Set num classes."""
158-
head_names = ["head"]
159-
for head_name in head_names:
160-
if head_name in config.model:
161-
config.model[head_name].num_classes = num_classes
162-
163-
16434
@check_input_parameters_type()
16535
def patch_datasets(
16636
config: Config,

otx/algorithms/classification/tasks/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def unload(self):
215215
self.cleanup()
216216

217217
@check_input_parameters_type()
218-
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = False):
218+
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = True):
219219
"""Export function of OTX Classification Task."""
220220

221221
logger.info("Exporting the model")

otx/algorithms/detection/tasks/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def unload(self):
238238
self.cleanup()
239239

240240
@check_input_parameters_type()
241-
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = False):
241+
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = True):
242242
"""Export function of OTX Detection Task."""
243243
# copied from OTX inference_task.py
244244
logger.info("Exporting the model")

otx/api/usecases/tasks/interfaces/export_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class IExportTask(metaclass=abc.ABCMeta):
2020
"""A base interface class for tasks which can export their models."""
2121

2222
@abc.abstractmethod
23-
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = False):
23+
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = True): # FIXME: False
2424
"""This method defines the interface for export.
2525
2626
Args:

otx/mpa/cls/stage.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ def configure_classes(self, cfg):
153153
data_classes = Stage.get_data_classes(cfg)
154154

155155
if cfg.get("model_classes", []):
156-
cfg.model.head.num_classes = len(cfg.model_classes)
156+
cfg.model.head.num_classes = len(cfg.model_classes) # read from prev model label_schema
157157
elif model_classes:
158-
cfg.model.head.num_classes = len(model_classes)
158+
cfg.model.head.num_classes = len(model_classes) # read from ckpt meta
159159
elif data_classes:
160-
cfg.model.head.num_classes = len(data_classes)
160+
cfg.model.head.num_classes = len(data_classes) # read from env label_schema
161161
self.model_meta["CLASSES"] = model_classes
162162

163163
if not train_data_cfg.get("new_classes", False): # when train_data_cfg doesn't have 'new_classes' key

otx/mpa/cls/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs): # noqa: C901
6262
cfg.checkpoint_config.meta = dict(mmcls_version=__version__)
6363
if hasattr(repr_ds, "tasks"):
6464
cfg.checkpoint_config.meta["tasks"] = repr_ds.tasks
65-
else:
65+
if hasattr(repr_ds, "CLASSES"):
6666
cfg.checkpoint_config.meta["CLASSES"] = repr_ds.CLASSES
6767
if "task_adapt" in cfg:
6868
if hasattr(self, "model_tasks"): # for incremnetal learning
6969
cfg.checkpoint_config.meta.update({"tasks": self.model_tasks})
7070
# instead of update(self.old_tasks), update using "self.model_tasks"
71-
else:
71+
if self.model_classes:
7272
cfg.checkpoint_config.meta.update({"CLASSES": self.model_classes})
73+
# FIXME:self.dst_classes?
7374

7475
self.configure_samples_per_gpu(cfg, "train", self.distributed)
7576
self.configure_fp16_optimizer(cfg, self.distributed)

tests/unit/algorithms/classification/adapters/mmcls/test_cls_config_builder.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33
#
44

5-
import tempfile
6-
75
import pytest
86
from mmcv.utils import Config
97

10-
from otx.algorithms.classification.adapters.mmcls.utils import (
11-
patch_config,
12-
patch_evaluation,
13-
)
14-
from otx.algorithms.common.adapters.mmcv.utils import get_dataset_configs
15-
from otx.api.entities.id import ID
16-
from otx.api.entities.label import Domain, LabelEntity
8+
from otx.algorithms.classification.adapters.mmcls.utils import patch_evaluation
179
from tests.test_suite.e2e_test_system import e2e_pytest_unit
1810

1911

@@ -27,41 +19,6 @@ def otx_default_cls_config():
2719
return conf
2820

2921

30-
@pytest.fixture
31-
def otx_default_labels():
32-
return [
33-
LabelEntity(name=name, domain=Domain.CLASSIFICATION, is_empty=False, id=ID(i))
34-
for i, name in enumerate(["a", "b"])
35-
]
36-
37-
38-
@e2e_pytest_unit
39-
def test_patch_config(otx_default_cls_config, otx_default_labels) -> None:
40-
"""Test patch_config function.
41-
42-
<Steps>
43-
1. Check work_dir
44-
2. Check removed high level pipelines
45-
3. Check checkpoint config update
46-
4. Check dataset labels
47-
"""
48-
49-
with tempfile.TemporaryDirectory() as work_dir:
50-
patch_config(otx_default_cls_config, work_dir, otx_default_labels)
51-
assert otx_default_cls_config.work_dir == work_dir
52-
53-
assert otx_default_cls_config.get("train_pipeline", None) is None
54-
assert otx_default_cls_config.get("test_pipeline", None) is None
55-
assert otx_default_cls_config.get("train_pipeline_strong", None) is None
56-
57-
assert otx_default_cls_config.checkpoint_config.max_keep_ckpts > 0
58-
assert otx_default_cls_config.checkpoint_config.interval > 0
59-
60-
for subset in ("train", "val", "test"):
61-
for cfg in get_dataset_configs(otx_default_cls_config, subset):
62-
assert len(cfg.labels) == len(otx_default_labels)
63-
64-
6522
@e2e_pytest_unit
6623
def test_patch_evaluation(otx_default_cls_config) -> None:
6724
"""Test patch_evaluation function.

0 commit comments

Comments
 (0)