|
24 | 24 |
|
25 | 25 | Inputs = namedtuple("case", ["x", "y"])
|
26 | 26 |
|
27 |
| -_cpu_cases = [ |
| 27 | +_cases = [ |
28 | 28 | Inputs(x=torch.randn(100), y=torch.randn(100)),
|
29 | 29 | Inputs(x=torch.randn(100, requires_grad=True), y=torch.randn(100, requires_grad=True)),
|
30 | 30 | # test that list/numpy arrays still works
|
|
36 | 36 | Inputs(x=torch.randn(5), y=[1, 2, 3, 4, 5]),
|
37 | 37 | ]
|
38 | 38 |
|
39 |
| -_gpu_cases = [ |
40 |
| - Inputs(x=torch.randn(100, device="cuda"), y=torch.randn(100, device="cuda")), |
41 |
| - Inputs( |
42 |
| - x=torch.randn(100, requires_grad=True, device="cuda"), y=torch.randn(100, requires_grad=True, device="cuda") |
43 |
| - ), |
44 |
| -] |
45 |
| - |
46 | 39 |
|
47 | 40 | _members_to_check = [name for name, member in getmembers(plt) if isfunction(member) and not name.startswith("_")]
|
48 | 41 |
|
49 | 42 |
|
50 |
| -def string_compare(text1, text2): |
51 |
| - if text1 is None and text2 is None: |
52 |
| - return True |
53 |
| - remove = string.punctuation + string.whitespace |
54 |
| - return text1.translate(str.maketrans(dict.fromkeys(remove))) == text2.translate( |
55 |
| - str.maketrans(dict.fromkeys(remove)) |
56 |
| - ) |
57 |
| - |
58 |
| - |
59 | 43 | @pytest.mark.parametrize("member", _members_to_check)
|
60 | 44 | def test_members(member):
|
61 | 45 | """ test that all members have been copied """
|
62 | 46 | assert member in dir(plt)
|
63 | 47 | assert member in dir(tp)
|
64 | 48 |
|
65 | 49 |
|
66 |
| -@pytest.mark.parametrize("test_case", _cpu_cases) |
| 50 | +@pytest.mark.parametrize("test_case", _cases) |
67 | 51 | def test_cpu(test_case):
|
68 | 52 | """ test that it works on cpu """
|
69 | 53 | assert tp.plot(test_case.x, test_case.y, ".")
|
70 | 54 |
|
71 | 55 |
|
72 | 56 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
|
73 |
| -@pytest.mark.parametrize("test_case", _gpu_cases) |
| 57 | +@pytest.mark.parametrize("test_case", _cases) |
74 | 58 | def test_gpu(test_case):
|
75 | 59 | """ test that it works on gpu """
|
76 |
| - assert tp.plot(test_case.x, test_case.y, ".") |
| 60 | + assert tp.plot( |
| 61 | + test_case.x.cuda() if isinstance(test_case.x, torch.Tensor) else test_case.x, |
| 62 | + test_case.y.cuda() if isinstance(test_case.y, torch.Tensor) else test_case.y, |
| 63 | + "." |
| 64 | + ) |
0 commit comments