Skip to content

Commit 17a2be8

Browse files
[AMD] Fix test_split_subview on gfx11/gfx12 (#7457)
Use THREADS_PER_WARP instead of computing it manually to correctly detect warp size of 32 on rdna Co-authored-by: Paul Trojahn <[email protected]>
1 parent 967e498 commit 17a2be8

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

python/test/unit/language/test_core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6319,8 +6319,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device='cuda'):
63196319
if not is_hip():
63206320
pytest.skip("the test is temporary disabled for the Nvidia backend.")
63216321

6322-
threads_per_warp = 64 if is_hip() else 32
6323-
num_raws_per_warp = 16 if is_hip() else 8
6322+
num_raws_per_warp = THREADS_PER_WARP // 4
63246323
num_repeats_M = int(M / M_tile_size)
63256324
num_repeats_N = int(N / N_tile_size)
63266325

@@ -6329,7 +6328,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device='cuda'):
63296328
#shared = #ttg.swizzled_shared<{{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}}>
63306329
#smem = #ttg.shared_memory
63316330
6332-
module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
6331+
module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
63336332
tt.func public @kernel(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
63346333
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
63356334
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked>

0 commit comments

Comments
 (0)