Skip to content

Commit 8d9a552

Browse files
Update demo
1 parent ae4ff45 commit 8d9a552

File tree

4 files changed

+97
-37
lines changed

4 files changed

+97
-37
lines changed

demo_aug.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,79 @@
11
import os
22
import argparse
33
import torch
4-
from ddranking.metrics import AugmentationRobustnessEvaluator
4+
from ddranking.metrics import AugmentationRobustScore
55
from ddranking.config import Config
66

7-
def main(args):
8-
root = args.root
9-
method_name = args.method_name
10-
dataset = args.dataset
11-
ipc = args.ipc
7+
""" Use config file to specify the arguments (Recommended) """
8+
config = Config.from_file("./configs/Demo_ARS.yaml")
9+
aug_evaluator = AugmentationRobustScore(config)
1210

13-
print(f"Evaluating {method_name} on {dataset} with ipc{ipc}")
14-
syn_image_dir = os.path.join(root, f"DD-Ranking/baselines/{method_name}/{dataset}/IPC{ipc}/")
15-
config = Config.from_file(f"./{method_name}/{dataset}/IPC{ipc}_Aug.yaml")
16-
aug_obj = AugmentationRobustnessEvaluator(config)
17-
aug_obj.compute_metrics(image_path=syn_image_dir, syn_lr=args.syn_lr)
11+
syn_data_dir = "./baselines/SRe2L/ImageNet1K/IPC10/"
12+
print(aug_evaluator.compute_metrics(image_path=syn_data_dir, syn_lr=0.001))
1813

1914

20-
if __name__ == "__main__":
21-
parser = argparse.ArgumentParser()
22-
parser.add_argument("--root", type=str, default="/home/wangkai/")
23-
parser.add_argument("--method_name", type=str, default="SRe2L")
24-
parser.add_argument("--dataset", type=str, default="ImageNet1K")
25-
parser.add_argument("--ipc", type=int, default=10)
26-
parser.add_argument("--syn_lr", type=float, default=0.001)
27-
args = parser.parse_args()
28-
main(args)
15+
""" Use keyword arguments """
16+
from torchvision import transforms
17+
device = "cuda"
18+
method_name = "SRe2L" # Specify your method name
19+
ipc = 10 # Specify your IPC
20+
dataset = "ImageNet1K" # Specify your dataset name
21+
syn_data_dir = "./SRe2L/ImageNet1K/IPC10/" # Specify your synthetic data path
22+
data_dir = "./datasets" # Specify your dataset path
23+
model_name = "ResNet-18-BN" # Specify your model name
24+
im_size = (224, 224) # Specify your image size
25+
cutmix_params = { # Specify your data augmentation parameters
26+
"beta": 1.0
27+
}
28+
29+
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
30+
soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu')
31+
syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu')
32+
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
33+
34+
custom_train_trans = transforms.Compose([
35+
transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
36+
transforms.RandomHorizontalFlip(),
37+
transforms.ToTensor(),
38+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
39+
])
40+
custom_val_trans = transforms.Compose([
41+
transforms.Resize(256),
42+
transforms.CenterCrop(224),
43+
transforms.ToTensor(),
44+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
45+
])
46+
47+
aug_evaluator = AugmentationRobustScore(
48+
dataset=dataset,
49+
real_data_path=data_dir,
50+
ipc=ipc,
51+
model_name=model_name,
52+
label_type='soft',
53+
soft_label_criterion='kl', # Use Soft Cross Entropy Loss
54+
soft_label_mode='M', # Use one-to-one image to soft label mapping
55+
loss_fn_kwargs={'temperature': 1.0, 'scale_loss': False},
56+
optimizer='adamw', # Use SGD optimizer
57+
lr_scheduler='cosine', # Use StepLR learning rate scheduler
58+
weight_decay=0.01,
59+
momentum=0.9,
60+
num_eval=5,
61+
data_aug_func='cutmix', # Use DSA data augmentation
62+
aug_params=cutmix_params, # Specify dsa parameters
63+
im_size=im_size,
64+
num_epochs=300,
65+
num_workers=4,
66+
stu_use_torchvision=True,
67+
tea_use_torchvision=True,
68+
random_data_format='tensor',
69+
random_data_path='./random_data',
70+
custom_train_trans=custom_train_trans,
71+
custom_val_trans=custom_val_trans,
72+
batch_size=256,
73+
teacher_dir='./teacher_models',
74+
teacher_model_name=['ResNet-18-BN'],
75+
device=device,
76+
dist=True,
77+
save_path=save_path
78+
)
79+
print(aug_evaluator.compute_metrics(image_path=syn_data_dir, syn_lr=0.001))

demo_hard.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os
22
import torch
33
import warnings
4-
from ddranking.metrics import HardLabelEvaluator
4+
from ddranking.metrics import LabelRobustScoreHard
55
from ddranking.config import Config
66
warnings.filterwarnings("ignore")
77

88

99
""" Use config file to specify the arguments (Recommended) """
10-
config = Config.from_file("./configs/Demo_Hard_Label.yaml")
11-
hard_label_evaluator = HardLabelEvaluator(config)
10+
config = Config.from_file("./configs/Demo_LRS_Hard_Label.yaml")
11+
hard_label_evaluator = LabelRobustScoreHard(config)
1212

1313
syn_data_dir = "./baselines/DM/CIFAR10/IPC10/"
1414
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
@@ -39,12 +39,11 @@
3939

4040
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
4141
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
42-
hard_label_evaluator = HardLabelEvaluator(
42+
hard_label_evaluator = LabelRobustScoreHard(
4343
dataset=dataset,
4444
real_data_path=data_dir,
4545
ipc=ipc,
4646
model_name=model_name,
47-
default_lr=0.01,
4847
optimizer='sgd', # Use SGD optimizer
4948
lr_scheduler='step', # Use StepLR learning rate scheduler
5049
weight_decay=0.0005,
@@ -57,10 +56,15 @@
5756
num_epochs=1000,
5857
num_workers=4,
5958
use_torchvision=False,
60-
syn_batch_size=128,
59+
syn_batch_size=256,
6160
real_batch_size=256,
61+
custom_train_trans=None,
6262
custom_val_trans=None,
6363
device=device,
64-
save_path=save_path
64+
dist=True,
65+
save_path=save_path,
66+
random_data_format='tensor',
67+
random_data_path='./random_data',
68+
eval_full_data=True,
6569
)
6670
print(hard_label_evaluator.compute_metrics(image_tensor=syn_images, syn_lr=0.01))

demo_soft.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os
22
import torch
33
import warnings
4-
from ddranking.metrics import SoftLabelEvaluator
4+
from ddranking.metrics import LabelRobustScoreSoft
55
from ddranking.config import Config
66
warnings.filterwarnings("ignore")
77

88

99
""" Use config file to specify the arguments (Recommended) """
10-
config = Config.from_file("./configs/Demo_Soft_Label.yaml")
11-
soft_label_evaluator = SoftLabelEvaluator(config)
10+
config = Config.from_file("./configs/Demo_LRS_Soft_Label.yaml")
11+
soft_label_evaluator = LabelRobustScoreSoft(config)
1212

1313
syn_data_dir = "./baselines/DATM/CIFAR10/IPC10/"
1414
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
@@ -41,16 +41,17 @@
4141
soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu')
4242
syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu')
4343
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
44-
soft_label_evaluator = SoftLabelEvaluator(
44+
soft_label_evaluator = LabelRobustScoreSoft(
4545
dataset=dataset,
4646
real_data_path=data_dir,
4747
ipc=ipc,
4848
model_name=model_name,
4949
soft_label_criterion='sce', # Use Soft Cross Entropy Loss
5050
soft_label_mode='S', # Use one-to-one image to soft label mapping
51-
default_lr=0.01,
51+
loss_fn_kwargs={'temperature': 1.0, 'scale_loss': False},
5252
optimizer='sgd', # Use SGD optimizer
5353
lr_scheduler='step', # Use StepLR learning rate scheduler
54+
step_size=500,
5455
weight_decay=0.0005,
5556
momentum=0.9,
5657
use_zca=True, # Use ZCA whitening (please disable it if you didn't use it to distill synthetic data)
@@ -60,13 +61,19 @@
6061
im_size=im_size,
6162
num_epochs=1000,
6263
num_workers=4,
64+
eval_full_data=True,
6365
stu_use_torchvision=False,
6466
tea_use_torchvision=False,
67+
random_data_format='tensor',
68+
random_data_path='./random_data',
69+
custom_train_trans=None,
6570
custom_val_trans=None,
66-
syn_batch_size=128,
71+
syn_batch_size=256,
6772
real_batch_size=256,
6873
teacher_dir='./teacher_models',
74+
teacher_model_name=['ConvNet-3'],
6975
device=device,
76+
dist=True,
7077
save_path=save_path
7178
)
7279
print(soft_label_evaluator.compute_metrics(image_tensor=syn_images, soft_labels=soft_labels, syn_lr=syn_lr))

doc/metrics/ars.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ dd_ranking.metrics.AugmentationRobustScore(config: Optional[Config] = None,
2525
use_zca: bool=False,
2626
random_data_format: str='image',
2727
random_data_path: str=None,
28-
real_batch_size: int=256,
29-
syn_batch_size: int=256,
28+
batch_size: int=256,
3029
save_path: str=None,
3130
stu_use_torchvision: bool=False,
3231
tea_use_torchvision: bool=False,
@@ -65,8 +64,7 @@ A class for evaluating the performance of a dataset distillation method with sof
6564
- **num_eval**(<span style="color:#FF6B00;">int</span>): Number of evaluations to perform.
6665
- **im_size**(<span style="color:#FF6B00;">tuple</span>): Size of the images.
6766
- **num_epochs**(<span style="color:#FF6B00;">int</span>): Number of epochs to train.
68-
- **real_batch_size**(<span style="color:#FF6B00;">int</span>): Batch size for the real dataset.
69-
- **syn_batch_size**(<span style="color:#FF6B00;">int</span>): Batch size for the synthetic dataset.
67+
- **batch_size**(<span style="color:#FF6B00;">int</span>): Batch size for the model training.
7068
- **stu_use_torchvision**(<span style="color:#FF6B00;">bool</span>): Whether to use torchvision to initialize the student model.
7169
- **tea_use_torchvision**(<span style="color:#FF6B00;">bool</span>): Whether to use torchvision to initialize the teacher model.
7270
- **teacher_dir**(<span style="color:#FF6B00;">str</span>): Path to the teacher model.

0 commit comments

Comments
 (0)