@@ -48,7 +48,7 @@ def __len__(self):
4848 return len (self .images )
4949
5050
51- def get_dataset (dataset , data_path , im_size , use_zca , custom_val_trans , device ):
51+ def get_dataset (dataset , data_path , im_size , use_zca , custom_train_trans , custom_val_trans , device ):
5252 class_map_inv = None
5353
5454 if dataset == 'CIFAR10' :
@@ -68,7 +68,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
6868 transforms .ToTensor ()
6969 ])
7070
71- dst_train = datasets .CIFAR10 (data_path , train = True , download = True , transform = transform )
71+ dst_train = datasets .CIFAR10 (data_path , train = True , download = True , transform = transform if custom_train_trans is None else custom_train_trans )
7272 dst_test = datasets .CIFAR10 (data_path , train = False , download = True , transform = transform if custom_val_trans is None else custom_val_trans )
7373 class_map = {x : x for x in range (num_classes )}
7474
@@ -89,7 +89,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
8989 transforms .ToTensor ()
9090 ])
9191
92- dst_train = datasets .CIFAR100 (data_path , train = True , download = True , transform = transform )
92+ dst_train = datasets .CIFAR100 (data_path , train = True , download = True , transform = transform if custom_train_trans is None else custom_train_trans )
9393 dst_test = datasets .CIFAR100 (data_path , train = False , download = True , transform = transform if custom_val_trans is None else custom_val_trans )
9494 class_map = {x : x for x in range (num_classes )}
9595
@@ -108,7 +108,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
108108 transform = transforms .Compose ([
109109 transforms .ToTensor ()
110110 ])
111- dst_train = datasets .ImageFolder (os .path .join (data_path , "train" ), transform = transform )
111+ dst_train = datasets .ImageFolder (os .path .join (data_path , "train" ), transform = transform if custom_train_trans is None else custom_train_trans )
112112 dst_test = datasets .ImageFolder (os .path .join (data_path , "val" ), transform = transform if custom_val_trans is None else custom_val_trans )
113113 class_map = {x : x for x in range (num_classes )}
114114
@@ -129,7 +129,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
129129 transforms .CenterCrop (im_size )
130130 ])
131131
132- dst_train = datasets .ImageFolder (os .path .join (data_path , "train" ), transform = transform )
132+ dst_train = datasets .ImageFolder (os .path .join (data_path , "train" ), transform = transform if custom_train_trans is None else custom_train_trans )
133133 dst_train = torch .utils .data .Subset (dst_train , np .squeeze (np .argwhere (np .isin (dst_train .targets , config .img_net_classes ))))
134134 dst_test = datasets .ImageFolder (os .path .join (data_path , "val" ), transform = transform if custom_val_trans is None else custom_val_trans )
135135 dst_test = torch .utils .data .Subset (dst_test , np .squeeze (np .argwhere (np .isin (dst_test .targets , config .img_net_classes ))))
@@ -153,7 +153,7 @@ def get_dataset(dataset, data_path, im_size, use_zca, custom_val_trans, device):
153153 transforms .CenterCrop (im_size )
154154 ])
155155
156- dst_train = datasets .ImageFolder (os .path .join (data_path , "train" ), transform = transform )
156+ dst_train = datasets .ImageFolder (os .path .join (data_path , "train" ), transform = transform if custom_train_trans is None else custom_train_trans )
157157 dst_test = datasets .ImageFolder (os .path .join (data_path , "val" ), transform = transform if custom_val_trans is None else custom_val_trans )
158158
159159 class_map = {x : i for i , x in enumerate (range (num_classes ))}
0 commit comments