Skip to content

Commit 5dbfce1

Browse files
v0.1
1 parent 5b3fc13 commit 5dbfce1

File tree

7 files changed

+94
-63
lines changed

7 files changed

+94
-63
lines changed

configs/Demo_Hard_Label.yaml

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11

22
# real data
3-
dataset: "CIFAR10"
4-
real_data_path: "./dataset/"
5-
custom_val_trans: None
3+
dataset: CIFAR10
4+
real_data_path: ./dataset/
5+
custom_val_trans: null
66

77
# synthetic data
88
ipc: 10
9-
im_size: (32, 32)
9+
im_size: [32, 32]
1010

1111
# agent model
12-
model_name: "ConvNet-3"
13-
use_torchvision: False
12+
model_name: ConvNet-3
13+
use_torchvision: false
1414

1515
# data augmentation
1616
data_aug_func: "dsa"
17-
aug_params: {
18-
"prob_flip": 0.5,
19-
"ratio_rotate": 15.0,
20-
"saturation": 2.0,
21-
"brightness": 1.0,
22-
"contrast": 0.5,
23-
"ratio_scale": 1.2,
24-
"ratio_crop_pad": 0.125,
25-
"ratio_cutout": 0.5
26-
}
27-
use_zca: False
17+
aug_params:
18+
prob_flip: 0.5
19+
ratio_rotate: 15.0
20+
saturation: 2.0
21+
brightness: 1.0
22+
contrast: 0.5
23+
ratio_scale: 1.2
24+
ratio_crop_pad: 0.125
25+
ratio_cutout: 0.5
26+
use_zca: false
27+
28+
custom_train_trans: null
29+
custom_val_trans: null
2830

2931
# training specifics
30-
optimizer: "sgd"
31-
lr_scheduler: "step"
32+
optimizer: sgd
33+
lr_scheduler: step
3234
weight_decay: 0.0005
3335
momentum: 0.9
3436
num_eval: 5
@@ -37,7 +39,7 @@ syn_batch_size: 128
3739
real_batch_size: 256
3840
default_lr: 0.01
3941
num_workers: 4
40-
device: "cuda"
42+
device: cuda
4143

4244
# save path
43-
save_path: "./results/my_method_hard_label_scores.csv"
45+
save_path: ./results/my_method_hard_label_scores.csv

configs/Demo_Soft_Label.yaml

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,49 @@
11
# real data
2-
dataset: "CIFAR10"
3-
real_data_path: "./dataset/"
4-
custom_val_trans: None
2+
dataset: CIFAR10
3+
real_data_path: ./dataset/
4+
custom_val_trans: null
55

66
# synthetic data
77
ipc: 10
8-
im_size: (32, 32)
8+
im_size: [32, 32]
99

1010
# agent model
11-
model_name: "ConvNet-3"
12-
stu_use_torchvision: False
13-
tea_use_torchvision: False
14-
teacher_dir: "./teacher_models"
11+
model_name: ConvNet-3
12+
stu_use_torchvision: false
13+
tea_use_torchvision: false
14+
teacher_dir: ./teacher_models
1515

16-
# data augmentation
17-
data_aug_func: "dsa"
18-
aug_params: {
19-
"prob_flip": 0.5,
20-
"ratio_rotate": 15.0,
21-
"saturation": 2.0,
22-
"brightness": 1.0,
23-
"contrast": 0.5,
24-
"ratio_scale": 1.2,
25-
"ratio_crop_pad": 0.125,
26-
"ratio_cutout": 0.5
27-
}
28-
use_zca: True
16+
# syntheticdata augmentation
17+
data_aug_func: dsa
18+
aug_params:
19+
prob_flip: 0.5
20+
ratio_rotate: 15.0
21+
saturation: 2.0
22+
brightness: 1.0
23+
contrast: 0.5
24+
ratio_scale: 1.2
25+
ratio_crop_pad: 0.125
26+
ratio_cutout: 0.5
27+
use_zca: true
28+
custom_train_trans: null
2929

3030
# soft label settings
31-
soft_label_mode: "S"
32-
soft_label_criterion: "sce"
31+
soft_label_mode: S
32+
soft_label_criterion: sce
3333
temperature: 1.0
3434

3535
# training specifics
36-
optimizer: "sgd"
37-
lr_scheduler: "step"
36+
optimizer: sgd
37+
lr_scheduler: step
3838
weight_decay: 0.0005
3939
momentum: 0.9
4040
num_eval: 5
4141
num_epochs: 1000
4242
default_lr: 0.01
4343
num_workers: 4
44-
device: "cuda"
44+
device: cuda
4545
syn_batch_size: 128
4646
real_batch_size: 256
4747

4848
# save path
49-
save_path: "./results/my_method_soft_label_scores.csv"
49+
save_path: ./my_method_soft_label_scores.csv

dd_ranking/config/user_config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import yaml
22
import json
33
from typing import Dict, Any
4-
4+
from torchvision import transforms
55

66
class Config:
77
"""Configuration object to manage individual configurations."""
@@ -21,9 +21,30 @@ def from_file(cls, filepath: str):
2121
else:
2222
raise ValueError("Unsupported file format. Use YAML or JSON.")
2323
return cls(config)
24+
25+
def load_transforms_from_yaml(self, values):
26+
if values is None:
27+
return None
28+
transform_list = []
29+
for transform in values:
30+
name = transform["name"]
31+
args = transform.get("args", [])
32+
if isinstance(args, dict):
33+
transform_list.append(getattr(transforms, name)(**args))
34+
else:
35+
transform_list.append(getattr(transforms, name)(*args))
36+
37+
return transforms.Compose(transform_list)
2438

2539
def get(self, key: str, default: Any = None):
2640
"""Get a value from the config."""
41+
if key == "custom_train_trans":
42+
return self.load_transforms_from_yaml(self.config["custom_train_trans"])
43+
elif key == "custom_val_trans":
44+
return self.load_transforms_from_yaml(self.config["custom_val_trans"])
45+
elif key == "im_size":
46+
return tuple(self.config.get("im_size", default))
47+
2748
return self.config.get(key, default)
2849

2950
def update(self, overrides: Dict[str, Any]):

dd_ranking/metrics/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .dd_ranking_unified import Unified_Evaluator
2-
from .dd_ranking_obj import Soft_Label_Evaluator, Hard_Label_Evaluator
3-
# from .dd_ranking_aug import Augmentation_Evaluator, DSA_Augmentation_Evaluator, ZCA_Whitening_Augmentation_Evaluator, Mixup_Augmentation_Evaluator, Cutmix_Augmentation_Evaluator
1+
from .general import Unified_Evaluator
2+
from .soft_label import Soft_Label_Evaluator
3+
from .hard_label import Hard_Label_Evaluator

dd_ranking/metrics/general.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def compute_metrics_helper(self, model, loader, lr):
207207
scheduler = get_lr_scheduler(optimizer, self.lr_scheduler, self.num_epochs)
208208

209209
best_acc = 0
210-
for epoch in range(self.num_epochs):
210+
for epoch in tqdm(range(self.num_epochs), total=self.num_epochs, desc="Training"):
211211
train_one_epoch(
212212
model=model,
213213
loader=loader,

demo_hard.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
""" Use config file to specify the arguments (Recommended) """
88
config = Config.from_file("./configs/Demo_Hard_Label.yaml")
9-
convd3_hard_obj = Hard_Label_Objective_Metrics(config)
10-
syn_images = torch.load(os.path.join("./DC/CIFAR10/IPC10/", f"images.pt"), map_location='cpu')
11-
print(convd3_hard_obj.compute_metrics(syn_images, syn_lr=0.01))
9+
hard_label_evaluator = Hard_Label_Objective_Metrics(config)
10+
11+
syn_data_dir = "./baselines/DM/CIFAR10/IPC10/"
12+
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
13+
syn_lr = 0.01
14+
print(hard_label_evaluator.compute_metrics(syn_images, syn_lr=syn_lr))
1215

1316

1417
""" Use keyword arguments """
@@ -34,7 +37,7 @@
3437

3538
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
3639
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
37-
convd3_hard_obj = Hard_Label_Objective_Metrics(
40+
hard_label_evaluator = Hard_Label_Objective_Metrics(
3841
dataset=dataset,
3942
real_data_path=data_dir,
4043
ipc=ipc,
@@ -58,4 +61,4 @@
5861
device=device,
5962
save_path=save_path
6063
)
61-
print(convd3_hard_obj.compute_metrics(syn_images, syn_lr=0.01))
64+
print(hard_label_evaluator.compute_metrics(syn_images, syn_lr=0.01))

demo_soft.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import os
22
import torch
3-
from dd_ranking.metrics import Soft_Label_Objective_Metrics
3+
import warnings
4+
from dd_ranking.metrics import Soft_Label_Evaluator
45
from dd_ranking.config import Config
6+
warnings.filterwarnings("ignore", category=FutureWarning)
57

68

79
""" Use config file to specify the arguments (Recommended) """
810
config = Config.from_file("./configs/Demo_Soft_Label.yaml")
9-
convd3_soft_obj = Soft_Label_Objective_Metrics(config)
11+
soft_label_evaluator = Soft_Label_Evaluator(config)
12+
13+
syn_data_dir = "./baselines/DATM/CIFAR10/IPC10/"
1014
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
1115
soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu')
1216
syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu')
13-
print(convd3_soft_obj.compute_metrics(syn_images, soft_labels, syn_lr=syn_lr))
17+
print(soft_label_evaluator.compute_metrics(image_tensor=syn_images, soft_labels=soft_labels, syn_lr=syn_lr))
1418

1519

1620
""" Use keyword arguments """
@@ -35,8 +39,9 @@
3539

3640
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
3741
soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu')
42+
syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu')
3843
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
39-
convd3_hard_obj = Soft_Label_Objective_Metrics(
44+
soft_label_evaluator = Soft_Label_Evaluator(
4045
dataset=dataset,
4146
real_data_path=data_dir,
4247
ipc=ipc,
@@ -64,4 +69,4 @@
6469
device=device,
6570
save_path=save_path
6671
)
67-
print(convd3_hard_obj.compute_metrics(syn_images, soft_labels, syn_lr=0.01))
72+
print(soft_label_evaluator.compute_metrics(syn_images, soft_labels, syn_lr=syn_lr))

0 commit comments

Comments
 (0)