Skip to content

Commit 98a284c

Browse files
Fix a bug
1 parent a025322 commit 98a284c

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

ddranking/metrics/soft_label.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def load_real_data(self, dataset, class_map, num_classes):
148148

149149
def get_class_indices(self, dataset, class_map, num_classes):
150150
class_indices = [[] for c in range(num_classes)]
151-
for idx, (_, label) in enumerate(dataset.imgs):
151+
for idx, label in enumerate(dataset.targets):
152152
if torch.is_tensor(label):
153153
label = label.item()
154154
true_label = class_map[label]

ddranking/utils/data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ class Config:
3838
class TensorDataset(torch.utils.data.Dataset):
3939

4040
def __init__(self, images: Tensor, labels: Tensor):
41+
super(TensorDataset, self).__init__()
4142
self.images = images
42-
self.labels = labels
43+
self.targets = labels
44+
self.imgs = [(image, label) for image, label in zip(images, labels)]
4345

4446
def __getitem__(self, index: int):
4547
image = self.images[index]
46-
label = self.labels[index]
48+
label = self.targets[index]
4749
return image, label
4850

4951
def __len__(self):

0 commit comments

Comments
 (0)