99from tqdm import tqdm
1010from torch .utils .data import DataLoader
1111from torch .nn import CrossEntropyLoss
12- from torch .optim import SGD
13- from torch .optim .lr_scheduler import CosineAnnealingLR , StepLR
1412from torchvision import transforms , datasets
15- from dd_ranking .utils import build_model , get_pretrained_model_path
16- from dd_ranking .utils import TensorDataset , get_random_images , get_dataset
13+ from dd_ranking .utils import build_model , get_pretrained_model_path , get_dataset , TensorDataset
1714from dd_ranking .utils import set_seed , get_optimizer , get_lr_scheduler
1815from dd_ranking .utils import train_one_epoch , validate
1916from dd_ranking .loss import SoftCrossEntropyLoss , KLDivergenceLoss
@@ -39,12 +36,17 @@ def __init__(self,
3936 num_eval : int = 5 ,
4037 im_size : tuple = (32 , 32 ),
4138 num_epochs : int = 300 ,
42- batch_size : int = 256 ,
39+ real_batch_size : int = 256 ,
40+ syn_batch_size : int = 256 ,
4341 weight_decay : float = 0.0005 ,
4442 momentum : float = 0.9 ,
4543 use_zca : bool = False ,
4644 temperature : float = 1.0 ,
47- use_torchvision : bool = False ,
45+ stu_use_torchvision : bool = False ,
46+ tea_use_torchvision : bool = False ,
47+ teacher_dir : str = './teacher_models' ,
48+ custom_train_trans : transforms .Compose = None ,
49+ custom_val_trans : transforms .Compose = None ,
4850 num_workers : int = 4 ,
4951 save_path : str = None ,
5052 device : str = "cuda"
@@ -78,20 +80,30 @@ def __init__(self,
7880 num_eval = self .config .get ('num_eval' , 5 )
7981 im_size = self .config .get ('im_size' , (32 , 32 ))
8082 num_epochs = self .config .get ('num_epochs' , 300 )
81- batch_size = self .config .get ('batch_size' , 256 )
83+ real_batch_size = self .config .get ('real_batch_size' , 256 )
84+ syn_batch_size = self .config .get ('syn_batch_size' , 256 )
8285 default_lr = self .config .get ('default_lr' , 0.01 )
8386 save_path = self .config .get ('save_path' , None )
8487 num_workers = self .config .get ('num_workers' , 4 )
85- use_torchvision = self .config .get ('use_torchvision' , False )
88+ stu_use_torchvision = self .config .get ('stu_use_torchvision' , False )
89+ tea_use_torchvision = self .config .get ('tea_use_torchvision' , False )
90+ custom_train_trans = self .config .get ('custom_train_trans' , None )
91+ custom_val_trans = self .config .get ('custom_val_trans' , None )
8692 device = self .config .get ('device' , 'cuda' )
8793
88- channel , im_size , num_classes , dst_train , dst_test , class_map , class_map_inv = get_dataset (dataset , real_data_path , im_size , use_zca )
94+ channel , im_size , num_classes , dst_train , dst_test , class_map , class_map_inv = get_dataset (dataset ,
95+ real_data_path ,
96+ im_size ,
97+ custom_val_trans ,
98+ use_zca )
8999 self .num_classes = num_classes
90100 self .im_size = im_size
91- self .test_loader = DataLoader (dst_test , batch_size = batch_size , num_workers = num_workers , shuffle = False )
101+ self .real_test_loader = DataLoader (dst_test , batch_size = real_batch_size , num_workers = num_workers , shuffle = False )
92102
93103 self .ipc = ipc
94104 self .model_name = model_name
105+ self .stu_use_torchvision = stu_use_torchvision
106+ self .custom_train_trans = custom_train_trans
95107 self .use_soft_label = use_soft_label
96108 if use_soft_label :
97109 assert soft_label_mode is not None , "soft_label_mode must be provided if use_soft_label is True"
@@ -107,7 +119,7 @@ def __init__(self,
107119
108120 self .num_eval = num_eval
109121 self .num_epochs = num_epochs
110- self .batch_size = batch_size
122+ self .syn_batch_size = syn_batch_size
111123 self .device = device
112124
113125 if not save_path :
@@ -117,7 +129,7 @@ def __init__(self,
117129 self .save_path = save_path
118130
119131 if not use_torchvision :
120- pretrained_model_path = get_pretrained_model_path (model_name , dataset , ipc )
132+ pretrained_model_path = get_pretrained_model_path (teacher_dir , model_name , dataset , ipc )
121133 else :
122134 pretrained_model_path = None
123135
@@ -128,15 +140,14 @@ def __init__(self,
128140 pretrained = True ,
129141 device = self .device ,
130142 model_path = pretrained_model_path ,
131- use_torchvision = use_torchvision
143+ use_torchvision = tea_use_torchvision
132144 )
133145 self .teacher_model .eval ()
134146
135147 if data_aug_func is None :
136148 self .aug_func = None
137149 elif data_aug_func == 'dsa' :
138150 self .aug_func = DSA_Augmentation (aug_params )
139- self .num_epochs = 1000
140151 elif data_aug_func == 'mixup' :
141152 self .aug_func = Mixup_Augmentation (aug_params )
142153 elif data_aug_func == 'cutmix' :
@@ -145,7 +156,7 @@ def __init__(self,
145156 raise ValueError (f"Invalid data augmentation function: { data_aug_func } " )
146157
147158 def generate_soft_labels (self , images ):
148- batches = torch .split (images , self .batch_size )
159+ batches = torch .split (images , self .syn_batch_size )
149160 soft_labels = []
150161 with torch .no_grad ():
151162 for image_batch in batches :
@@ -164,12 +175,13 @@ def hyper_param_search(self, loader):
164175 model_name = self .model_name ,
165176 num_classes = self .num_classes ,
166177 im_size = self .im_size ,
167- pretrained = False ,
178+ pretrained = False ,
179+ use_torchvision = self .stu_use_torchvision ,
168180 device = self .device
169181 )
170182 acc = self .compute_metrics_helper (
171183 model = model ,
172- loader = loader ,
184+ loader = loader ,
173185 lr = lr
174186 )
175187 if acc > best_acc :
@@ -180,13 +192,13 @@ def hyper_param_search(self, loader):
180192 def get_loss_fn (self ):
181193 if self .use_soft_label :
182194 if self .soft_label_criterion == 'kl' :
183- return KLDivergenceLoss (temperature = self .temperature )
195+ return KLDivergenceLoss (temperature = self .temperature ). to ( self . device )
184196 elif self .soft_label_criterion == 'sce' :
185- return SoftCrossEntropyLoss ()
186- else :
197+ return SoftCrossEntropyLoss (temperature = self . temperature ). to ( self . device )
198+ else :
187199 raise ValueError (f"Invalid soft label criterion: { self .soft_label_criterion } " )
188200 else :
189- return nn . CrossEntropyLoss ()
201+ return CrossEntropyLoss (). to ( self . device )
190202
191203 def compute_metrics_helper (self , model , loader , lr ):
192204 loss_fn = self .get_loss_fn ()
@@ -218,9 +230,28 @@ def compute_metrics_helper(self, model, loader, lr):
218230 best_acc = acc
219231 return best_acc
220232
221- def compute_metrics (self , images , labels , syn_lr = None ):
222- syn_dataset = TensorDataset (images , labels )
223- syn_loader = DataLoader (syn_dataset , batch_size = self .batch_size , shuffle = True , num_workers = 4 )
233+ def compute_metrics (self , image_tensor : Tensor = None , image_path : str = None , labels : Tensor = None , syn_lr = None ):
234+ if image_tensor is None and image_path is None :
235+ raise ValueError ("Either image_tensor or image_path must be provided" )
236+
237+ if self .use_soft_label and self .soft_label_mode == 'S' and labels is None :
238+ raise ValueError ("labels must be provided if soft_label_mode is 'S'" )
239+
240+ if image_tensor is None :
241+ syn_dataset = datasets .ImageFolder (root = image_path , transform = self .custom_train_trans )
242+ if labels is not None :
243+ syn_dataset .samples = [(path , labels [idx ]) for idx , (path , _ ) in enumerate (syn_dataset .samples )]
244+ syn_dataset .targets = labels
245+ else :
246+ if labels is not None :
247+ syn_dataset = TensorDataset (image_tensor , labels , transform = self .custom_train_trans )
248+ else :
249+ # use hard labels if labels are not provided
250+ default_labels = torch .tensor (np .array ([np .ones (self .ipc ) * i for i in range (self .num_classes )]),
251+ dtype = torch .long , requires_grad = False ).view (- 1 )
252+ syn_dataset = TensorDataset (image_tensor , default_labels , transform = self .custom_train_trans )
253+
254+ syn_loader = DataLoader (syn_dataset , batch_size = self .syn_batch_size , shuffle = True , num_workers = 4 )
224255
225256 accs = []
226257 lrs = []
@@ -232,12 +263,13 @@ def compute_metrics(self, images, labels, syn_lr=None):
232263 model_name = self .model_name ,
233264 num_classes = self .num_classes ,
234265 im_size = self .im_size ,
235- pretrained = False ,
266+ pretrained = False ,
267+ use_torchvision = self .stu_use_torchvision ,
236268 device = self .device
237269 )
238270 syn_data_acc = self .compute_metrics_helper (
239- model = model ,
240- loader = syn_loader ,
271+ model = model ,
272+ loader = syn_loader ,
241273 lr = syn_lr
242274 )
243275 del model
0 commit comments