Skip to content

Commit dc3a1bb

Browse files
2 parents e1c6107 + 5dbfce1 commit dc3a1bb

File tree

16 files changed

+1085
-958
lines changed

16 files changed

+1085
-958
lines changed

configs/Demo_Hard_Label.yaml

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11

22
# real data
3-
dataset: "CIFAR10"
4-
real_data_path: "./dataset/"
5-
custom_val_trans: None
3+
dataset: CIFAR10
4+
real_data_path: ./dataset/
5+
custom_val_trans: null
66

77
# synthetic data
88
ipc: 10
9-
im_size: (32, 32)
9+
im_size: [32, 32]
1010

1111
# agent model
12-
model_name: "ConvNet-3"
13-
use_torchvision: False
12+
model_name: ConvNet-3
13+
use_torchvision: false
1414

1515
# data augmentation
1616
data_aug_func: "dsa"
17-
aug_params: {
18-
"prob_flip": 0.5,
19-
"ratio_rotate": 15.0,
20-
"saturation": 2.0,
21-
"brightness": 1.0,
22-
"contrast": 0.5,
23-
"ratio_scale": 1.2,
24-
"ratio_crop_pad": 0.125,
25-
"ratio_cutout": 0.5
26-
}
27-
use_zca: False
17+
aug_params:
18+
prob_flip: 0.5
19+
ratio_rotate: 15.0
20+
saturation: 2.0
21+
brightness: 1.0
22+
contrast: 0.5
23+
ratio_scale: 1.2
24+
ratio_crop_pad: 0.125
25+
ratio_cutout: 0.5
26+
use_zca: false
27+
28+
custom_train_trans: null
29+
custom_val_trans: null
2830

2931
# training specifics
30-
optimizer: "sgd"
31-
lr_scheduler: "step"
32+
optimizer: sgd
33+
lr_scheduler: step
3234
weight_decay: 0.0005
3335
momentum: 0.9
3436
num_eval: 5
@@ -37,7 +39,7 @@ syn_batch_size: 128
3739
real_batch_size: 256
3840
default_lr: 0.01
3941
num_workers: 4
40-
device: "cuda"
42+
device: cuda
4143

4244
# save path
43-
save_path: "./results/my_method_hard_label_scores.csv"
45+
save_path: ./results/my_method_hard_label_scores.csv

configs/Demo_Soft_Label.yaml

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,49 @@
11
# real data
2-
dataset: "CIFAR10"
3-
real_data_path: "./dataset/"
4-
custom_val_trans: None
2+
dataset: CIFAR10
3+
real_data_path: ./dataset/
4+
custom_val_trans: null
55

66
# synthetic data
77
ipc: 10
8-
im_size: (32, 32)
8+
im_size: [32, 32]
99

1010
# agent model
11-
model_name: "ConvNet-3"
12-
stu_use_torchvision: False
13-
tea_use_torchvision: False
14-
teacher_dir: "./teacher_models"
11+
model_name: ConvNet-3
12+
stu_use_torchvision: false
13+
tea_use_torchvision: false
14+
teacher_dir: ./teacher_models
1515

16-
# data augmentation
17-
data_aug_func: "dsa"
18-
aug_params: {
19-
"prob_flip": 0.5,
20-
"ratio_rotate": 15.0,
21-
"saturation": 2.0,
22-
"brightness": 1.0,
23-
"contrast": 0.5,
24-
"ratio_scale": 1.2,
25-
"ratio_crop_pad": 0.125,
26-
"ratio_cutout": 0.5
27-
}
28-
use_zca: True
16+
# syntheticdata augmentation
17+
data_aug_func: dsa
18+
aug_params:
19+
prob_flip: 0.5
20+
ratio_rotate: 15.0
21+
saturation: 2.0
22+
brightness: 1.0
23+
contrast: 0.5
24+
ratio_scale: 1.2
25+
ratio_crop_pad: 0.125
26+
ratio_cutout: 0.5
27+
use_zca: true
28+
custom_train_trans: null
2929

3030
# soft label settings
31-
soft_label_mode: "S"
32-
soft_label_criterion: "sce"
31+
soft_label_mode: S
32+
soft_label_criterion: sce
3333
temperature: 1.0
3434

3535
# training specifics
36-
optimizer: "sgd"
37-
lr_scheduler: "step"
36+
optimizer: sgd
37+
lr_scheduler: step
3838
weight_decay: 0.0005
3939
momentum: 0.9
4040
num_eval: 5
4141
num_epochs: 1000
4242
default_lr: 0.01
4343
num_workers: 4
44-
device: "cuda"
44+
device: cuda
4545
syn_batch_size: 128
4646
real_batch_size: 256
4747

4848
# save path
49-
save_path: "./results/my_method_soft_label_scores.csv"
49+
save_path: ./my_method_soft_label_scores.csv

dd_ranking/config/user_config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import yaml
22
import json
33
from typing import Dict, Any
4-
4+
from torchvision import transforms
55

66
class Config:
77
"""Configuration object to manage individual configurations."""
@@ -21,9 +21,30 @@ def from_file(cls, filepath: str):
2121
else:
2222
raise ValueError("Unsupported file format. Use YAML or JSON.")
2323
return cls(config)
24+
25+
def load_transforms_from_yaml(self, values):
26+
if values is None:
27+
return None
28+
transform_list = []
29+
for transform in values:
30+
name = transform["name"]
31+
args = transform.get("args", [])
32+
if isinstance(args, dict):
33+
transform_list.append(getattr(transforms, name)(**args))
34+
else:
35+
transform_list.append(getattr(transforms, name)(*args))
36+
37+
return transforms.Compose(transform_list)
2438

2539
def get(self, key: str, default: Any = None):
2640
"""Get a value from the config."""
41+
if key == "custom_train_trans":
42+
return self.load_transforms_from_yaml(self.config["custom_train_trans"])
43+
elif key == "custom_val_trans":
44+
return self.load_transforms_from_yaml(self.config["custom_val_trans"])
45+
elif key == "im_size":
46+
return tuple(self.config.get("im_size", default))
47+
2748
return self.config.get(key, default)
2849

2950
def update(self, overrides: Dict[str, Any]):

dd_ranking/metrics/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .dd_ranking_unified import Unified_Evaluator
2-
from .dd_ranking_obj import Soft_Label_Objective_Metrics, Hard_Label_Objective_Metrics
3-
from .dd_ranking_aug import Augmentation_Metrics, DSA_Augmentation_Metrics, ZCA_Whitening_Augmentation_Metrics, Mixup_Augmentation_Metrics, Cutmix_Augmentation_Metrics
1+
from .general import Unified_Evaluator
2+
from .soft_label import Soft_Label_Evaluator
3+
from .hard_label import Hard_Label_Evaluator
Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,8 @@
99
from tqdm import tqdm
1010
from torch.utils.data import DataLoader
1111
from torch.nn import CrossEntropyLoss
12-
from torch.optim import SGD
13-
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
1412
from torchvision import transforms, datasets
15-
from dd_ranking.utils import build_model, get_pretrained_model_path
16-
from dd_ranking.utils import TensorDataset, get_random_images, get_dataset
13+
from dd_ranking.utils import build_model, get_pretrained_model_path, get_dataset, TensorDataset
1714
from dd_ranking.utils import set_seed, get_optimizer, get_lr_scheduler
1815
from dd_ranking.utils import train_one_epoch, validate
1916
from dd_ranking.loss import SoftCrossEntropyLoss, KLDivergenceLoss
@@ -39,12 +36,17 @@ def __init__(self,
3936
num_eval: int=5,
4037
im_size: tuple=(32, 32),
4138
num_epochs: int=300,
42-
batch_size: int=256,
39+
real_batch_size: int=256,
40+
syn_batch_size: int=256,
4341
weight_decay: float=0.0005,
4442
momentum: float=0.9,
4543
use_zca: bool=False,
4644
temperature: float=1.0,
47-
use_torchvision: bool=False,
45+
stu_use_torchvision: bool=False,
46+
tea_use_torchvision: bool=False,
47+
teacher_dir: str='./teacher_models',
48+
custom_train_trans: transforms.Compose=None,
49+
custom_val_trans: transforms.Compose=None,
4850
num_workers: int=4,
4951
save_path: str=None,
5052
device: str="cuda"
@@ -78,20 +80,30 @@ def __init__(self,
7880
num_eval = self.config.get('num_eval', 5)
7981
im_size = self.config.get('im_size', (32, 32))
8082
num_epochs = self.config.get('num_epochs', 300)
81-
batch_size = self.config.get('batch_size', 256)
83+
real_batch_size = self.config.get('real_batch_size', 256)
84+
syn_batch_size = self.config.get('syn_batch_size', 256)
8285
default_lr = self.config.get('default_lr', 0.01)
8386
save_path = self.config.get('save_path', None)
8487
num_workers = self.config.get('num_workers', 4)
85-
use_torchvision = self.config.get('use_torchvision', False)
88+
stu_use_torchvision = self.config.get('stu_use_torchvision', False)
89+
tea_use_torchvision = self.config.get('tea_use_torchvision', False)
90+
custom_train_trans = self.config.get('custom_train_trans', None)
91+
custom_val_trans = self.config.get('custom_val_trans', None)
8692
device = self.config.get('device', 'cuda')
8793

88-
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset, real_data_path, im_size, use_zca)
94+
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset,
95+
real_data_path,
96+
im_size,
97+
custom_val_trans,
98+
use_zca)
8999
self.num_classes = num_classes
90100
self.im_size = im_size
91-
self.test_loader = DataLoader(dst_test, batch_size=batch_size, num_workers=num_workers, shuffle=False)
101+
self.real_test_loader = DataLoader(dst_test, batch_size=real_batch_size, num_workers=num_workers, shuffle=False)
92102

93103
self.ipc = ipc
94104
self.model_name = model_name
105+
self.stu_use_torchvision = stu_use_torchvision
106+
self.custom_train_trans = custom_train_trans
95107
self.use_soft_label = use_soft_label
96108
if use_soft_label:
97109
assert soft_label_mode is not None, "soft_label_mode must be provided if use_soft_label is True"
@@ -107,7 +119,7 @@ def __init__(self,
107119

108120
self.num_eval = num_eval
109121
self.num_epochs = num_epochs
110-
self.batch_size = batch_size
122+
self.syn_batch_size = syn_batch_size
111123
self.device = device
112124

113125
if not save_path:
@@ -117,7 +129,7 @@ def __init__(self,
117129
self.save_path = save_path
118130

119131
if not use_torchvision:
120-
pretrained_model_path = get_pretrained_model_path(model_name, dataset, ipc)
132+
pretrained_model_path = get_pretrained_model_path(teacher_dir, model_name, dataset, ipc)
121133
else:
122134
pretrained_model_path = None
123135

@@ -128,15 +140,14 @@ def __init__(self,
128140
pretrained=True,
129141
device=self.device,
130142
model_path=pretrained_model_path,
131-
use_torchvision=use_torchvision
143+
use_torchvision=tea_use_torchvision
132144
)
133145
self.teacher_model.eval()
134146

135147
if data_aug_func is None:
136148
self.aug_func = None
137149
elif data_aug_func == 'dsa':
138150
self.aug_func = DSA_Augmentation(aug_params)
139-
self.num_epochs = 1000
140151
elif data_aug_func == 'mixup':
141152
self.aug_func = Mixup_Augmentation(aug_params)
142153
elif data_aug_func == 'cutmix':
@@ -145,7 +156,7 @@ def __init__(self,
145156
raise ValueError(f"Invalid data augmentation function: {data_aug_func}")
146157

147158
def generate_soft_labels(self, images):
148-
batches = torch.split(images, self.batch_size)
159+
batches = torch.split(images, self.syn_batch_size)
149160
soft_labels = []
150161
with torch.no_grad():
151162
for image_batch in batches:
@@ -164,12 +175,13 @@ def hyper_param_search(self, loader):
164175
model_name=self.model_name,
165176
num_classes=self.num_classes,
166177
im_size=self.im_size,
167-
pretrained=False,
178+
pretrained=False,
179+
use_torchvision=self.stu_use_torchvision,
168180
device=self.device
169181
)
170182
acc = self.compute_metrics_helper(
171183
model=model,
172-
loader=loader,
184+
loader=loader,
173185
lr=lr
174186
)
175187
if acc > best_acc:
@@ -180,13 +192,13 @@ def hyper_param_search(self, loader):
180192
def get_loss_fn(self):
181193
if self.use_soft_label:
182194
if self.soft_label_criterion == 'kl':
183-
return KLDivergenceLoss(temperature=self.temperature)
195+
return KLDivergenceLoss(temperature=self.temperature).to(self.device)
184196
elif self.soft_label_criterion == 'sce':
185-
return SoftCrossEntropyLoss()
186-
else:
197+
return SoftCrossEntropyLoss(temperature=self.temperature).to(self.device)
198+
else:
187199
raise ValueError(f"Invalid soft label criterion: {self.soft_label_criterion}")
188200
else:
189-
return nn.CrossEntropyLoss()
201+
return CrossEntropyLoss().to(self.device)
190202

191203
def compute_metrics_helper(self, model, loader, lr):
192204
loss_fn = self.get_loss_fn()
@@ -195,7 +207,7 @@ def compute_metrics_helper(self, model, loader, lr):
195207
scheduler = get_lr_scheduler(optimizer, self.lr_scheduler, self.num_epochs)
196208

197209
best_acc = 0
198-
for epoch in range(self.num_epochs):
210+
for epoch in tqdm(range(self.num_epochs), total=self.num_epochs, desc="Training"):
199211
train_one_epoch(
200212
model=model,
201213
loader=loader,
@@ -218,9 +230,28 @@ def compute_metrics_helper(self, model, loader, lr):
218230
best_acc = acc
219231
return best_acc
220232

221-
def compute_metrics(self, images, labels, syn_lr=None):
222-
syn_dataset = TensorDataset(images, labels)
223-
syn_loader = DataLoader(syn_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
233+
def compute_metrics(self, image_tensor: Tensor=None, image_path: str=None, labels: Tensor=None, syn_lr=None):
234+
if image_tensor is None and image_path is None:
235+
raise ValueError("Either image_tensor or image_path must be provided")
236+
237+
if self.use_soft_label and self.soft_label_mode == 'S' and labels is None:
238+
raise ValueError("labels must be provided if soft_label_mode is 'S'")
239+
240+
if image_tensor is None:
241+
syn_dataset = datasets.ImageFolder(root=image_path, transform=self.custom_train_trans)
242+
if labels is not None:
243+
syn_dataset.samples = [(path, labels[idx]) for idx, (path, _) in enumerate(syn_dataset.samples)]
244+
syn_dataset.targets = labels
245+
else:
246+
if labels is not None:
247+
syn_dataset = TensorDataset(image_tensor, labels, transform=self.custom_train_trans)
248+
else:
249+
# use hard labels if labels are not provided
250+
default_labels = torch.tensor(np.array([np.ones(self.ipc) * i for i in range(self.num_classes)]),
251+
dtype=torch.long, requires_grad=False).view(-1)
252+
syn_dataset = TensorDataset(image_tensor, default_labels, transform=self.custom_train_trans)
253+
254+
syn_loader = DataLoader(syn_dataset, batch_size=self.syn_batch_size, shuffle=True, num_workers=4)
224255

225256
accs = []
226257
lrs = []
@@ -232,12 +263,13 @@ def compute_metrics(self, images, labels, syn_lr=None):
232263
model_name=self.model_name,
233264
num_classes=self.num_classes,
234265
im_size=self.im_size,
235-
pretrained=False,
266+
pretrained=False,
267+
use_torchvision=self.stu_use_torchvision,
236268
device=self.device
237269
)
238270
syn_data_acc = self.compute_metrics_helper(
239-
model=model,
240-
loader=syn_loader,
271+
model=model,
272+
loader=syn_loader,
241273
lr=syn_lr
242274
)
243275
del model

0 commit comments

Comments
 (0)