Skip to content

Commit 4f1a47c

Browse files
authored
[HOTFIX][RELEASE1.0] Fixed label_scheme mismatch in classification (#1841)
* Fix some cls-patch * Remove comments * Fix some var
1 parent c226919 commit 4f1a47c

File tree

5 files changed

+45
-78
lines changed

5 files changed

+45
-78
lines changed

otx/mpa/cls/exporter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def model_builder_helper(*args, **kwargs):
4141
return model
4242

4343
kwargs["model_builder"] = model_builder_helper
44-
4544
return super().run(model_cfg, model_ckpt, data_cfg, **kwargs)
4645

4746
@staticmethod

otx/mpa/cls/incremental/stage.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mmcv import ConfigDict
66

7-
from otx.mpa.cls.stage import ClsStage, Stage
7+
from otx.mpa.cls.stage import ClsStage
88
from otx.mpa.utils.config_utils import update_or_add_custom_hook
99
from otx.mpa.utils.logger import get_logger
1010

@@ -35,9 +35,7 @@ def configure_task(self, cfg, training, **kwargs):
3535
# noqa: C901
3636
def configure_task_adapt(self, cfg, training, **kwargs):
3737
"""Configure for Task Adaptation Task"""
38-
39-
self.adapt_type = cfg["task_adapt"].get("op", "REPLACE")
40-
train_data_cfg = Stage.get_data_cfg(cfg, "train")
38+
train_data_cfg = self.get_data_cfg(cfg, "train")
4139
if training:
4240
if train_data_cfg.type not in CLASS_INC_DATASET:
4341
logger.warning(f"Class Incremental Learning for {train_data_cfg.type} is not yet supported!")
@@ -46,24 +44,17 @@ def configure_task_adapt(self, cfg, training, **kwargs):
4644

4745
if cfg.model.type in WEIGHT_MIX_CLASSIFIER:
4846
cfg.model.task_adapt = ConfigDict(
49-
src_classes=self.model_classes,
50-
dst_classes=self.data_classes,
47+
src_classes=self.org_model_classes,
48+
dst_classes=self.model_classes,
5149
)
5250
else:
5351
logger.warning(f"Weight mixing for {cfg.model.type} is not yet supported!")
5452

55-
# refine self.dst_class following adapt_type (REPLACE, MERGE)
56-
self.refine_classes(train_data_cfg)
57-
cfg.model.head.num_classes = len(self.dst_classes)
53+
train_data_cfg.classes = self.model_classes
5854

5955
# configure loss, sampler, task_adapt_hook
6056
self.configure_task_modules(cfg)
6157

62-
else: # if eval phase (eval)
63-
if train_data_cfg.get("new_classes"):
64-
self.refine_classes(train_data_cfg)
65-
cfg.model.head.num_classes = len(self.dst_classes)
66-
6758
def configure_task_modules(self, cfg):
6859
if not cfg.model.get("multilabel", False) and not cfg.model.get("hierarchical", False):
6960
efficient_mode = cfg["task_adapt"].get("efficient_mode", True)
@@ -73,8 +64,8 @@ def configure_task_modules(self, cfg):
7364
efficient_mode = cfg["task_adapt"].get("efficient_mode", False)
7465
sampler_type = "cls_incr"
7566

76-
if len(set(self.model_classes) & set(self.dst_classes)) == 0 or set(self.model_classes) == set(
77-
self.dst_classes
67+
if len(set(self.org_model_classes) & set(self.model_classes)) == 0 or set(self.org_model_classes) == set(
68+
self.model_classes
7869
):
7970
sampler_flag = False
8071
else:
@@ -83,8 +74,8 @@ def configure_task_modules(self, cfg):
8374
# Update Task Adapt Hook
8475
task_adapt_hook = ConfigDict(
8576
type="TaskAdaptHook",
86-
src_classes=self.old_classes,
87-
dst_classes=self.dst_classes,
77+
src_classes=self.org_model_classes,
78+
dst_classes=self.model_classes,
8879
model_type=cfg.model.type,
8980
sampler_flag=sampler_flag,
9081
sampler_type=sampler_type,
@@ -93,8 +84,8 @@ def configure_task_modules(self, cfg):
9384
update_or_add_custom_hook(cfg, task_adapt_hook)
9485

9586
def configure_loss(self, cfg):
96-
if len(set(self.model_classes) & set(self.dst_classes)) == 0 or set(self.model_classes) == set(
97-
self.dst_classes
87+
if len(set(self.org_model_classes) & set(self.model_classes)) == 0 or set(self.org_model_classes) == set(
88+
self.model_classes
9889
):
9990
cfg.model.head.loss = dict(type="CrossEntropyLoss", loss_weight=1.0)
10091
else:
@@ -104,20 +95,6 @@ def configure_loss(self, cfg):
10495
)
10596
ib_loss_hook = ConfigDict(
10697
type="IBLossHook",
107-
dst_classes=self.dst_classes,
98+
dst_classes=self.model_classes,
10899
)
109100
update_or_add_custom_hook(cfg, ib_loss_hook)
110-
111-
def refine_classes(self, train_cfg):
112-
# Get 'new_classes' in data.train_cfg & get 'old_classes' pretreained model meta data CLASSES
113-
new_classes = train_cfg["new_classes"]
114-
self.old_classes = self.model_meta["CLASSES"]
115-
if self.adapt_type == "REPLACE":
116-
# if 'REPLACE' operation, then self.dst_classes -> data_classes
117-
self.dst_classes = self.data_classes.copy()
118-
elif self.adapt_type == "MERGE":
119-
# if 'MERGE' operation, then self.dst_classes -> old_classes + new_classes (merge)
120-
self.dst_classes = self.old_classes + [cls for cls in new_classes if cls not in self.old_classes]
121-
else:
122-
raise KeyError(f"{self.adapt_type} is not supported for task_adapt options!")
123-
train_cfg.classes = self.dst_classes

otx/mpa/cls/stage.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -143,25 +143,27 @@ def configure_task(self, cfg, training, **kwargs):
143143
self.configure_classes(cfg)
144144

145145
def configure_classes(self, cfg):
146-
model_classes, data_classes = [], []
147-
self.model_meta = self.get_model_meta(cfg)
148-
train_data_cfg = Stage.get_data_cfg(cfg, "train")
149-
if isinstance(train_data_cfg, list):
150-
train_data_cfg = train_data_cfg[0]
151-
152-
model_classes = Stage.get_model_classes(cfg)
153-
data_classes = Stage.get_data_classes(cfg)
154-
155-
if cfg.get("model_classes", []):
156-
cfg.model.head.num_classes = len(cfg.model_classes) # read from prev model label_schema
157-
elif model_classes:
158-
cfg.model.head.num_classes = len(model_classes) # read from ckpt meta
159-
elif data_classes:
160-
cfg.model.head.num_classes = len(data_classes) # read from env label_schema
161-
self.model_meta["CLASSES"] = model_classes
162-
163-
if not train_data_cfg.get("new_classes", False): # when train_data_cfg doesn't have 'new_classes' key
164-
new_classes = np.setdiff1d(data_classes, model_classes).tolist()
165-
train_data_cfg["new_classes"] = new_classes
146+
"""Patch classes for model and dataset."""
147+
self.task_adapt_op = "REPLACE"
148+
if "task_adapt" in cfg:
149+
self.task_adapt_op = cfg["task_adapt"].get("op", "REPLACE")
150+
151+
org_model_classes = self.get_model_classes(cfg)
152+
data_classes = self.get_data_classes(cfg)
153+
154+
# Model classes
155+
if self.task_adapt_op == "REPLACE":
156+
if len(data_classes) == 0:
157+
model_classes = org_model_classes.copy()
158+
else:
159+
model_classes = data_classes.copy()
160+
elif self.task_adapt_op == "MERGE":
161+
model_classes = org_model_classes + [cls for cls in data_classes if cls not in org_model_classes]
162+
else:
163+
raise KeyError(f"{self.task_adapt_op} is not supported for task_adapt options!")
164+
165+
# Model architecture
166+
cfg.model.head.num_classes = len(model_classes)
167+
168+
self.org_model_classes = org_model_classes
166169
self.model_classes = model_classes
167-
self.data_classes = data_classes

otx/mpa/cls/trainer.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,6 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs): # noqa: C901
6464
cfg.checkpoint_config.meta["tasks"] = repr_ds.tasks
6565
if hasattr(repr_ds, "CLASSES"):
6666
cfg.checkpoint_config.meta["CLASSES"] = repr_ds.CLASSES
67-
if "task_adapt" in cfg:
68-
if hasattr(self, "model_tasks"): # for incremnetal learning
69-
cfg.checkpoint_config.meta.update({"tasks": self.model_tasks})
70-
# instead of update(self.old_tasks), update using "self.model_tasks"
71-
if self.model_classes:
72-
cfg.checkpoint_config.meta.update({"CLASSES": self.model_classes})
73-
# FIXME:self.dst_classes?
7467

7568
self.configure_samples_per_gpu(cfg, "train", self.distributed)
7669
self.configure_fp16_optimizer(cfg, self.distributed)
@@ -85,10 +78,6 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs): # noqa: C901
8578

8679
self.configure_compat_cfg(cfg)
8780

88-
# Save config
89-
# cfg.dump(osp.join(cfg.work_dir, 'config.py'))
90-
# logger.info(f'Config:\n{cfg.pretty_text}')
91-
9281
# register custom eval hooks
9382
validate = True if cfg.data.get("val", None) else False
9483
if validate:

tests/unit/mpa/cls/incremental/test_cls_incremental_stage.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,28 @@ def setup(self) -> None:
1515

1616
@pytest.mark.parametrize("mode", ["MERGE", "REPLACE"])
1717
@e2e_pytest_unit
18-
def test_configure_task_adapt(self, mode):
18+
def test_configure_classes(self, mode, mocker):
19+
1920
self.stage.cfg.merge_from_dict(self.data_cfg)
2021
self.stage.cfg.task_adapt.op = mode
21-
model_classes = ["label_0", "label_1", "label_3", "label_n"]
22-
self.stage.model_meta = {"CLASSES": model_classes}
23-
self.stage.model_classes = model_classes
24-
self.stage.data_classes = self.data_cfg.data.train.data_classes
25-
self.stage.new_classes = self.data_cfg.data.train.data_classes
26-
self.stage.configure_task_adapt(self.stage.cfg, True)
22+
origin_model_classes = ["label_0", "label_3", "label_n"]
23+
mocker.patch("otx.mpa.cls.incremental.stage.IncrClsStage.get_model_classes", return_value=origin_model_classes)
24+
self.stage.data_classes = self.data_cfg.data.train.data_classes # ["label_0", "label_1"]
25+
self.stage.configure_classes(self.stage.cfg)
26+
merge_target = ["label_0", "label_1", "label_3", "label_n"]
2727

2828
if mode == "REPLACE":
2929
assert self.stage.cfg.model.head.num_classes == len(self.stage.data_classes)
3030
else:
31-
assert self.stage.cfg.model.head.num_classes == len(self.stage.data_classes) + len(self.stage.new_classes)
31+
assert self.stage.cfg.model.head.num_classes == len(merge_target)
3232

3333
@e2e_pytest_unit
3434
def test_configure_task_modules(self, monkeypatch, mocker):
3535
mock_update_hook = mocker.patch("otx.mpa.cls.incremental.stage.update_or_add_custom_hook")
3636
# some dummy classes
37+
self.stage.data_classes = [0, 1]
3738
self.stage.model_classes = [0, 1]
38-
self.stage.dst_classes = [0, 1]
39-
self.stage.old_classes = [0, 1]
39+
self.stage.org_model_classes = [0, 1]
4040
self.stage.configure_task_modules(self.stage.cfg)
4141

4242
mock_update_hook.assert_called_once()

0 commit comments

Comments
 (0)