diff --git a/src/datasets/cifar10.py b/src/datasets/cifar10.py index b31c681..312c151 100644 --- a/src/datasets/cifar10.py +++ b/src/datasets/cifar10.py @@ -40,7 +40,7 @@ def __init__(self, root: str, normal_class=5): train_set = MyCIFAR10(root=self.root, train=True, download=True, transform=transform, target_transform=target_transform) # Subset train set to normal class - train_idx_normal = get_target_label_idx(train_set.train_labels, self.normal_classes) + train_idx_normal = get_target_label_idx(train_set.targets, self.normal_classes) self.train_set = Subset(train_set, train_idx_normal) self.test_set = MyCIFAR10(root=self.root, train=False, download=True, @@ -61,9 +61,9 @@ def __getitem__(self, index): triple: (image, target, index) where target is index of the target class. """ if self.train: - img, target = self.train_data[index], self.train_labels[index] + img, target = self.data[index], self.targets[index] else: - img, target = self.test_data[index], self.test_labels[index] + img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image diff --git a/src/main.py b/src/main.py index a558cae..2497861 100644 --- a/src/main.py +++ b/src/main.py @@ -176,8 +176,8 @@ def main(dataset_name, net_name, xp_path, data_path, load_config, load_model, ob X_outliers = dataset.test_set.test_data[idx_sorted[-32:], ...].unsqueeze(1) if dataset_name == 'cifar10': - X_normals = torch.tensor(np.transpose(dataset.test_set.test_data[idx_sorted[:32], ...], (0, 3, 1, 2))) - X_outliers = torch.tensor(np.transpose(dataset.test_set.test_data[idx_sorted[-32:], ...], (0, 3, 1, 2))) + X_normals = torch.tensor(np.transpose(dataset.test_set.data[idx_sorted[:32], ...], (0, 3, 1, 2))) + X_outliers = torch.tensor(np.transpose(dataset.test_set.data[idx_sorted[-32:], ...], (0, 3, 1, 2))) plot_images_grid(X_normals, export_img=xp_path + '/normals', title='Most normal examples', padding=2) plot_images_grid(X_outliers, export_img=xp_path + '/outliers', title='Most anomalous examples', padding=2)