Skip to content

Commit 90fce43

Browse files
authored
Update data_loader.py
1 parent ae54006 commit 90fce43

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

AROS/data_loader.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66

77

88

9+
10+
11+
12+
13+
14+
915
class LabelChangedDataset(Dataset):
1016
def __init__(self, original_dataset, new_label):
1117
self.original_dataset = original_dataset
@@ -42,7 +48,7 @@ def get_subsampled_subset(dataset, subset_ratio=0.1):
4248
root='./data', train=False, download=True, transform=transform_tensor)
4349

4450

45-
trainloader_CIFAR10 = DataLoader(trainset_CIFAR10, batch_size=64, shuffle=True, num_workers=2)
51+
trainloader_CIFAR10 = DataLoader(trainset_CIFAR10, batch_size=65, shuffle=True, num_workers=2)
4652

4753
testloader_CIFAR10 = DataLoader(testset_CIFAR10, batch_size=16, shuffle=False, num_workers=2)
4854

@@ -55,20 +61,34 @@ def get_subsampled_subset(dataset, subset_ratio=0.1):
5561
testset_CIFAR100 = torchvision.datasets.CIFAR100(
5662
root='./data', train=False, download=True, transform=transform_tensor)
5763

64+
65+
5866
trainloader_CIFAR100 = DataLoader(trainset_CIFAR100, batch_size=64, shuffle=True, num_workers=2)
5967

6068
testloader_CIFAR100 = DataLoader(testset_CIFAR100, batch_size=16, shuffle=False, num_workers=2)
6169

6270

71+
72+
73+
74+
75+
6376
testset_CIFAR10_relabled = LabelChangedDataset(testset_CIFAR10, new_label=100)
6477
testset_CIFAR100_relabled = LabelChangedDataset(testset_CIFAR100, new_label=10)
6578

6679

67-
testloader_CIFAR10_vs_CIFAR100 = DataLoader(ConcatDataset([testset_CIFAR10, testset_CIFAR100_relabled]), shuffle=False, batch_size=16)
68-
testloader_CIFAR100_vs_CIFAR10 = DataLoader(ConcatDataset([testset_CIFAR100, testset_CIFAR10_relabled]), shuffle=False, batch_size=16)
80+
testloader_CIFAR10_vs_CIFAR100 = DataLoader(ConcatDataset([testset_CIFAR10, testset_CIFAR100_relabled]), shuffle=False, batch_size=8)
81+
testloader_CIFAR100_vs_CIFAR10 = DataLoader(ConcatDataset([testset_CIFAR100, testset_CIFAR10_relabled]), shuffle=False, batch_size=8)
6982

7083
def get_loaders(in_dataset='CIFAR10'):
7184
if in_dataset == 'cifar10':
72-
return trainloader_CIFAR10, testloader_CIFAR10, testloader_CIFAR10_vs_CIFAR100
85+
return trainloader_CIFAR10, testloader_CIFAR10,testset_CIFAR10, testloader_CIFAR10_vs_CIFAR100
86+
if in_dataset == 'cifar100':
87+
return trainloader_CIFAR100, testloader_CIFAR100,testset_CIFAR100, testloader_CIFAR100_vs_CIFAR10
7388
else:
7489
raise ValueError(f"Dataset '{in_dataset}' is not supported.")
90+
91+
92+
93+
94+

0 commit comments

Comments
 (0)