Skip to content

Commit 3584a9a

Browse files
njriasanmeta-codesync[bot]
authored andcommitted
[TLX] [FA] Enable H-DIM=64 for Backwards (#711)
Summary: Enables H-DIM=64 for Backwards which already works. Pull Request resolved: #711 Reviewed By: htyu Differential Revision: D88212441 Pulled By: njriasan fbshipit-source-id: f90d0b4fc652002a6ab95f39abd423e5eb9668c6
1 parent 89823c2 commit 3584a9a

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,11 +1750,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16):
17501750
sm_scale = 0.5
17511751
# reference implementation
17521752
ref_dtype = dtype
1753-
if mode == "bwd" and HEAD_DIM == 64:
1754-
pytest.skip("Only test bwd with 128")
1755-
elif mode == "fwd" and not causal and HEAD_DIM == 128:
1756-
pytest.skip("Only test fwd with causal")
1757-
elif mode == "bwd" and causal:
1753+
if mode == "bwd" and causal:
17581754
pytest.skip("Causal not supported for bwd yet")
17591755
if mode == "fwd" and "fp8" in provider:
17601756
ref_dtype = torch.float32

0 commit comments

Comments
 (0)