Skip to content

Commit 97e4e41

Browse files
Update code
1 parent 9f5891b commit 97e4e41

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

dd_ranking/metrics/general.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(self,
9494
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset,
9595
real_data_path,
9696
im_size,
97+
custom_train_trans,
9798
custom_val_trans,
9899
use_zca)
99100
self.num_classes = num_classes

dd_ranking/metrics/hard_label.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
5454
real_data_path,
5555
im_size,
5656
use_zca,
57+
custom_train_trans,
5758
custom_val_trans,
5859
device)
5960
self.images_train, self.labels_train, self.class_indices_train = self.load_real_data(dst_train, class_map, num_classes)

dd_ranking/metrics/soft_label.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
2323
soft_label_criterion: str='kl', data_aug_func: str='cutmix', aug_params: dict={'beta': 1.0}, soft_label_mode: str='S',
2424
optimizer: str='sgd', lr_scheduler: str='step', temperature: float=1.0, weight_decay: float=0.0005,
2525
momentum: float=0.9, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False,
26-
real_batch_size: int=256, syn_batch_size: int=256, default_lr: float=0.01, save_path: str=None,
26+
real_batch_size: int=256, syn_batch_size: int=256, default_lr: float=0.01, save_path: str=None, use_aug_for_hard: bool=False,
2727
stu_use_torchvision: bool=False, tea_use_torchvision: bool=False, num_workers: int=4, teacher_dir: str='./teacher_models',
2828
custom_train_trans: transforms.Compose=None, custom_val_trans: transforms.Compose=None, device: str="cuda"):
2929

@@ -60,6 +60,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
6060
real_data_path,
6161
im_size,
6262
use_zca,
63+
custom_train_trans,
6364
custom_val_trans,
6465
device)
6566
self.images_train, self.labels_train, self.class_indices_train = self.load_real_data(dst_train, class_map, num_classes)
@@ -99,6 +100,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
99100
self.aug_func = Cutmix(aug_params)
100101
else:
101102
self.aug_func = None
103+
self.use_aug_for_hard = use_aug_for_hard
102104

103105
if not save_path:
104106
save_path = f"./results/{dataset}/{model_name}/ipc{ipc}/obj_scores.csv"
@@ -213,7 +215,7 @@ def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_l
213215
loader=train_loader,
214216
loss_fn=loss_fn,
215217
optimizer=optimizer,
216-
aug_func=self.aug_func,
218+
aug_func=self.aug_func if self.use_aug_for_hard else None,
217219
lr_scheduler=lr_scheduler,
218220
tea_model=self.teacher_model,
219221
device=self.device

dd_ranking/utils/data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __len__(self):
4848
return len(self.images)
4949

5050

51-
def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
51+
def get_dataset(dataset, data_path, im_size, use_zca, custom_train_trans, custom_val_trans, device):
5252
class_map_inv = None
5353

5454
if dataset == 'CIFAR10':
@@ -68,7 +68,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
6868
transforms.ToTensor()
6969
])
7070

71-
dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
71+
dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform if custom_train_trans is None else custom_train_trans)
7272
dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform if custom_val_trans is None else custom_val_trans)
7373
class_map = {x: x for x in range(num_classes)}
7474

@@ -89,7 +89,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
8989
transforms.ToTensor()
9090
])
9191

92-
dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform)
92+
dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform if custom_train_trans is None else custom_train_trans)
9393
dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform if custom_val_trans is None else custom_val_trans)
9494
class_map = {x: x for x in range(num_classes)}
9595

@@ -108,7 +108,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
108108
transform = transforms.Compose([
109109
transforms.ToTensor()
110110
])
111-
dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform)
111+
dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform if custom_train_trans is None else custom_train_trans)
112112
dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform if custom_val_trans is None else custom_val_trans)
113113
class_map = {x: x for x in range(num_classes)}
114114

@@ -129,7 +129,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
129129
transforms.CenterCrop(im_size)
130130
])
131131

132-
dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform)
132+
dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform if custom_train_trans is None else custom_train_trans)
133133
dst_train = torch.utils.data.Subset(dst_train, np.squeeze(np.argwhere(np.isin(dst_train.targets, config.img_net_classes))))
134134
dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform if custom_val_trans is None else custom_val_trans)
135135
dst_test = torch.utils.data.Subset(dst_test, np.squeeze(np.argwhere(np.isin(dst_test.targets, config.img_net_classes))))
@@ -153,7 +153,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
153153
transforms.CenterCrop(im_size)
154154
])
155155

156-
dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform)
156+
dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform if custom_train_trans is None else custom_train_trans)
157157
dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform if custom_val_trans is None else custom_val_trans)
158158

159159
class_map = {x: i for i, x in enumerate(range(num_classes))}

0 commit comments

Comments
 (0)