Skip to content

Commit 3b62a9b

Browse files
Update demo config
1 parent 8d9a552 commit 3b62a9b

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

configs/Demo_ARS.yaml

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# real data
2+
dataset: ImageNet1K
3+
real_data_path: ./dataset/ImageNet1K/
4+
5+
# synthetic data
6+
ipc: 10
7+
im_size: [224, 224]
8+
9+
# agent model
10+
model_name: ResNet-18-BN
11+
stu_use_torchvision: true
12+
tea_use_torchvision: true
13+
teacher_dir: ./teacher_models
14+
teacher_model_names: [ResNet-18-BN]
15+
16+
# syntheticdata augmentation
17+
data_aug_func: cutmix
18+
aug_params:
19+
beta: 1.0
20+
use_zca: false
21+
22+
custom_train_trans:
23+
- name: RandomResizedCrop
24+
args:
25+
size: 224
26+
scale: [0.08, 1.0]
27+
- name: RandomHorizontalFlip
28+
args:
29+
p: 0.5
30+
- name: ToTensor
31+
- name: Normalize
32+
args:
33+
mean: [0.485, 0.456, 0.406]
34+
std: [0.229, 0.224, 0.225]
35+
36+
custom_val_trans:
37+
- name: Resize
38+
args:
39+
size: 256
40+
- name: CenterCrop
41+
args:
42+
size: 224
43+
- name: ToTensor
44+
- name: Normalize
45+
args:
46+
mean: [0.485, 0.456, 0.406]
47+
std: [0.229, 0.224, 0.225]
48+
49+
# soft label settings
50+
label_type: soft
51+
soft_label_mode: M
52+
soft_label_criterion: kl
53+
loss_fn_kwargs:
54+
temperature: 1.0
55+
scale_loss: false
56+
57+
# training specifics
58+
optimizer: adamw
59+
lr_scheduler: cosine
60+
weight_decay: 0.01
61+
momentum: 0.9
62+
num_eval: 5
63+
num_epochs: 300
64+
num_workers: 4
65+
device: cuda
66+
dist: true
67+
batch_size: 1024
68+
random_data_path: ./random_data/my_method/ImageNet1K/IPC10/
69+
random_data_format: image
70+
71+
# save path
72+
save_path: ./my_method_ars_scores.csv
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ custom_val_trans: null
3131
# training specifics
3232
optimizer: sgd
3333
lr_scheduler: step
34+
step_size: 500
3435
weight_decay: 0.0005
3536
momentum: 0.9
3637
num_eval: 5
@@ -43,7 +44,7 @@ device: cuda
4344
dist: true
4445
eval_full_data: false
4546
random_data_path: ./results/my_method_random_data.pt
46-
random_data_format: tensors
47+
random_data_format: tensor
4748

4849
# save path
4950
save_path: ./results/my_method_hard_label_scores.csv
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ loss_fn_kwargs:
5959
optimizer: adamw
6060
lr_scheduler: cosine
6161
weight_decay: 0.01
62+
momentum: 0.9
6263
num_eval: 5
6364
num_epochs: 300
6465
num_workers: 4

0 commit comments

Comments
 (0)