Skip to content

Commit eab0b08

Browse files
committed
Update to handle nr_channels arg
1 parent 717eb01 commit eab0b08

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

CollaborativeCoding/dataloaders/usps_0_6.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import h5py as h5
1010
import numpy as np
11+
import torch
1112
from PIL import Image
1213
from torch.utils.data import Dataset
1314

@@ -83,6 +84,7 @@ def __init__(
8384
sample_ids: list,
8485
train: bool = False,
8586
transform=None,
87+
nr_channels=1,
8688
):
8789
super().__init__()
8890

@@ -91,6 +93,7 @@ def __init__(
9193
self.transform = transform
9294
self.mode = "train" if train else "test"
9395
self.sample_ids = sample_ids
96+
self.nr_channels = nr_channels
9497

9598
def __len__(self):
9699
return len(self.sample_ids)
@@ -100,11 +103,18 @@ def __getitem__(self, id):
100103

101104
with h5.File(self.filepath, "r") as f:
102105
data = f[self.mode]["data"][index].astype(np.uint8)
103-
label = f[self.mode]["target"][index]
106+
label = int(f[self.mode]["target"][index])
104107

105-
data = Image.fromarray(data, mode="L")
108+
if self.nr_channels == 1:
109+
data = Image.fromarray(data, mode="L")
110+
elif self.nr_channels == 3:
111+
data = Image.fromarray(data, mode="RGB")
112+
else:
113+
raise ValueError("Invalid number of channels")
106114

107115
if self.transform:
108116
data = self.transform(data)
109117

118+
# label = torch.tensor(label).long()
119+
110120
return data, label

CollaborativeCoding/load_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,23 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
8989
sample_ids=train_samples,
9090
train=True,
9191
transform=transform,
92-
nr_channels=kwargs.get("nr_channels"),
92+
nr_channels=kwargs.get("nr_channels", 1),
9393
)
9494

9595
val = dataset(
9696
data_path=data_dir,
9797
sample_ids=val_samples,
9898
train=True,
9999
transform=transform,
100-
nr_channels=kwargs.get("nr_channels"),
100+
nr_channels=kwargs.get("nr_channels", 1),
101101
)
102102

103103
test = dataset(
104104
data_path=data_dir,
105105
sample_ids=test_samples,
106106
train=False,
107107
transform=transform,
108-
nr_channels=kwargs.get("nr_channels"),
108+
nr_channels=kwargs.get("nr_channels", 1),
109109
)
110110

111111
return train, val, test

0 commit comments

Comments
 (0)