We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0f32064 commit ad15940Copy full SHA for ad15940
tests/test_dataloaders.py
@@ -17,18 +17,25 @@ def test_uspsdataset0_6():
17
18
# Create a h5 file
19
with h5py.File(tf, "w") as f:
20
+ targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
21
+ indices = np.arange(len(targets))
22
# Populate the file with data
23
f["train/data"] = np.random.rand(10, 16 * 16)
- f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
24
+ f["train/target"] = targets
25
26
trans = transforms.Compose(
27
[
- transforms.Resize((16, 16)), # At least for USPS
28
+ transforms.Resize((16, 16)),
29
transforms.ToTensor(),
30
]
31
)
- dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
32
+ dataset = USPSDataset0_6(
33
+ data_path=tempdir,
34
+ sample_ids=indices,
35
+ train=True,
36
+ transform=trans,
37
+ )
38
assert len(dataset) == 10
39
data, target = dataset[0]
40
assert data.shape == (1, 16, 16)
- assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
41
+ assert target == 6
0 commit comments