Skip to content

Commit ad15940

Browse files
committed
Adjust test to comply with new functionality
1 parent 0f32064 commit ad15940

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

tests/test_dataloaders.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@ def test_uspsdataset0_6():
1717

1818
# Create a h5 file
1919
with h5py.File(tf, "w") as f:
20+
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
21+
indices = np.arange(len(targets))
2022
# Populate the file with data
2123
f["train/data"] = np.random.rand(10, 16 * 16)
22-
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
24+
f["train/target"] = targets
2325

2426
trans = transforms.Compose(
2527
[
26-
transforms.Resize((16, 16)), # At least for USPS
28+
transforms.Resize((16, 16)),
2729
transforms.ToTensor(),
2830
]
2931
)
30-
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
32+
dataset = USPSDataset0_6(
33+
data_path=tempdir,
34+
sample_ids=indices,
35+
train=True,
36+
transform=trans,
37+
)
3138
assert len(dataset) == 10
3239
data, target = dataset[0]
3340
assert data.shape == (1, 16, 16)
34-
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
41+
assert target == 6

0 commit comments

Comments
 (0)