Skip to content

Commit 8099136

Browse files
Fix a bug
1 parent 23ae0a1 commit 8099136

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
258258
- [Dai Liu](https://scholar.google.com/citations?user=3aWKpkQAAAAJ&hl=en)
259259
- [Ziheng Qin](https://henryqin1997.github.io/ziheng_qin/)
260260
- [Kaipeng Zhang](https://kpzhang93.github.io/)
261+
- [Zheng Zhu](http://www.zhengzhu.net/)
261262
- [Zhangyang Wang](https://vita-group.github.io/)
262263
- [Bo Zhao](https://www.bozhao.me/)
263264
- [Yang You](https://www.comp.nus.edu.sg/~youy/)

dd_ranking/metrics/soft_label.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
2323
soft_label_criterion: str='kl', data_aug_func: str='cutmix', aug_params: dict={'beta': 1.0}, soft_label_mode: str='S',
2424
optimizer: str='sgd', lr_scheduler: str='step', temperature: float=1.0, weight_decay: float=0.0005,
2525
momentum: float=0.9, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False,
26-
real_batch_size: int=256, syn_batch_size: int=256, default_lr: float=0.01, save_path: str=None, use_aug_for_hard: bool=False,
27-
stu_use_torchvision: bool=False, tea_use_torchvision: bool=False, num_workers: int=4, teacher_dir: str='./teacher_models',
28-
custom_train_trans: transforms.Compose=None, custom_val_trans: transforms.Compose=None, device: str="cuda"):
26+
real_batch_size: int=256, syn_batch_size: int=256, default_lr: float=0.01, save_path: str=None, stu_use_torchvision: bool=False,
27+
tea_use_torchvision: bool=False, num_workers: int=4, teacher_dir: str='./teacher_models', custom_train_trans: transforms.Compose=None,
28+
custom_val_trans: transforms.Compose=None, device: str="cuda"):
2929

3030
if config is not None:
3131
self.config = config
@@ -45,12 +45,14 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
4545
num_eval = self.config.get('num_eval')
4646
im_size = self.config.get('im_size')
4747
num_epochs = self.config.get('num_epochs')
48+
use_zca = self.config.get('use_zca')
4849
real_batch_size = self.config.get('real_batch_size')
4950
syn_batch_size = self.config.get('syn_batch_size')
5051
default_lr = self.config.get('default_lr')
5152
save_path = self.config.get('save_path')
5253
num_workers = self.config.get('num_workers')
53-
use_torchvision = self.config.get('use_torchvision')
54+
stu_use_torchvision = self.config.get('stu_use_torchvision')
55+
tea_use_torchvision = self.config.get('tea_use_torchvision')
5456
custom_train_trans = self.config.get('custom_train_trans')
5557
custom_val_trans = self.config.get('custom_val_trans')
5658
teacher_dir = self.config.get('teacher_dir')
@@ -100,7 +102,6 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
100102
self.aug_func = Cutmix(aug_params)
101103
else:
102104
self.aug_func = None
103-
self.use_aug_for_hard = use_aug_for_hard
104105

105106
if not save_path:
106107
save_path = f"./results/{dataset}/{model_name}/ipc{ipc}/obj_scores.csv"

0 commit comments

Comments
 (0)