Skip to content

Commit 01b2ee7

Browse files
Update demo and config
1 parent ee072fa commit 01b2ee7

File tree

5 files changed

+47
-24
lines changed

5 files changed

+47
-24
lines changed

configs/Demo_Hard_Label.yaml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,26 @@ real_data_path: "./dataset/"
33
ipc: 10
44
model_name: "ConvNet-3"
55
data_aug_func: "dsa"
6-
aug_params: {"prob_flip": 0.5, "ratio_rotate": 15.0, "saturation": 2.0, "brightness": 1.0, "contrast": 0.5, "ratio_scale": 1.2, "ratio_crop_pad": 0.125, "ratio_cutout": 0.5}
6+
aug_params: {
7+
"prob_flip": 0.5,
8+
"ratio_rotate": 15.0,
9+
"saturation": 2.0,
10+
"brightness": 1.0,
11+
"contrast": 0.5,
12+
"ratio_scale": 1.2,
13+
"ratio_crop_pad": 0.125,
14+
"ratio_cutout": 0.5
15+
}
716
optimizer: "sgd"
8-
lr_scheduler: "sgd"
9-
weight_decay: 0.01
17+
lr_scheduler: "step"
18+
weight_decay: 0.0005
1019
momentum: 0.9
1120
num_eval: 5
1221
im_size: (32, 32)
13-
num_epochs: 300
22+
num_epochs: 1000
1423
use_zca: False
1524
batch_size: 256
1625
default_lr: 0.01
26+
num_workers: 4
1727
save_path: "./results/my_method_hard_label_scores.csv"
1828
device: "cuda"

configs/Demo_Soft_Label.yaml

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,30 @@ dataset: "CIFAR10"
22
real_data_path: "./dataset/"
33
ipc: 10
44
model_name: "ConvNet-3"
5-
data_aug_func: "cutmix"
6-
aug_params: {"cutmix_p": 1.0}
5+
data_aug_func: "dsa"
6+
aug_params: {
7+
"prob_flip": 0.5,
8+
"ratio_rotate": 15.0,
9+
"saturation": 2.0,
10+
"brightness": 1.0,
11+
"contrast": 0.5,
12+
"ratio_scale": 1.2,
13+
"ratio_crop_pad": 0.125,
14+
"ratio_cutout": 0.5
15+
}
716
soft_label_mode: "S"
8-
soft_label_criterion: "kl"
17+
soft_label_criterion: "sce"
918
optimizer: "sgd"
10-
lr_scheduler: "cosine"
19+
lr_scheduler: "step"
1120
temperature: 20.0
12-
weight_decay: 0.01
21+
weight_decay: 0.0005
1322
momentum: 0.9
1423
num_eval: 5
1524
im_size: (32, 32)
16-
num_epochs: 300
17-
use_zca: False
25+
num_epochs: 1000
26+
use_zca: True
1827
batch_size: 256
1928
default_lr: 0.01
29+
num_workers: 4
2030
save_path: "./results/my_method_soft_label_scores.csv"
21-
device: "cuda"
31+
device: "cuda"

dd_ranking/utils/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,6 @@ def validate(
534534

535535
model.eval()
536536

537-
if aug_func is None:
538-
aug_func = default_augmentation
539-
540537
end = time.time()
541538
last_idx = len(loader) - 1
542539
with torch.no_grad():
@@ -615,9 +612,6 @@ def validate_dc(
615612
model = model.to(device)
616613
model.eval()
617614

618-
if aug_func is None:
619-
aug_func = default_augmentation
620-
621615
for i_batch, datum in enumerate(loader):
622616
img = datum[0].to(device)
623617
lab = datum[1].to(device)

demo_hard.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
config = Config.from_file("./configs/Demo_Hard_Label.yaml")
99
convd3_hard_obj = Hard_Label_Objective_Metrics(config)
1010
syn_images = torch.load(os.path.join("./DC/CIFAR10/IPC10/", f"images.pt"), map_location='cpu')
11-
print(convd3_hard_obj.compute_metrics(syn_images))
11+
print(convd3_hard_obj.compute_metrics(syn_images, syn_lr=0.01))
1212

1313

1414
"""Use hardcoded parameters"""
@@ -17,6 +17,7 @@
1717
ipc = 10 # Specify your IPC
1818
dataset = "CIFAR10" # Specify your dataset name
1919
data_dir = "./datasets" # Specify your dataset path
20+
syn_data_dir = "./DM/CIFAR10/IPC10/" # Specify your synthetic data path
2021
model_name = "ConvNet-3" # Specify your model name
2122
im_size = (32, 32) # Specify your image size
2223

@@ -31,13 +32,14 @@
3132
"ratio_cutout": 0.5,
3233
}
3334

34-
syn_images = torch.load(os.path.join("./DC/CIFAR10/IPC10/", f"images.pt"), map_location='cpu')
35+
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
3536
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
3637
convd3_hard_obj = Hard_Label_Objective_Metrics(
3738
dataset=dataset,
3839
real_data_path=data_dir,
3940
ipc=ipc,
4041
model_name=model_name,
42+
default_lr=0.01,
4143
optimizer='sgd', # Use SGD optimizer
4244
lr_scheduler='step', # Use StepLR learning rate scheduler
4345
weight_decay=0.0005,
@@ -47,7 +49,9 @@
4749
data_aug_func='dsa', # Use DSA data augmentation
4850
aug_params=dsa_params, # Specify DSA parameters
4951
im_size=im_size,
52+
num_epochs=1000,
53+
num_workers=4,
5054
device=device,
5155
save_path=save_path
5256
)
53-
print(convd3_hard_obj.compute_metrics(syn_images))
57+
print(convd3_hard_obj.compute_metrics(syn_images, syn_lr=0.01))

demo_soft.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
config = Config.from_file("./configs/Demo_Soft_Label.yaml")
99
convd3_soft_obj = Soft_Label_Objective_Metrics(config)
1010
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
11-
print(convd3_soft_obj.compute_metrics(syn_images))
11+
soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu')
12+
print(convd3_soft_obj.compute_metrics(syn_images, soft_labels, syn_lr=0.01))
1213

1314

1415
"""Use hardcoded parameters"""
@@ -32,6 +33,7 @@
3233
}
3334

3435
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
36+
soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu')
3537
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
3638
convd3_hard_obj = Soft_Label_Objective_Metrics(
3739
dataset=dataset,
@@ -40,16 +42,19 @@
4042
model_name=model_name,
4143
soft_label_criterion='sce', # Use Soft Cross Entropy Loss
4244
soft_label_mode='S', # Use one-to-one image to soft label mapping
45+
default_lr=0.01,
4346
optimizer='sgd', # Use SGD optimizer
4447
lr_scheduler='step', # Use StepLR learning rate scheduler
4548
weight_decay=0.0005,
4649
momentum=0.9,
47-
use_zca=True, # Use ZCA whitening
50+
use_zca=True, # Use ZCA whitening (please disable it if you didn't use it to distill synthetic data)
4851
num_eval=5,
4952
data_aug_func='dsa', # Use DSA data augmentation
5053
aug_params=dsa_params, # Specify dsa parameters
5154
im_size=im_size,
55+
num_epochs=1000,
56+
num_workers=4,
5257
device=device,
5358
save_path=save_path
5459
)
55-
print(convd3_hard_obj.compute_metrics(syn_images))
60+
print(convd3_hard_obj.compute_metrics(syn_images, soft_labels, syn_lr=0.01))

0 commit comments

Comments
 (0)