@@ -23,9 +23,9 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
2323 soft_label_criterion : str = 'kl' , data_aug_func : str = 'cutmix' , aug_params : dict = {'beta' : 1.0 }, soft_label_mode : str = 'S' ,
2424 optimizer : str = 'sgd' , lr_scheduler : str = 'step' , temperature : float = 1.0 , weight_decay : float = 0.0005 ,
2525 momentum : float = 0.9 , num_eval : int = 5 , im_size : tuple = (32 , 32 ), num_epochs : int = 300 , use_zca : bool = False ,
26- real_batch_size : int = 256 , syn_batch_size : int = 256 , default_lr : float = 0.01 , save_path : str = None , use_aug_for_hard : bool = False ,
27- stu_use_torchvision : bool = False , tea_use_torchvision : bool = False , num_workers : int = 4 , teacher_dir : str = './teacher_models' ,
28- custom_train_trans : transforms . Compose = None , custom_val_trans : transforms .Compose = None , device : str = "cuda" ):
26+ real_batch_size : int = 256 , syn_batch_size : int = 256 , default_lr : float = 0.01 , save_path : str = None , stu_use_torchvision : bool = False ,
27+ tea_use_torchvision : bool = False , num_workers : int = 4 , teacher_dir : str = './teacher_models' , custom_train_trans : transforms . Compose = None ,
28+ custom_val_trans : transforms .Compose = None , device : str = "cuda" ):
2929
3030 if config is not None :
3131 self .config = config
@@ -45,12 +45,14 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
4545 num_eval = self .config .get ('num_eval' )
4646 im_size = self .config .get ('im_size' )
4747 num_epochs = self .config .get ('num_epochs' )
48+ use_zca = self .config .get ('use_zca' )
4849 real_batch_size = self .config .get ('real_batch_size' )
4950 syn_batch_size = self .config .get ('syn_batch_size' )
5051 default_lr = self .config .get ('default_lr' )
5152 save_path = self .config .get ('save_path' )
5253 num_workers = self .config .get ('num_workers' )
53- use_torchvision = self .config .get ('use_torchvision' )
54+ stu_use_torchvision = self .config .get ('stu_use_torchvision' )
55+ tea_use_torchvision = self .config .get ('tea_use_torchvision' )
5456 custom_train_trans = self .config .get ('custom_train_trans' )
5557 custom_val_trans = self .config .get ('custom_val_trans' )
5658 teacher_dir = self .config .get ('teacher_dir' )
@@ -100,7 +102,6 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
100102 self .aug_func = Cutmix (aug_params )
101103 else :
102104 self .aug_func = None
103- self .use_aug_for_hard = use_aug_for_hard
104105
105106 if not save_path :
106107 save_path = f"./results/{ dataset } /{ model_name } /ipc{ ipc } /obj_scores.csv"
0 commit comments