Skip to content

Commit 73941d7

Browse files
Update demo
1 parent 4763c02 commit 73941d7

File tree

4 files changed

+59
-15
lines changed

4 files changed

+59
-15
lines changed

configs/Demo_Hard_Label.yaml

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
2+
# real data
13
dataset: "CIFAR10"
24
real_data_path: "./dataset/"
5+
custom_val_trans: None
6+
7+
# synthetic data
38
ipc: 10
9+
im_size: (32, 32)
10+
11+
# agent model
412
model_name: "ConvNet-3"
13+
use_torchvision: False
14+
15+
# data augmentation
516
data_aug_func: "dsa"
617
aug_params: {
718
"prob_flip": 0.5,
@@ -13,16 +24,20 @@ aug_params: {
1324
"ratio_crop_pad": 0.125,
1425
"ratio_cutout": 0.5
1526
}
27+
use_zca: False
28+
29+
# training specifics
1630
optimizer: "sgd"
1731
lr_scheduler: "step"
1832
weight_decay: 0.0005
1933
momentum: 0.9
2034
num_eval: 5
21-
im_size: (32, 32)
2235
num_epochs: 1000
23-
use_zca: False
24-
batch_size: 256
36+
syn_batch_size: 128
37+
real_batch_size: 256
2538
default_lr: 0.01
2639
num_workers: 4
27-
save_path: "./results/my_method_hard_label_scores.csv"
28-
device: "cuda"
40+
device: "cuda"
41+
42+
# save path
43+
save_path: "./results/my_method_hard_label_scores.csv"

configs/Demo_Soft_Label.yaml

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1+
# real data
12
dataset: "CIFAR10"
23
real_data_path: "./dataset/"
4+
custom_val_trans: None
5+
6+
# synthetic data
37
ipc: 10
8+
im_size: (32, 32)
9+
10+
# agent model
411
model_name: "ConvNet-3"
12+
stu_use_torchvision: False
13+
tea_use_torchvision: False
14+
teacher_dir: "./teacher_models"
15+
16+
# data augmentation
517
data_aug_func: "dsa"
618
aug_params: {
719
"prob_flip": 0.5,
@@ -13,19 +25,25 @@ aug_params: {
1325
"ratio_crop_pad": 0.125,
1426
"ratio_cutout": 0.5
1527
}
28+
use_zca: True
29+
30+
# soft label settings
1631
soft_label_mode: "S"
1732
soft_label_criterion: "sce"
33+
temperature: 1.0
34+
35+
# training specifics
1836
optimizer: "sgd"
1937
lr_scheduler: "step"
20-
temperature: 20.0
2138
weight_decay: 0.0005
2239
momentum: 0.9
2340
num_eval: 5
24-
im_size: (32, 32)
2541
num_epochs: 1000
26-
use_zca: True
27-
batch_size: 256
2842
default_lr: 0.01
2943
num_workers: 4
30-
save_path: "./results/my_method_soft_label_scores.csv"
3144
device: "cuda"
45+
syn_batch_size: 128
46+
real_batch_size: 256
47+
48+
# save path
49+
save_path: "./results/my_method_soft_label_scores.csv"

demo_hard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from dd_ranking.config import Config
55

66

7-
"""Use config file to specify the parameters (Recommended)"""
7+
""" Use config file to specify the arguments (Recommended) """
88
config = Config.from_file("./configs/Demo_Hard_Label.yaml")
99
convd3_hard_obj = Hard_Label_Objective_Metrics(config)
1010
syn_images = torch.load(os.path.join("./DC/CIFAR10/IPC10/", f"images.pt"), map_location='cpu')
1111
print(convd3_hard_obj.compute_metrics(syn_images, syn_lr=0.01))
1212

1313

14-
"""Use hardcoded parameters"""
14+
""" Use keyword arguments """
1515
device = "cuda"
1616
method_name = "DM" # Specify your method name
1717
ipc = 10 # Specify your IPC
@@ -51,6 +51,10 @@
5151
im_size=im_size,
5252
num_epochs=1000,
5353
num_workers=4,
54+
use_torchvision=False,
55+
syn_batch_size=128,
56+
real_batch_size=256,
57+
custom_val_trans=None,
5458
device=device,
5559
save_path=save_path
5660
)

demo_soft.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
from dd_ranking.config import Config
55

66

7-
"""Use config file to specify the parameters (Recommended)"""
7+
""" Use config file to specify the arguments (Recommended) """
88
config = Config.from_file("./configs/Demo_Soft_Label.yaml")
99
convd3_soft_obj = Soft_Label_Objective_Metrics(config)
1010
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
1111
soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu')
12-
print(convd3_soft_obj.compute_metrics(syn_images, soft_labels, syn_lr=0.01))
12+
syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu')
13+
print(convd3_soft_obj.compute_metrics(syn_images, soft_labels, syn_lr=syn_lr))
1314

1415

15-
"""Use hardcoded parameters"""
16+
""" Use keyword arguments """
1617
device = "cuda"
1718
method_name = "DATM" # Specify your method name
1819
ipc = 10 # Specify your IPC
@@ -54,6 +55,12 @@
5455
im_size=im_size,
5556
num_epochs=1000,
5657
num_workers=4,
58+
stu_use_torchvision=False,
59+
tea_use_torchvision=False,
60+
custom_val_trans=None,
61+
syn_batch_size=128,
62+
real_batch_size=256,
63+
teacher_dir='./teacher_models',
5764
device=device,
5865
save_path=save_path
5966
)

0 commit comments

Comments
 (0)