Skip to content

Commit c418898

Browse files
authored
fix tests
1 parent 3555227 commit c418898

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

tests/test_torchplot.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
Inputs = namedtuple("case", ["x", "y"])
2626

27-
_cpu_cases = [
27+
_cases = [
2828
Inputs(x=torch.randn(100), y=torch.randn(100)),
2929
Inputs(x=torch.randn(100, requires_grad=True), y=torch.randn(100, requires_grad=True)),
3030
# test that list/numpy arrays still works
@@ -36,41 +36,29 @@
3636
Inputs(x=torch.randn(5), y=[1, 2, 3, 4, 5]),
3737
]
3838

39-
_gpu_cases = [
40-
Inputs(x=torch.randn(100, device="cuda"), y=torch.randn(100, device="cuda")),
41-
Inputs(
42-
x=torch.randn(100, requires_grad=True, device="cuda"), y=torch.randn(100, requires_grad=True, device="cuda")
43-
),
44-
]
45-
4639

4740
_members_to_check = [name for name, member in getmembers(plt) if isfunction(member) and not name.startswith("_")]
4841

4942

50-
def string_compare(text1, text2):
51-
if text1 is None and text2 is None:
52-
return True
53-
remove = string.punctuation + string.whitespace
54-
return text1.translate(str.maketrans(dict.fromkeys(remove))) == text2.translate(
55-
str.maketrans(dict.fromkeys(remove))
56-
)
57-
58-
5943
@pytest.mark.parametrize("member", _members_to_check)
6044
def test_members(member):
6145
""" test that all members have been copied """
6246
assert member in dir(plt)
6347
assert member in dir(tp)
6448

6549

66-
@pytest.mark.parametrize("test_case", _cpu_cases)
50+
@pytest.mark.parametrize("test_case", _cases)
6751
def test_cpu(test_case):
6852
""" test that it works on cpu """
6953
assert tp.plot(test_case.x, test_case.y, ".")
7054

7155

7256
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
73-
@pytest.mark.parametrize("test_case", _gpu_cases)
57+
@pytest.mark.parametrize("test_case", _cases)
7458
def test_gpu(test_case):
7559
""" test that it works on gpu """
76-
assert tp.plot(test_case.x, test_case.y, ".")
60+
assert tp.plot(
61+
test_case.x.cuda() if isinstance(test_case.x, torch.Tensor) else test_case.x,
62+
test_case.y.cuda() if isinstance(test_case.y, torch.Tensor) else test_case.y,
63+
"."
64+
)

0 commit comments

Comments
 (0)