Skip to content

Commit d0e6080

Browse files
Update config
1 parent fe609e5 commit d0e6080

File tree

3 files changed

+75
-24
lines changed

3 files changed

+75
-24
lines changed

dd_ranking/metrics/dd_ranking_obj.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
2323
soft_label_criterion: str='kl', data_aug_func: str='cutmix', aug_params: dict={'cutmix_p': 1.0}, soft_label_mode: str='S',
2424
optimizer: str='sgd', lr_scheduler: str='step', temperature: float=1.0, weight_decay: float=0.0005,
2525
momentum: float=0.9, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False,
26-
batch_size: int=256, default_lr: float=0.01, save_path: str=None, device: str="cuda"):
26+
batch_size: int=256, default_lr: float=0.01, save_path: str=None, use_torchvision: bool=False, device: str="cuda"):
2727

2828
if config is not None:
2929
self.config = config
@@ -46,6 +46,8 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
4646
batch_size = self.config.get('batch_size')
4747
default_lr = self.config.get('default_lr')
4848
save_path = self.config.get('save_path')
49+
num_workers = self.config.get('num_workers')
50+
use_torchvision = self.config.get('use_torchvision')
4951
device = self.config.get('device')
5052

5153
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset,
@@ -82,8 +84,6 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
8284
if data_aug_func == 'dsa':
8385
self.aug_func = DSA_Augmentation(aug_params)
8486
self.num_epochs = 1000
85-
elif data_aug_func == 'zca':
86-
self.aug_func = ZCA_Whitening_Augmentation(aug_params)
8787
elif data_aug_func == 'mixup':
8888
self.aug_func = Mixup_Augmentation(aug_params)
8989
elif data_aug_func == 'cutmix':
@@ -98,8 +98,17 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
9898
self.save_path = save_path
9999

100100
# teacher model
101-
pretrained_model_path = get_pretrained_model_path(model_name, dataset, ipc)
102-
self.teacher_model = build_model(model_name, num_classes=self.num_classes, im_size=self.im_size, pretrained=True, device=self.device, model_path=pretrained_model_path)
101+
if not use_torchvision:
102+
pretrained_model_path = get_pretrained_model_path(model_name, dataset, ipc)
103+
else:
104+
pretrained_model_path = None
105+
self.teacher_model = build_model(model_name,
106+
num_classes=self.num_classes,
107+
im_size=self.im_size,
108+
pretrained=True,
109+
device=self.device,
110+
model_path=pretrained_model_path,
111+
use_torchvision=use_torchvision)
103112
self.teacher_model.eval()
104113

105114
def load_real_data(self, dataset, class_map, num_classes):

dd_ranking/metrics/dd_ranking_unified.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,24 @@
1818
from dd_ranking.utils import train_one_epoch, validate
1919
from dd_ranking.loss import SoftCrossEntropyLoss, KLDivergenceLoss
2020
from dd_ranking.aug import DSA_Augmentation, Mixup_Augmentation, Cutmix_Augmentation, ZCA_Whitening_Augmentation
21+
from dd_ranking.config import Config
2122

2223

2324
class Unified_Evaluator:
2425

25-
def __init__(self,
26-
dataset: str,
27-
real_data_path: str,
28-
ipc: int,
29-
model_name: str,
30-
use_soft_label: bool,
26+
def __init__(self,
27+
config: Config=None,
28+
dataset: str='CIFAR10',
29+
real_data_path: str='./dataset',
30+
ipc: int=10,
31+
model_name: str='ConvNet-3',
32+
use_soft_label: bool=False,
3133
optimizer: str='sgd',
3234
lr_scheduler: str='step',
33-
data_aug_func: str=None,
35+
data_aug_func: str='dsa',
3436
aug_params: dict=None,
35-
soft_label_mode: str=None,
36-
soft_label_criterion: str=None,
37+
soft_label_mode: str='M',
38+
soft_label_criterion: str='kl',
3739
num_eval: int=5,
3840
im_size: tuple=(32, 32),
3941
num_epochs: int=300,
@@ -42,14 +44,51 @@ def __init__(self,
4244
momentum: float=0.9,
4345
use_zca: bool=False,
4446
temperature: float=1.0,
47+
use_torchvision: bool=False,
48+
num_workers: int=4,
4549
save_path: str=None,
4650
device: str="cuda"
4751
):
4852

53+
if config is not None:
54+
self.config = config
55+
dataset = self.config.get('dataset', 'CIFAR10')
56+
real_data_path = self.config.get('real_data_path', './dataset')
57+
ipc = self.config.get('ipc', 10)
58+
model_name = self.config.get('model_name', 'ConvNet-3')
59+
use_soft_label = self.config.get('use_soft_label', False)
60+
soft_label_criterion = self.config.get('soft_label_criterion', 'sce')
61+
data_aug_func = self.config.get('data_aug_func', 'dsa')
62+
aug_params = self.config.get('aug_params', {
63+
"prob_flip": 0.5,
64+
"ratio_rotate": 15.0,
65+
"saturation": 2.0,
66+
"brightness": 1.0,
67+
"contrast": 0.5,
68+
"ratio_scale": 1.2,
69+
"ratio_crop_pad": 0.125,
70+
"ratio_cutout": 0.5
71+
})
72+
soft_label_mode = self.config.get('soft_label_mode', 'S')
73+
optimizer = self.config.get('optimizer', 'sgd')
74+
lr_scheduler = self.config.get('lr_scheduler', 'step')
75+
temperature = self.config.get('temperature', 1.0)
76+
weight_decay = self.config.get('weight_decay', 0.0005)
77+
momentum = self.config.get('momentum', 0.9)
78+
num_eval = self.config.get('num_eval', 5)
79+
im_size = self.config.get('im_size', (32, 32))
80+
num_epochs = self.config.get('num_epochs', 300)
81+
batch_size = self.config.get('batch_size', 256)
82+
default_lr = self.config.get('default_lr', 0.01)
83+
save_path = self.config.get('save_path', None)
84+
num_workers = self.config.get('num_workers', 4)
85+
use_torchvision = self.config.get('use_torchvision', False)
86+
device = self.config.get('device', 'cuda')
87+
4988
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset, real_data_path, im_size, use_zca)
5089
self.num_classes = num_classes
5190
self.im_size = im_size
52-
self.test_loader = DataLoader(dst_test, batch_size=batch_size, num_workers=4, shuffle=False)
91+
self.test_loader = DataLoader(dst_test, batch_size=batch_size, num_workers=num_workers, shuffle=False)
5392

5493
self.ipc = ipc
5594
self.model_name = model_name
@@ -77,26 +116,30 @@ def __init__(self,
77116
os.makedirs(os.path.dirname(save_path))
78117
self.save_path = save_path
79118

80-
pretrained_model_path = get_pretrained_model_path(model_name, dataset, ipc)
119+
if not use_torchvision:
120+
pretrained_model_path = get_pretrained_model_path(model_name, dataset, ipc)
121+
else:
122+
pretrained_model_path = None
123+
81124
self.teacher_model = build_model(
82125
model_name=model_name,
83126
num_classes=num_classes,
84127
im_size=self.im_size,
85128
pretrained=True,
86129
device=self.device,
87-
model_path=pretrained_model_path
130+
model_path=pretrained_model_path,
131+
use_torchvision=use_torchvision
88132
)
89133
self.teacher_model.eval()
90134

91135
if data_aug_func is None:
92136
self.aug_func = None
93-
elif data_aug_func == 'DSA':
137+
elif data_aug_func == 'dsa':
94138
self.aug_func = DSA_Augmentation(aug_params)
95-
elif data_aug_func == 'ZCA':
96-
self.aug_func = ZCA_Whitening_Augmentation(aug_params)
97-
elif data_aug_func == 'Mixup':
139+
self.num_epochs = 1000
140+
elif data_aug_func == 'mixup':
98141
self.aug_func = Mixup_Augmentation(aug_params)
99-
elif data_aug_func == 'Cutmix':
142+
elif data_aug_func == 'cutmix':
100143
self.aug_func = Cutmix_Augmentation(aug_params)
101144
else:
102145
raise ValueError(f"Invalid data augmentation function: {data_aug_func}")
@@ -168,7 +211,6 @@ def compute_metrics_helper(self, model, loader, lr):
168211
acc = validate(
169212
model=model,
170213
loader=loader,
171-
aug_func=self.aug_func,
172214
logging=True,
173215
device=self.device
174216
)

demo_soft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"contrast": 0.5,
3030
"ratio_scale": 1.2,
3131
"ratio_crop_pad": 0.125,
32-
"ratio_cutout": 0.5,
32+
"ratio_cutout": 0.5
3333
}
3434

3535
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')

0 commit comments

Comments
 (0)