11import os
22import torch
3- from dd_ranking .metrics import Soft_Label_Objective_Metrics
3+ import warnings
4+ from dd_ranking .metrics import Soft_Label_Evaluator
45from dd_ranking .config import Config
6+ warnings .filterwarnings ("ignore" , category = FutureWarning )
57
68
79""" Use config file to specify the arguments (Recommended) """
810config = Config .from_file ("./configs/Demo_Soft_Label.yaml" )
9- convd3_soft_obj = Soft_Label_Objective_Metrics (config )
11+ soft_label_evaluator = Soft_Label_Evaluator (config )
12+
13+ syn_data_dir = "./baselines/DATM/CIFAR10/IPC10/"
1014syn_images = torch .load (os .path .join (syn_data_dir , f"images.pt" ), map_location = 'cpu' )
1115soft_labels = torch .load (os .path .join (syn_data_dir , f"labels.pt" ), map_location = 'cpu' )
1216syn_lr = torch .load (os .path .join (syn_data_dir , f"lr.pt" ), map_location = 'cpu' )
13- print (convd3_soft_obj .compute_metrics (syn_images , soft_labels , syn_lr = syn_lr ))
17+ print (soft_label_evaluator .compute_metrics (image_tensor = syn_images , soft_labels = soft_labels , syn_lr = syn_lr ))
1418
1519
1620""" Use keyword arguments """
3539
3640syn_images = torch .load (os .path .join (syn_data_dir , f"images.pt" ), map_location = 'cpu' )
3741soft_labels = torch .load (os .path .join (syn_data_dir , f"labels.pt" ), map_location = 'cpu' )
42+ syn_lr = torch .load (os .path .join (syn_data_dir , f"lr.pt" ), map_location = 'cpu' )
3843save_path = f"./results/{ dataset } /{ model_name } /IPC{ ipc } /dm_hard_scores.csv"
39- convd3_hard_obj = Soft_Label_Objective_Metrics (
44+ soft_label_evaluator = Soft_Label_Evaluator (
4045 dataset = dataset ,
4146 real_data_path = data_dir ,
4247 ipc = ipc ,
6469 device = device ,
6570 save_path = save_path
6671)
67- print (convd3_hard_obj .compute_metrics (syn_images , soft_labels , syn_lr = 0.01 ))
72+ print (soft_label_evaluator .compute_metrics (syn_images , soft_labels , syn_lr = syn_lr ))
0 commit comments