|
11 | 11 | import unittest
|
12 | 12 |
|
13 | 13 | LLVM_ENABLED = torch._C._llvm_enabled()
|
| 14 | +HAS_CUDA = torch.cuda.is_available() |
14 | 15 | HAS_SYMPY = False
|
15 | 16 | try:
|
16 | 17 | import sympy
|
@@ -44,8 +45,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
|
44 | 45 | return torch.zeros_like(args[0])
|
45 | 46 |
|
46 | 47 |
|
47 |
| -class TestOperatorAuthoringCPU(JitTestCase): |
48 |
| - device = "cpu" |
| 48 | +class TestOperatorAuthoring(JitTestCase): |
| 49 | + device = None |
49 | 50 |
|
50 | 51 | def rand(self, *args, dtype=torch.float32, **kwargs):
|
51 | 52 | return torch.randint(0, 100, args, dtype=dtype, device=self.device, **kwargs)
|
@@ -126,16 +127,15 @@ def example(x):
|
126 | 127 | torch.testing.assert_allclose(x + 3, graph(x))
|
127 | 128 |
|
128 | 129 |
|
129 |
| -class TestOperatorAuthoringGPU(TestOperatorAuthoringCPU): |
| 130 | +@unittest.skipIf(not HAS_CUDA, "GPU tests require CUDA") |
| 131 | +class TestOperatorAuthoringGPU(TestOperatorAuthoring): |
130 | 132 | device = "cuda"
|
131 | 133 |
|
132 | 134 |
|
133 |
| -if not LLVM_ENABLED: |
134 |
| - TestOperatorAuthoringCPU = None # noqa: F811 |
| 135 | +@unittest.skipIf(not LLVM_ENABLED, "CPU tests require LLVM") |
| 136 | +class TestOperatorAuthoringCPU(TestOperatorAuthoring): |
| 137 | + device = "cpu" |
135 | 138 |
|
136 |
| -# TODO: TestOperatorAuthoringGPU is disabled because it fails on CUDAs. |
137 |
| -# if not torch.cuda.is_available(): |
138 |
| -TestOperatorAuthoringGPU = None # noqa: F811 |
139 | 139 |
|
140 | 140 | if __name__ == "__main__":
|
141 | 141 | run_tests()
|
0 commit comments