Skip to content

Commit 5b3fc13

Browse files
Restructure the repo
1 parent 8896385 commit 5b3fc13

File tree

10 files changed

+992
-897
lines changed

10 files changed

+992
-897
lines changed

dd_ranking/metrics/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .dd_ranking_unified import Unified_Evaluator
2-
from .dd_ranking_obj import Soft_Label_Objective_Metrics, Hard_Label_Objective_Metrics
3-
from .dd_ranking_aug import Augmentation_Metrics, DSA_Augmentation_Metrics, ZCA_Whitening_Augmentation_Metrics, Mixup_Augmentation_Metrics, Cutmix_Augmentation_Metrics
2+
from .dd_ranking_obj import Soft_Label_Evaluator, Hard_Label_Evaluator
3+
# from .dd_ranking_aug import Augmentation_Evaluator, DSA_Augmentation_Evaluator, ZCA_Whitening_Augmentation_Evaluator, Mixup_Augmentation_Evaluator, Cutmix_Augmentation_Evaluator
Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,8 @@
99
from tqdm import tqdm
1010
from torch.utils.data import DataLoader
1111
from torch.nn import CrossEntropyLoss
12-
from torch.optim import SGD
13-
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
1412
from torchvision import transforms, datasets
15-
from dd_ranking.utils import build_model, get_pretrained_model_path
16-
from dd_ranking.utils import TensorDataset, get_random_images, get_dataset
13+
from dd_ranking.utils import build_model, get_pretrained_model_path, get_dataset, TensorDataset
1714
from dd_ranking.utils import set_seed, get_optimizer, get_lr_scheduler
1815
from dd_ranking.utils import train_one_epoch, validate
1916
from dd_ranking.loss import SoftCrossEntropyLoss, KLDivergenceLoss
@@ -39,12 +36,17 @@ def __init__(self,
3936
num_eval: int=5,
4037
im_size: tuple=(32, 32),
4138
num_epochs: int=300,
42-
batch_size: int=256,
39+
real_batch_size: int=256,
40+
syn_batch_size: int=256,
4341
weight_decay: float=0.0005,
4442
momentum: float=0.9,
4543
use_zca: bool=False,
4644
temperature: float=1.0,
47-
use_torchvision: bool=False,
45+
stu_use_torchvision: bool=False,
46+
tea_use_torchvision: bool=False,
47+
teacher_dir: str='./teacher_models',
48+
custom_train_trans: transforms.Compose=None,
49+
custom_val_trans: transforms.Compose=None,
4850
num_workers: int=4,
4951
save_path: str=None,
5052
device: str="cuda"
@@ -78,20 +80,30 @@ def __init__(self,
7880
num_eval = self.config.get('num_eval', 5)
7981
im_size = self.config.get('im_size', (32, 32))
8082
num_epochs = self.config.get('num_epochs', 300)
81-
batch_size = self.config.get('batch_size', 256)
83+
real_batch_size = self.config.get('real_batch_size', 256)
84+
syn_batch_size = self.config.get('syn_batch_size', 256)
8285
default_lr = self.config.get('default_lr', 0.01)
8386
save_path = self.config.get('save_path', None)
8487
num_workers = self.config.get('num_workers', 4)
85-
use_torchvision = self.config.get('use_torchvision', False)
88+
stu_use_torchvision = self.config.get('stu_use_torchvision', False)
89+
tea_use_torchvision = self.config.get('tea_use_torchvision', False)
90+
custom_train_trans = self.config.get('custom_train_trans', None)
91+
custom_val_trans = self.config.get('custom_val_trans', None)
8692
device = self.config.get('device', 'cuda')
8793

88-
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset, real_data_path, im_size, use_zca)
94+
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset,
95+
real_data_path,
96+
im_size,
97+
custom_val_trans,
98+
use_zca)
8999
self.num_classes = num_classes
90100
self.im_size = im_size
91-
self.test_loader = DataLoader(dst_test, batch_size=batch_size, num_workers=num_workers, shuffle=False)
101+
self.real_test_loader = DataLoader(dst_test, batch_size=real_batch_size, num_workers=num_workers, shuffle=False)
92102

93103
self.ipc = ipc
94104
self.model_name = model_name
105+
self.stu_use_torchvision = stu_use_torchvision
106+
self.custom_train_trans = custom_train_trans
95107
self.use_soft_label = use_soft_label
96108
if use_soft_label:
97109
assert soft_label_mode is not None, "soft_label_mode must be provided if use_soft_label is True"
@@ -107,7 +119,7 @@ def __init__(self,
107119

108120
self.num_eval = num_eval
109121
self.num_epochs = num_epochs
110-
self.batch_size = batch_size
122+
self.syn_batch_size = syn_batch_size
111123
self.device = device
112124

113125
if not save_path:
@@ -117,7 +129,7 @@ def __init__(self,
117129
self.save_path = save_path
118130

119131
if not use_torchvision:
120-
pretrained_model_path = get_pretrained_model_path(model_name, dataset, ipc)
132+
pretrained_model_path = get_pretrained_model_path(teacher_dir, model_name, dataset, ipc)
121133
else:
122134
pretrained_model_path = None
123135

@@ -128,15 +140,14 @@ def __init__(self,
128140
pretrained=True,
129141
device=self.device,
130142
model_path=pretrained_model_path,
131-
use_torchvision=use_torchvision
143+
use_torchvision=tea_use_torchvision
132144
)
133145
self.teacher_model.eval()
134146

135147
if data_aug_func is None:
136148
self.aug_func = None
137149
elif data_aug_func == 'dsa':
138150
self.aug_func = DSA_Augmentation(aug_params)
139-
self.num_epochs = 1000
140151
elif data_aug_func == 'mixup':
141152
self.aug_func = Mixup_Augmentation(aug_params)
142153
elif data_aug_func == 'cutmix':
@@ -145,7 +156,7 @@ def __init__(self,
145156
raise ValueError(f"Invalid data augmentation function: {data_aug_func}")
146157

147158
def generate_soft_labels(self, images):
148-
batches = torch.split(images, self.batch_size)
159+
batches = torch.split(images, self.syn_batch_size)
149160
soft_labels = []
150161
with torch.no_grad():
151162
for image_batch in batches:
@@ -164,12 +175,13 @@ def hyper_param_search(self, loader):
164175
model_name=self.model_name,
165176
num_classes=self.num_classes,
166177
im_size=self.im_size,
167-
pretrained=False,
178+
pretrained=False,
179+
use_torchvision=self.stu_use_torchvision,
168180
device=self.device
169181
)
170182
acc = self.compute_metrics_helper(
171183
model=model,
172-
loader=loader,
184+
loader=loader,
173185
lr=lr
174186
)
175187
if acc > best_acc:
@@ -180,13 +192,13 @@ def hyper_param_search(self, loader):
180192
def get_loss_fn(self):
181193
if self.use_soft_label:
182194
if self.soft_label_criterion == 'kl':
183-
return KLDivergenceLoss(temperature=self.temperature)
195+
return KLDivergenceLoss(temperature=self.temperature).to(self.device)
184196
elif self.soft_label_criterion == 'sce':
185-
return SoftCrossEntropyLoss()
186-
else:
197+
return SoftCrossEntropyLoss(temperature=self.temperature).to(self.device)
198+
else:
187199
raise ValueError(f"Invalid soft label criterion: {self.soft_label_criterion}")
188200
else:
189-
return nn.CrossEntropyLoss()
201+
return CrossEntropyLoss().to(self.device)
190202

191203
def compute_metrics_helper(self, model, loader, lr):
192204
loss_fn = self.get_loss_fn()
@@ -218,9 +230,28 @@ def compute_metrics_helper(self, model, loader, lr):
218230
best_acc = acc
219231
return best_acc
220232

221-
def compute_metrics(self, images, labels, syn_lr=None):
222-
syn_dataset = TensorDataset(images, labels)
223-
syn_loader = DataLoader(syn_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
233+
def compute_metrics(self, image_tensor: Tensor=None, image_path: str=None, labels: Tensor=None, syn_lr=None):
234+
if image_tensor is None and image_path is None:
235+
raise ValueError("Either image_tensor or image_path must be provided")
236+
237+
if self.use_soft_label and self.soft_label_mode == 'S' and labels is None:
238+
raise ValueError("labels must be provided if soft_label_mode is 'S'")
239+
240+
if image_tensor is None:
241+
syn_dataset = datasets.ImageFolder(root=image_path, transform=self.custom_train_trans)
242+
if labels is not None:
243+
syn_dataset.samples = [(path, labels[idx]) for idx, (path, _) in enumerate(syn_dataset.samples)]
244+
syn_dataset.targets = labels
245+
else:
246+
if labels is not None:
247+
syn_dataset = TensorDataset(image_tensor, labels, transform=self.custom_train_trans)
248+
else:
249+
# use hard labels if labels are not provided
250+
default_labels = torch.tensor(np.array([np.ones(self.ipc) * i for i in range(self.num_classes)]),
251+
dtype=torch.long, requires_grad=False).view(-1)
252+
syn_dataset = TensorDataset(image_tensor, default_labels, transform=self.custom_train_trans)
253+
254+
syn_loader = DataLoader(syn_dataset, batch_size=self.syn_batch_size, shuffle=True, num_workers=4)
224255

225256
accs = []
226257
lrs = []
@@ -232,12 +263,13 @@ def compute_metrics(self, images, labels, syn_lr=None):
232263
model_name=self.model_name,
233264
num_classes=self.num_classes,
234265
im_size=self.im_size,
235-
pretrained=False,
266+
pretrained=False,
267+
use_torchvision=self.stu_use_torchvision,
236268
device=self.device
237269
)
238270
syn_data_acc = self.compute_metrics_helper(
239-
model=model,
240-
loader=syn_loader,
271+
model=model,
272+
loader=syn_loader,
241273
lr=syn_lr
242274
)
243275
del model

0 commit comments

Comments
 (0)