44
55from mmcv import ConfigDict
66
7- from otx .mpa .cls .stage import ClsStage , Stage
7+ from otx .mpa .cls .stage import ClsStage
88from otx .mpa .utils .config_utils import update_or_add_custom_hook
99from otx .mpa .utils .logger import get_logger
1010
@@ -35,9 +35,7 @@ def configure_task(self, cfg, training, **kwargs):
3535 # noqa: C901
3636 def configure_task_adapt (self , cfg , training , ** kwargs ):
3737 """Configure for Task Adaptation Task"""
38-
39- self .adapt_type = cfg ["task_adapt" ].get ("op" , "REPLACE" )
40- train_data_cfg = Stage .get_data_cfg (cfg , "train" )
38+ train_data_cfg = self .get_data_cfg (cfg , "train" )
4139 if training :
4240 if train_data_cfg .type not in CLASS_INC_DATASET :
4341 logger .warning (f"Class Incremental Learning for { train_data_cfg .type } is not yet supported!" )
@@ -46,24 +44,17 @@ def configure_task_adapt(self, cfg, training, **kwargs):
4644
4745 if cfg .model .type in WEIGHT_MIX_CLASSIFIER :
4846 cfg .model .task_adapt = ConfigDict (
49- src_classes = self .model_classes ,
50- dst_classes = self .data_classes ,
47+ src_classes = self .org_model_classes ,
48+ dst_classes = self .model_classes ,
5149 )
5250 else :
5351 logger .warning (f"Weight mixing for { cfg .model .type } is not yet supported!" )
5452
55- # refine self.dst_class following adapt_type (REPLACE, MERGE)
56- self .refine_classes (train_data_cfg )
57- cfg .model .head .num_classes = len (self .dst_classes )
53+ train_data_cfg .classes = self .model_classes
5854
5955 # configure loss, sampler, task_adapt_hook
6056 self .configure_task_modules (cfg )
6157
62- else : # if eval phase (eval)
63- if train_data_cfg .get ("new_classes" ):
64- self .refine_classes (train_data_cfg )
65- cfg .model .head .num_classes = len (self .dst_classes )
66-
6758 def configure_task_modules (self , cfg ):
6859 if not cfg .model .get ("multilabel" , False ) and not cfg .model .get ("hierarchical" , False ):
6960 efficient_mode = cfg ["task_adapt" ].get ("efficient_mode" , True )
@@ -73,8 +64,8 @@ def configure_task_modules(self, cfg):
7364 efficient_mode = cfg ["task_adapt" ].get ("efficient_mode" , False )
7465 sampler_type = "cls_incr"
7566
76- if len (set (self .model_classes ) & set (self .dst_classes )) == 0 or set (self .model_classes ) == set (
77- self .dst_classes
67+ if len (set (self .org_model_classes ) & set (self .model_classes )) == 0 or set (self .org_model_classes ) == set (
68+ self .model_classes
7869 ):
7970 sampler_flag = False
8071 else :
@@ -83,8 +74,8 @@ def configure_task_modules(self, cfg):
8374 # Update Task Adapt Hook
8475 task_adapt_hook = ConfigDict (
8576 type = "TaskAdaptHook" ,
86- src_classes = self .old_classes ,
87- dst_classes = self .dst_classes ,
77+ src_classes = self .org_model_classes ,
78+ dst_classes = self .model_classes ,
8879 model_type = cfg .model .type ,
8980 sampler_flag = sampler_flag ,
9081 sampler_type = sampler_type ,
@@ -93,8 +84,8 @@ def configure_task_modules(self, cfg):
9384 update_or_add_custom_hook (cfg , task_adapt_hook )
9485
9586 def configure_loss (self , cfg ):
96- if len (set (self .model_classes ) & set (self .dst_classes )) == 0 or set (self .model_classes ) == set (
97- self .dst_classes
87+ if len (set (self .org_model_classes ) & set (self .model_classes )) == 0 or set (self .org_model_classes ) == set (
88+ self .model_classes
9889 ):
9990 cfg .model .head .loss = dict (type = "CrossEntropyLoss" , loss_weight = 1.0 )
10091 else :
@@ -104,20 +95,6 @@ def configure_loss(self, cfg):
10495 )
10596 ib_loss_hook = ConfigDict (
10697 type = "IBLossHook" ,
107- dst_classes = self .dst_classes ,
98+ dst_classes = self .model_classes ,
10899 )
109100 update_or_add_custom_hook (cfg , ib_loss_hook )
110-
111- def refine_classes (self , train_cfg ):
112- # Get 'new_classes' in data.train_cfg & get 'old_classes' pretreained model meta data CLASSES
113- new_classes = train_cfg ["new_classes" ]
114- self .old_classes = self .model_meta ["CLASSES" ]
115- if self .adapt_type == "REPLACE" :
116- # if 'REPLACE' operation, then self.dst_classes -> data_classes
117- self .dst_classes = self .data_classes .copy ()
118- elif self .adapt_type == "MERGE" :
119- # if 'MERGE' operation, then self.dst_classes -> old_classes + new_classes (merge)
120- self .dst_classes = self .old_classes + [cls for cls in new_classes if cls not in self .old_classes ]
121- else :
122- raise KeyError (f"{ self .adapt_type } is not supported for task_adapt options!" )
123- train_cfg .classes = self .dst_classes
0 commit comments