1818from dd_ranking .utils import train_one_epoch , validate
1919from dd_ranking .loss import SoftCrossEntropyLoss , KLDivergenceLoss
2020from dd_ranking .aug import DSA_Augmentation , Mixup_Augmentation , Cutmix_Augmentation , ZCA_Whitening_Augmentation
21+ from dd_ranking .config import Config
2122
2223
2324class Unified_Evaluator :
2425
25- def __init__ (self ,
26- dataset : str ,
27- real_data_path : str ,
28- ipc : int ,
29- model_name : str ,
30- use_soft_label : bool ,
26+ def __init__ (self ,
27+ config : Config = None ,
28+ dataset : str = 'CIFAR10' ,
29+ real_data_path : str = './dataset' ,
30+ ipc : int = 10 ,
31+ model_name : str = 'ConvNet-3' ,
32+ use_soft_label : bool = False ,
3133 optimizer : str = 'sgd' ,
3234 lr_scheduler : str = 'step' ,
33- data_aug_func : str = None ,
35+ data_aug_func : str = 'dsa' ,
3436 aug_params : dict = None ,
35- soft_label_mode : str = None ,
36- soft_label_criterion : str = None ,
37+ soft_label_mode : str = 'M' ,
38+ soft_label_criterion : str = 'kl' ,
3739 num_eval : int = 5 ,
3840 im_size : tuple = (32 , 32 ),
3941 num_epochs : int = 300 ,
@@ -42,14 +44,51 @@ def __init__(self,
4244 momentum : float = 0.9 ,
4345 use_zca : bool = False ,
4446 temperature : float = 1.0 ,
47+ use_torchvision : bool = False ,
48+ num_workers : int = 4 ,
4549 save_path : str = None ,
4650 device : str = "cuda"
4751 ):
4852
53+ if config is not None :
54+ self .config = config
55+ dataset = self .config .get ('dataset' , 'CIFAR10' )
56+ real_data_path = self .config .get ('real_data_path' , './dataset' )
57+ ipc = self .config .get ('ipc' , 10 )
58+ model_name = self .config .get ('model_name' , 'ConvNet-3' )
59+ use_soft_label = self .config .get ('use_soft_label' , False )
60+ soft_label_criterion = self .config .get ('soft_label_criterion' , 'sce' )
61+ data_aug_func = self .config .get ('data_aug_func' , 'dsa' )
62+ aug_params = self .config .get ('aug_params' , {
63+ "prob_flip" : 0.5 ,
64+ "ratio_rotate" : 15.0 ,
65+ "saturation" : 2.0 ,
66+ "brightness" : 1.0 ,
67+ "contrast" : 0.5 ,
68+ "ratio_scale" : 1.2 ,
69+ "ratio_crop_pad" : 0.125 ,
70+ "ratio_cutout" : 0.5
71+ })
72+ soft_label_mode = self .config .get ('soft_label_mode' , 'S' )
73+ optimizer = self .config .get ('optimizer' , 'sgd' )
74+ lr_scheduler = self .config .get ('lr_scheduler' , 'step' )
75+ temperature = self .config .get ('temperature' , 1.0 )
76+ weight_decay = self .config .get ('weight_decay' , 0.0005 )
77+ momentum = self .config .get ('momentum' , 0.9 )
78+ num_eval = self .config .get ('num_eval' , 5 )
79+ im_size = self .config .get ('im_size' , (32 , 32 ))
80+ num_epochs = self .config .get ('num_epochs' , 300 )
81+ batch_size = self .config .get ('batch_size' , 256 )
82+ default_lr = self .config .get ('default_lr' , 0.01 )
83+ save_path = self .config .get ('save_path' , None )
84+ num_workers = self .config .get ('num_workers' , 4 )
85+ use_torchvision = self .config .get ('use_torchvision' , False )
86+ device = self .config .get ('device' , 'cuda' )
87+
4988 channel , im_size , num_classes , dst_train , dst_test , class_map , class_map_inv = get_dataset (dataset , real_data_path , im_size , use_zca )
5089 self .num_classes = num_classes
5190 self .im_size = im_size
52- self .test_loader = DataLoader (dst_test , batch_size = batch_size , num_workers = 4 , shuffle = False )
91+ self .test_loader = DataLoader (dst_test , batch_size = batch_size , num_workers = num_workers , shuffle = False )
5392
5493 self .ipc = ipc
5594 self .model_name = model_name
@@ -77,26 +116,30 @@ def __init__(self,
77116 os .makedirs (os .path .dirname (save_path ))
78117 self .save_path = save_path
79118
80- pretrained_model_path = get_pretrained_model_path (model_name , dataset , ipc )
119+ if not use_torchvision :
120+ pretrained_model_path = get_pretrained_model_path (model_name , dataset , ipc )
121+ else :
122+ pretrained_model_path = None
123+
81124 self .teacher_model = build_model (
82125 model_name = model_name ,
83126 num_classes = num_classes ,
84127 im_size = self .im_size ,
85128 pretrained = True ,
86129 device = self .device ,
87- model_path = pretrained_model_path
130+ model_path = pretrained_model_path ,
131+ use_torchvision = use_torchvision
88132 )
89133 self .teacher_model .eval ()
90134
91135 if data_aug_func is None :
92136 self .aug_func = None
93- elif data_aug_func == 'DSA ' :
137+ elif data_aug_func == 'dsa ' :
94138 self .aug_func = DSA_Augmentation (aug_params )
95- elif data_aug_func == 'ZCA' :
96- self .aug_func = ZCA_Whitening_Augmentation (aug_params )
97- elif data_aug_func == 'Mixup' :
139+ self .num_epochs = 1000
140+ elif data_aug_func == 'mixup' :
98141 self .aug_func = Mixup_Augmentation (aug_params )
99- elif data_aug_func == 'Cutmix ' :
142+ elif data_aug_func == 'cutmix ' :
100143 self .aug_func = Cutmix_Augmentation (aug_params )
101144 else :
102145 raise ValueError (f"Invalid data augmentation function: { data_aug_func } " )
@@ -168,7 +211,6 @@ def compute_metrics_helper(self, model, loader, lr):
168211 acc = validate (
169212 model = model ,
170213 loader = loader ,
171- aug_func = self .aug_func ,
172214 logging = True ,
173215 device = self .device
174216 )
0 commit comments