Skip to content

Commit 38eb9f4

Browse files
authored
Fix PointwiseCompiler on CUDA (#203)
1 parent 13892fa commit 38eb9f4

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

functorch/_src/operator_authoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def compute_code(self):
316316
loopnest = _te.LoopNest(_te.Block([out]), output_bufs)
317317

318318
if self.device == "cuda" and loops:
319-
flattened = _te.LoopNest.flatten(loops)
319+
flattened = loopnest.flatten(loops)
320320
assert flattened
321321
inner = _te.LoopNest.split_with_mask(flattened, 512)
322322
assert inner

test/test_operator_authoring.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import unittest
1212

1313
LLVM_ENABLED = torch._C._llvm_enabled()
14+
HAS_CUDA = torch.cuda.is_available()
1415
HAS_SYMPY = False
1516
try:
1617
import sympy
@@ -44,8 +45,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
4445
return torch.zeros_like(args[0])
4546

4647

47-
class TestOperatorAuthoringCPU(JitTestCase):
48-
device = "cpu"
48+
class TestOperatorAuthoring(JitTestCase):
49+
device = None
4950

5051
def rand(self, *args, dtype=torch.float32, **kwargs):
5152
return torch.randint(0, 100, args, dtype=dtype, device=self.device, **kwargs)
@@ -126,16 +127,15 @@ def example(x):
126127
torch.testing.assert_allclose(x + 3, graph(x))
127128

128129

129-
class TestOperatorAuthoringGPU(TestOperatorAuthoringCPU):
130+
@unittest.skipIf(not HAS_CUDA, "GPU tests require CUDA")
131+
class TestOperatorAuthoringGPU(TestOperatorAuthoring):
130132
device = "cuda"
131133

132134

133-
if not LLVM_ENABLED:
134-
TestOperatorAuthoringCPU = None # noqa: F811
135+
@unittest.skipIf(not LLVM_ENABLED, "CPU tests require LLVM")
136+
class TestOperatorAuthoringCPU(TestOperatorAuthoring):
137+
device = "cpu"
135138

136-
# TODO: TestOperatorAuthoringGPU is disabled because it fails on CUDAs.
137-
# if not torch.cuda.is_available():
138-
TestOperatorAuthoringGPU = None # noqa: F811
139139

140140
if __name__ == "__main__":
141141
run_tests()

0 commit comments

Comments
 (0)