@@ -22,7 +22,7 @@ class SoftLabelEvaluator:
2222 def __init__ (self , config : Config = None , dataset : str = 'CIFAR10' , real_data_path : str = './dataset/' , ipc : int = 10 , model_name : str = 'ConvNet-3' ,
2323 soft_label_criterion : str = 'kl' , data_aug_func : str = 'cutmix' , aug_params : dict = {'beta' : 1.0 }, soft_label_mode : str = 'S' ,
2424 optimizer : str = 'sgd' , lr_scheduler : str = 'step' , temperature : float = 1.0 , weight_decay : float = 0.0005 ,
25- momentum : float = 0.9 , num_eval : int = 5 , im_size : tuple = (32 , 32 ), num_epochs : int = 300 , use_zca : bool = False ,
25+ momentum : float = 0.9 , num_eval : int = 5 , im_size : tuple = (32 , 32 ), num_epochs : int = 300 , use_zca : bool = False , use_aug_for_hard : bool = False ,
2626 real_batch_size : int = 256 , syn_batch_size : int = 256 , default_lr : float = 0.01 , save_path : str = None , stu_use_torchvision : bool = False ,
2727 tea_use_torchvision : bool = False , num_workers : int = 4 , teacher_dir : str = './teacher_models' , custom_train_trans : transforms .Compose = None ,
2828 custom_val_trans : transforms .Compose = None , device : str = "cuda" ):
@@ -46,6 +46,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
4646 im_size = self .config .get ('im_size' )
4747 num_epochs = self .config .get ('num_epochs' )
4848 use_zca = self .config .get ('use_zca' )
49+ use_aug_for_hard = self .config .get ('use_aug_for_hard' )
4950 real_batch_size = self .config .get ('real_batch_size' )
5051 syn_batch_size = self .config .get ('syn_batch_size' )
5152 default_lr = self .config .get ('default_lr' )
@@ -58,15 +59,18 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
5859 teacher_dir = self .config .get ('teacher_dir' )
5960 device = self .config .get ('device' )
6061
61- channel , im_size , num_classes , dst_train , dst_test , class_map , class_map_inv = get_dataset (dataset ,
62- real_data_path ,
63- im_size ,
64- use_zca ,
65- custom_train_trans ,
66- custom_val_trans ,
67- device )
68- self .images_train , self .labels_train , self .class_indices_train = self .load_real_data (dst_train , class_map , num_classes )
69- self .test_loader = DataLoader (dst_test , batch_size = real_batch_size , num_workers = num_workers , shuffle = False )
62+ channel , im_size , num_classes , dst_train , dst_test_real , dst_test_syn , class_map , class_map_inv = get_dataset (dataset ,
63+ real_data_path ,
64+ im_size ,
65+ use_zca ,
66+ custom_val_trans ,
67+ device )
68+ # self.images_train, self.labels_train, self.class_indices_train = self.load_real_data(dst_train, class_map, num_classes)
69+ self .class_indices = self .get_class_indices (dst_train , class_map , num_classes )
70+ self .dst_train = dst_train
71+
72+ self .test_loader_real = DataLoader (dst_test_real , batch_size = real_batch_size , num_workers = num_workers , shuffle = False )
73+ self .test_loader_syn = DataLoader (dst_test_syn , batch_size = syn_batch_size , num_workers = num_workers , shuffle = False )
7074
7175 self .soft_label_mode = soft_label_mode
7276 self .soft_label_criterion = soft_label_criterion
@@ -94,6 +98,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
9498 self .num_workers = num_workers
9599 self .device = device
96100
101+ # data augmentation
97102 if data_aug_func == 'dsa' :
98103 self .aug_func = DSA (aug_params )
99104 elif data_aug_func == 'mixup' :
@@ -102,7 +107,9 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
102107 self .aug_func = Cutmix (aug_params )
103108 else :
104109 self .aug_func = None
110+ self .use_aug_for_hard = use_aug_for_hard
105111
112+ # save path
106113 if not save_path :
107114 save_path = f"./results/{ dataset } /{ model_name } /ipc{ ipc } /obj_scores.csv"
108115 if not os .path .exists (os .path .dirname (save_path )):
@@ -139,6 +146,15 @@ def load_real_data(self, dataset, class_map, num_classes):
139146
140147 return images_all , labels_all , class_indices
141148
149+ def get_class_indices (self , dataset , class_map , num_classes ):
150+ class_indices = [[] for c in range (num_classes )]
151+ for idx , (_ , label ) in enumerate (dataset .imgs ):
152+ if torch .is_tensor (label ):
153+ label = label .item ()
154+ true_label = class_map [label ]
155+ class_indices [true_label ].append (idx )
156+ return class_indices
157+
142158 def hyper_param_search_for_hard_label (self , image_tensor , image_path , hard_labels , mode = 'real' ):
143159 lr_list = [0.001 , 0.005 , 0.01 , 0.05 , 0.1 ]
144160 best_acc = 0
@@ -177,7 +193,7 @@ def hyper_param_search_for_soft_label(self, image_tensor, image_path, soft_label
177193 model = build_model (
178194 model_name = self .model_name ,
179195 num_classes = self .num_classes ,
180- im_size = self .im_size ,
196+ im_size = self .im_size ,
181197 pretrained = False ,
182198 use_torchvision = self .stu_use_torchvision ,
183199 device = self .device
@@ -196,11 +212,13 @@ def hyper_param_search_for_soft_label(self, image_tensor, image_path, soft_label
196212 return best_acc , best_lr
197213
198214 def compute_hard_label_metrics (self , model , image_tensor , image_path , lr , hard_labels , mode = 'real' ):
199-
200- if image_tensor is None :
201- hard_label_dataset = datasets .ImageFolder (root = image_path , transform = self .custom_train_trans )
215+ if mode == 'real' :
216+ hard_label_dataset = self .dst_train
202217 else :
203- hard_label_dataset = TensorDataset (image_tensor , hard_labels )
218+ if image_tensor is None :
219+ hard_label_dataset = datasets .ImageFolder (root = image_path , transform = self .custom_train_trans )
220+ else :
221+ hard_label_dataset = TensorDataset (image_tensor , hard_labels )
204222 train_loader = DataLoader (hard_label_dataset , batch_size = self .real_batch_size if mode == 'real' else self .syn_batch_size ,
205223 num_workers = self .num_workers , shuffle = True )
206224
@@ -216,15 +234,15 @@ def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_l
216234 loader = train_loader ,
217235 loss_fn = loss_fn ,
218236 optimizer = optimizer ,
219- aug_func = self .aug_func ,
237+ aug_func = self .aug_func if self . use_aug_for_hard else None ,
220238 lr_scheduler = lr_scheduler ,
221239 tea_model = self .teacher_model ,
222240 device = self .device
223241 )
224242 if epoch > 0.8 * self .num_epochs and (epoch + 1 ) % self .test_interval == 0 :
225243 metric = validate (
226244 model = model ,
227- loader = self .test_loader ,
245+ loader = self .test_loader_real ,
228246 device = self .device
229247 )
230248 if metric ['top1' ] > best_acc1 :
@@ -272,7 +290,7 @@ def compute_soft_label_metrics(self, model, image_tensor, image_path, lr, soft_l
272290 if epoch > 0.8 * self .num_epochs and (epoch + 1 ) % self .test_interval == 0 :
273291 metric = validate (
274292 model = model ,
275- loader = self .test_loader ,
293+ loader = self .test_loader_syn ,
276294 device = self .device
277295 )
278296 if metric ['top1' ] > best_acc1 :
@@ -326,10 +344,10 @@ def compute_metrics(self, image_tensor: Tensor=None, image_path: str=None, soft_
326344 )
327345 full_data_hard_label_acc = self .compute_hard_label_metrics (
328346 model = model ,
329- image_tensor = self . images_train ,
347+ image_tensor = None ,
330348 image_path = None ,
331349 lr = self .default_lr ,
332- hard_labels = self . labels_train ,
350+ hard_labels = None ,
333351 mode = 'real'
334352 )
335353 del model
@@ -362,7 +380,7 @@ def compute_metrics(self, image_tensor: Tensor=None, image_path: str=None, soft_
362380 print (f"Syn data soft label acc: { syn_data_soft_label_acc :.2f} %" )
363381
364382 print ("Caculating random data soft label metrics..." )
365- random_images , _ = get_random_images (self .images_train , self .labels_train , self . class_indices_train , self .ipc )
383+ random_images , _ = get_random_images (self .dst_train , self .class_indices , self .ipc )
366384 if self .soft_label_mode == 'S' :
367385 random_data_soft_labels = self .generate_soft_labels (random_images )
368386 else :
0 commit comments