File tree Expand file tree Collapse file tree 4 files changed +59
-15
lines changed
Expand file tree Collapse file tree 4 files changed +59
-15
lines changed Original file line number Diff line number Diff line change 1+
2+ # real data
13dataset : " CIFAR10"
24real_data_path : " ./dataset/"
5+ custom_val_trans : None
6+
7+ # synthetic data
38ipc : 10
9+ im_size : (32, 32)
10+
11+ # agent model
412model_name : " ConvNet-3"
13+ use_torchvision : False
14+
15+ # data augmentation
516data_aug_func : " dsa"
617aug_params : {
718 " prob_flip " : 0.5,
@@ -13,16 +24,20 @@ aug_params: {
1324 " ratio_crop_pad " : 0.125,
1425 " ratio_cutout " : 0.5
1526}
27+ use_zca : False
28+
29+ # training specifics
1630optimizer : " sgd"
1731lr_scheduler : " step"
1832weight_decay : 0.0005
1933momentum : 0.9
2034num_eval : 5
21- im_size : (32, 32)
2235num_epochs : 1000
23- use_zca : False
24- batch_size : 256
36+ syn_batch_size : 128
37+ real_batch_size : 256
2538default_lr : 0.01
2639num_workers : 4
27- save_path : " ./results/my_method_hard_label_scores.csv"
28- device : " cuda"
40+ device : " cuda"
41+
42+ # save path
43+ save_path : " ./results/my_method_hard_label_scores.csv"
Original file line number Diff line number Diff line change 1+ # real data
12dataset : " CIFAR10"
23real_data_path : " ./dataset/"
4+ custom_val_trans : None
5+
6+ # synthetic data
37ipc : 10
8+ im_size : (32, 32)
9+
10+ # agent model
411model_name : " ConvNet-3"
12+ stu_use_torchvision : False
13+ tea_use_torchvision : False
14+ teacher_dir : " ./teacher_models"
15+
16+ # data augmentation
517data_aug_func : " dsa"
618aug_params : {
719 " prob_flip " : 0.5,
@@ -13,19 +25,25 @@ aug_params: {
1325 " ratio_crop_pad " : 0.125,
1426 " ratio_cutout " : 0.5
1527}
28+ use_zca : True
29+
30+ # soft label settings
1631soft_label_mode : " S"
1732soft_label_criterion : " sce"
33+ temperature : 1.0
34+
35+ # training specifics
1836optimizer : " sgd"
1937lr_scheduler : " step"
20- temperature : 20.0
2138weight_decay : 0.0005
2239momentum : 0.9
2340num_eval : 5
24- im_size : (32, 32)
2541num_epochs : 1000
26- use_zca : True
27- batch_size : 256
2842default_lr : 0.01
2943num_workers : 4
30- save_path : " ./results/my_method_soft_label_scores.csv"
3144device : " cuda"
45+ syn_batch_size : 128
46+ real_batch_size : 256
47+
48+ # save path
49+ save_path : " ./results/my_method_soft_label_scores.csv"
Original file line number Diff line number Diff line change 44from dd_ranking .config import Config
55
66
7- """Use config file to specify the parameters (Recommended)"""
7+ """ Use config file to specify the arguments (Recommended) """
88config = Config .from_file ("./configs/Demo_Hard_Label.yaml" )
99convd3_hard_obj = Hard_Label_Objective_Metrics (config )
1010syn_images = torch .load (os .path .join ("./DC/CIFAR10/IPC10/" , f"images.pt" ), map_location = 'cpu' )
1111print (convd3_hard_obj .compute_metrics (syn_images , syn_lr = 0.01 ))
1212
1313
14- """Use hardcoded parameters """
14+ """ Use keyword arguments """
1515device = "cuda"
1616method_name = "DM" # Specify your method name
1717ipc = 10 # Specify your IPC
5151 im_size = im_size ,
5252 num_epochs = 1000 ,
5353 num_workers = 4 ,
54+ use_torchvision = False ,
55+ syn_batch_size = 128 ,
56+ real_batch_size = 256 ,
57+ custom_val_trans = None ,
5458 device = device ,
5559 save_path = save_path
5660)
Original file line number Diff line number Diff line change 44from dd_ranking .config import Config
55
66
7- """Use config file to specify the parameters (Recommended)"""
7+ """ Use config file to specify the arguments (Recommended) """
88config = Config .from_file ("./configs/Demo_Soft_Label.yaml" )
99convd3_soft_obj = Soft_Label_Objective_Metrics (config )
1010syn_images = torch .load (os .path .join (syn_data_dir , f"images.pt" ), map_location = 'cpu' )
1111soft_labels = torch .load (os .path .join (syn_data_dir , f"labels.pt" ), map_location = 'cpu' )
12- print (convd3_soft_obj .compute_metrics (syn_images , soft_labels , syn_lr = 0.01 ))
12+ syn_lr = torch .load (os .path .join (syn_data_dir , f"lr.pt" ), map_location = 'cpu' )
13+ print (convd3_soft_obj .compute_metrics (syn_images , soft_labels , syn_lr = syn_lr ))
1314
1415
15- """Use hardcoded parameters """
16+ """ Use keyword arguments """
1617device = "cuda"
1718method_name = "DATM" # Specify your method name
1819ipc = 10 # Specify your IPC
5455 im_size = im_size ,
5556 num_epochs = 1000 ,
5657 num_workers = 4 ,
58+ stu_use_torchvision = False ,
59+ tea_use_torchvision = False ,
60+ custom_val_trans = None ,
61+ syn_batch_size = 128 ,
62+ real_batch_size = 256 ,
63+ teacher_dir = './teacher_models' ,
5764 device = device ,
5865 save_path = save_path
5966)
You can’t perform that action at this time.
0 commit comments