11import os
22import argparse
33import torch
4- from ddranking .metrics import AugmentationRobustnessEvaluator
4+ from ddranking .metrics import AugmentationRobustScore
55from ddranking .config import Config
66
7- def main (args ):
8- root = args .root
9- method_name = args .method_name
10- dataset = args .dataset
11- ipc = args .ipc
7+ """ Use config file to specify the arguments (Recommended) """
8+ config = Config .from_file ("./configs/Demo_ARS.yaml" )
9+ aug_evaluator = AugmentationRobustScore (config )
1210
13- print (f"Evaluating { method_name } on { dataset } with ipc{ ipc } " )
14- syn_image_dir = os .path .join (root , f"DD-Ranking/baselines/{ method_name } /{ dataset } /IPC{ ipc } /" )
15- config = Config .from_file (f"./{ method_name } /{ dataset } /IPC{ ipc } _Aug.yaml" )
16- aug_obj = AugmentationRobustnessEvaluator (config )
17- aug_obj .compute_metrics (image_path = syn_image_dir , syn_lr = args .syn_lr )
11+ syn_data_dir = "./baselines/SRe2L/ImageNet1K/IPC10/"
12+ print (aug_evaluator .compute_metrics (image_path = syn_data_dir , syn_lr = 0.001 ))
1813
1914
20- if __name__ == "__main__" :
21- parser = argparse .ArgumentParser ()
22- parser .add_argument ("--root" , type = str , default = "/home/wangkai/" )
23- parser .add_argument ("--method_name" , type = str , default = "SRe2L" )
24- parser .add_argument ("--dataset" , type = str , default = "ImageNet1K" )
25- parser .add_argument ("--ipc" , type = int , default = 10 )
26- parser .add_argument ("--syn_lr" , type = float , default = 0.001 )
27- args = parser .parse_args ()
28- main (args )
15+ """ Use keyword arguments """
16+ from torchvision import transforms
17+ device = "cuda"
18+ method_name = "SRe2L" # Specify your method name
19+ ipc = 10 # Specify your IPC
20+ dataset = "ImageNet1K" # Specify your dataset name
21+ syn_data_dir = "./SRe2L/ImageNet1K/IPC10/" # Specify your synthetic data path
22+ data_dir = "./datasets" # Specify your dataset path
23+ model_name = "ResNet-18-BN" # Specify your model name
24+ im_size = (224 , 224 ) # Specify your image size
25+ cutmix_params = { # Specify your data augmentation parameters
26+ "beta" : 1.0
27+ }
28+
29+ syn_images = torch .load (os .path .join (syn_data_dir , f"images.pt" ), map_location = 'cpu' )
30+ soft_labels = torch .load (os .path .join (syn_data_dir , f"labels.pt" ), map_location = 'cpu' )
31+ syn_lr = torch .load (os .path .join (syn_data_dir , f"lr.pt" ), map_location = 'cpu' )
32+ save_path = f"./results/{ dataset } /{ model_name } /IPC{ ipc } /dm_hard_scores.csv"
33+
34+ custom_train_trans = transforms .Compose ([
35+ transforms .RandomResizedCrop (224 , scale = (0.08 , 1.0 )),
36+ transforms .RandomHorizontalFlip (),
37+ transforms .ToTensor (),
38+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
39+ ])
40+ custom_val_trans = transforms .Compose ([
41+ transforms .Resize (256 ),
42+ transforms .CenterCrop (224 ),
43+ transforms .ToTensor (),
44+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
45+ ])
46+
47+ aug_evaluator = AugmentationRobustScore (
48+ dataset = dataset ,
49+ real_data_path = data_dir ,
50+ ipc = ipc ,
51+ model_name = model_name ,
52+ label_type = 'soft' ,
53+ soft_label_criterion = 'kl' , # Use Soft Cross Entropy Loss
54+ soft_label_mode = 'M' , # Use one-to-one image to soft label mapping
55+ loss_fn_kwargs = {'temperature' : 1.0 , 'scale_loss' : False },
56+ optimizer = 'adamw' , # Use SGD optimizer
57+ lr_scheduler = 'cosine' , # Use StepLR learning rate scheduler
58+ weight_decay = 0.01 ,
59+ momentum = 0.9 ,
60+ num_eval = 5 ,
61+ data_aug_func = 'cutmix' , # Use DSA data augmentation
62+ aug_params = cutmix_params , # Specify dsa parameters
63+ im_size = im_size ,
64+ num_epochs = 300 ,
65+ num_workers = 4 ,
66+ stu_use_torchvision = True ,
67+ tea_use_torchvision = True ,
68+ random_data_format = 'tensor' ,
69+ random_data_path = './random_data' ,
70+ custom_train_trans = custom_train_trans ,
71+ custom_val_trans = custom_val_trans ,
72+ batch_size = 256 ,
73+ teacher_dir = './teacher_models' ,
74+ teacher_model_name = ['ResNet-18-BN' ],
75+ device = device ,
76+ dist = True ,
77+ save_path = save_path
78+ )
79+ print (aug_evaluator .compute_metrics (image_path = syn_data_dir , syn_lr = 0.001 ))
0 commit comments