66
77
88
9+
10+
11+
12+
13+
14+
915class LabelChangedDataset (Dataset ):
1016 def __init__ (self , original_dataset , new_label ):
1117 self .original_dataset = original_dataset
@@ -42,7 +48,7 @@ def get_subsampled_subset(dataset, subset_ratio=0.1):
4248 root = './data' , train = False , download = True , transform = transform_tensor )
4349
4450
45- trainloader_CIFAR10 = DataLoader (trainset_CIFAR10 , batch_size = 64 , shuffle = True , num_workers = 2 )
51+ trainloader_CIFAR10 = DataLoader (trainset_CIFAR10 , batch_size = 65 , shuffle = True , num_workers = 2 )
4652
4753testloader_CIFAR10 = DataLoader (testset_CIFAR10 , batch_size = 16 , shuffle = False , num_workers = 2 )
4854
@@ -55,20 +61,34 @@ def get_subsampled_subset(dataset, subset_ratio=0.1):
5561testset_CIFAR100 = torchvision .datasets .CIFAR100 (
5662 root = './data' , train = False , download = True , transform = transform_tensor )
5763
64+
65+
5866trainloader_CIFAR100 = DataLoader (trainset_CIFAR100 , batch_size = 64 , shuffle = True , num_workers = 2 )
5967
6068testloader_CIFAR100 = DataLoader (testset_CIFAR100 , batch_size = 16 , shuffle = False , num_workers = 2 )
6169
6270
71+
72+
73+
74+
75+
6376testset_CIFAR10_relabled = LabelChangedDataset (testset_CIFAR10 , new_label = 100 )
6477testset_CIFAR100_relabled = LabelChangedDataset (testset_CIFAR100 , new_label = 10 )
6578
6679
67- testloader_CIFAR10_vs_CIFAR100 = DataLoader (ConcatDataset ([testset_CIFAR10 , testset_CIFAR100_relabled ]), shuffle = False , batch_size = 16 )
68- testloader_CIFAR100_vs_CIFAR10 = DataLoader (ConcatDataset ([testset_CIFAR100 , testset_CIFAR10_relabled ]), shuffle = False , batch_size = 16 )
80+ testloader_CIFAR10_vs_CIFAR100 = DataLoader (ConcatDataset ([testset_CIFAR10 , testset_CIFAR100_relabled ]), shuffle = False , batch_size = 8 )
81+ testloader_CIFAR100_vs_CIFAR10 = DataLoader (ConcatDataset ([testset_CIFAR100 , testset_CIFAR10_relabled ]), shuffle = False , batch_size = 8 )
6982
7083def get_loaders (in_dataset = 'CIFAR10' ):
7184 if in_dataset == 'cifar10' :
72- return trainloader_CIFAR10 , testloader_CIFAR10 , testloader_CIFAR10_vs_CIFAR100
85+ return trainloader_CIFAR10 , testloader_CIFAR10 ,testset_CIFAR10 , testloader_CIFAR10_vs_CIFAR100
86+ if in_dataset == 'cifar100' :
87+ return trainloader_CIFAR100 , testloader_CIFAR100 ,testset_CIFAR100 , testloader_CIFAR100_vs_CIFAR10
7388 else :
7489 raise ValueError (f"Dataset '{ in_dataset } ' is not supported." )
90+
91+
92+
93+
94+
0 commit comments