Skip to content

Commit 98e531c

Browse files
Add tests to check how transformations are applied.
1 parent 7da96fd commit 98e531c

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

tests/test_transforms.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import numpy as np
2+
import pytest
3+
from torchvision.transforms import ToTensor
4+
from torch.utils.data import DataLoader
5+
import torch
6+
7+
from continuum.datasets import InMemoryDataset
8+
from continuum.scenarios import ClassIncremental, InstanceIncremental, ContinualScenario
9+
10+
11+
@pytest.fixture
12+
def dataset():
13+
x = np.random.randint(0, 255, (100, 4, 4, 3), dtype=np.uint8)
14+
y = np.random.randint(0, 3, (100,), dtype=np.int16)
15+
t = np.ones_like(y)
16+
17+
t[:30] = 0
18+
t[30:60] = 1
19+
t[60:] = 2
20+
21+
return InMemoryDataset(x, y, t)
22+
23+
24+
25+
@pytest.mark.parametrize("scenario,opt", [
26+
(ClassIncremental, {'increment': 1}),
27+
(InstanceIncremental, {}),
28+
(ContinualScenario, {})
29+
])
30+
def test_same_transforms(dataset, scenario, opt):
31+
trsfs = [
32+
ToTensor(),
33+
lambda tensor: tensor.fill_(0)
34+
]
35+
s = scenario(dataset, transformations=trsfs, **opt)
36+
37+
for taskset in s:
38+
loader = DataLoader(taskset)
39+
for x, _, _ in loader:
40+
assert torch.unique(x).numpy().tolist() == [0]
41+
42+
43+
@pytest.mark.parametrize("scenario,opt", [
44+
(ClassIncremental, {'increment': 1}),
45+
(InstanceIncremental, {}),
46+
(ContinualScenario, {})
47+
])
48+
def test_diff_transforms(dataset, scenario, opt):
49+
trsfs = [
50+
[ToTensor(), lambda tensor1: tensor1.fill_(0)],
51+
[ToTensor(), lambda tensor2: tensor2.fill_(1)],
52+
[ToTensor(), lambda tensor3: tensor3.fill_(2)],
53+
]
54+
s = scenario(dataset, transformations=trsfs, **opt)
55+
56+
for taskid, taskset in enumerate(s):
57+
loader = DataLoader(taskset)
58+
for x, _, _ in loader:
59+
assert torch.unique(x).numpy().tolist() == [taskid]
60+
61+
62+
@pytest.mark.parametrize("scenario,opt,error", [
63+
(ClassIncremental, {'increment': 1}, True),
64+
(InstanceIncremental, {}, True),
65+
(ContinualScenario, {}, True),
66+
(ClassIncremental, {'increment': 1}, False),
67+
(InstanceIncremental, {}, False),
68+
(ContinualScenario, {}, False)
69+
])
70+
def test_diff_transforms_slice(dataset, scenario, opt, error):
71+
trsfs = [
72+
[ToTensor(), lambda tensor1: tensor1.fill_(0)],
73+
[ToTensor(), lambda tensor2: tensor2.fill_(1)],
74+
[ToTensor(), lambda tensor3: tensor3.fill_(2)],
75+
]
76+
s = scenario(dataset, transformations=trsfs, **opt)
77+
78+
for taskid in range(len(s)):
79+
if not error:
80+
taskset = s[taskid]
81+
loader = DataLoader(taskset)
82+
for x, _, _ in loader:
83+
assert torch.unique(x).numpy().tolist() == [taskid]
84+
else:
85+
with pytest.raises(ValueError):
86+
s[:taskid]
87+
88+
89+
90+
91+

0 commit comments

Comments
 (0)