Skip to content

Commit dba750e

Browse files
[AMD] Support async load in ping-pong pass (#7458)
This PR introduces two additional pingpong transforms for the GEMM kernels using async_copy - When num_stages=3 used with narrow tile_K, it schedules most of the memory ops and dot op in each separate cluster. num_stages=3 naturally utilizes prefetching and narrow tile_K is required for the increased shared memory requirement from the prefetch. - For mxfp types, this introduces the initial support that has one small memory cluster with async_copy and the other cluster with local_load and dot_scaled interleaved. The performance of this variant rely a lot on the backend compiler currently. - This also introduces a small change to the stream pipeliner. When num_stages=3 and async_copy is enabled, async_wait will be scheduled at the end of the loop instead of the top and carry over its token. No major performance difference is expected but this is required by pingpong scheduling to deal with dependency between warps. Still can explicitly disable it by `TRITON_HIP_USE_BLOCK_PINGPONG=0`
1 parent ea82c9f commit dba750e

File tree

10 files changed

+494
-103
lines changed

10 files changed

+494
-103
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,6 +1996,13 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand(
19961996
ArrayRef<unsigned> sharedOrder, unsigned vectorSize, unsigned elemBitWidth,
19971997
bool needTrans) const {
19981998
int kDimIndex = operandIdx == 0 ? 1 : 0;
1999+
2000+
// Disable swizzling for scales
2001+
if (operandIdx >= 2) {
2002+
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
2003+
ctaLayout);
2004+
}
2005+
19992006
if (needTrans)
20002007
kDimIndex = 1 - kDimIndex;
20012008

python/src/passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,8 @@
3636
#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \
3737
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \
3838
ty3 val3) { pm.addPass(builder({val0, val1, val2, val3})); })
39+
40+
#define ADD_PASS_OPTION_WRAPPER_5(name, builder, ty0, ty1, ty2, ty3, ty4) \
41+
m.def(name, \
42+
[](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \
43+
ty4 val4) { pm.addPass(builder({val0, val1, val2, val3, val4})); })

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: triton-opt %s -split-input-file --tritonamdgpu-block-pingpong="num-stages=2" | FileCheck %s
2+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-block-pingpong="num-stages=3" | FileCheck %s --check-prefixes CHECK-NS3
23

34
//CHECK-LABEL: pingpong_small
45
//CHECK: ttg.local_load
@@ -1835,3 +1836,140 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
18351836
tt.return
18361837
}
18371838
}
1839+
1840+
// -----
1841+
// CHECK-LABEL: async_ns3_gemm
1842+
// CHECK-NOT: rocdl
1843+
// CHECK-NS3-LABEL: async_ns3_gemm
1844+
// CHECK-NS3: amdgpu.cond_barrier
1845+
// CHECK-NS3: %[[LL0:.+]] = ttg.local_load
1846+
// CHECK-NS3: %[[LL1:.+]] = ttg.local_load
1847+
// CHECK-NS3: ttg.async_wait
1848+
// CHECK-NS3: tt.dot %[[LL0]], %[[LL1]]
1849+
// CHECK-NS3: amdgpu.cond_barrier
1850+
1851+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
1852+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
1853+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>
1854+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [0, 1]}>
1855+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
1856+
#smem = #ttg.shared_memory
1857+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
1858+
tt.func public @async_ns3_gemm(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: tensor<256x32x!tt.ptr<bf16>, #blocked>, %arg11: tensor<32x256x!tt.ptr<bf16>, #blocked1>, %arg12: !ttg.memdesc<256x32xbf16, #shared, #smem, mutable>, %arg13: !ttg.memdesc<256x32xbf16, #shared, #smem, mutable>, %arg14: !ttg.async.token, %arg15: !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable>, %arg16: !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable>, %arg17: !ttg.async.token, %arg18: !ttg.async.token, %arg19: !ttg.async.token, %arg20: tensor<256x32xi32, #blocked>, %arg21: tensor<32x256xi32, #blocked1>, %arg22: !ttg.memdesc<3x256x32xbf16, #shared, #smem, mutable>, %arg23: !ttg.memdesc<3x32x256xbf16, #shared1, #smem, mutable>, %arg24: tensor<256x256x!tt.ptr<bf16>, #mma>, %arg25: tensor<256x256xi1, #mma>) attributes {noinline = false} {
1859+
%c3_i32 = arith.constant 3 : i32
1860+
%c0_i32 = arith.constant 0 : i32
1861+
%c1_i32 = arith.constant 1 : i32
1862+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
1863+
%0:12 = scf.for %arg26 = %c0_i32 to %arg9 step %c1_i32 iter_args(%arg27 = %cst, %arg28 = %arg10, %arg29 = %arg11, %arg30 = %c1_i32, %arg31 = %arg12, %arg32 = %arg13, %arg33 = %arg14, %arg34 = %arg15, %arg35 = %arg16, %arg36 = %arg17, %arg37 = %arg18, %arg38 = %arg19) -> (tensor<256x256xf32, #mma>, tensor<256x32x!tt.ptr<bf16>, #blocked>, tensor<32x256x!tt.ptr<bf16>, #blocked1>, i32, !ttg.memdesc<256x32xbf16, #shared, #smem, mutable>, !ttg.memdesc<256x32xbf16, #shared, #smem, mutable>, !ttg.async.token, !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable>, !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable>, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 {
1864+
%4 = tt.addptr %arg28, %arg20 : tensor<256x32x!tt.ptr<bf16>, #blocked>, tensor<256x32xi32, #blocked>
1865+
%5 = tt.addptr %arg29, %arg21 : tensor<32x256x!tt.ptr<bf16>, #blocked1>, tensor<32x256xi32, #blocked1>
1866+
%6 = arith.addi %arg30, %c1_i32 : i32
1867+
%7 = arith.cmpi slt, %6, %c3_i32 : i32
1868+
%8 = arith.select %7, %6, %c0_i32 : i32
1869+
%9 = ttg.memdesc_subview %arg22[%8, %c0_i32, %c0_i32] : !ttg.memdesc<3x256x32xbf16, #shared, #smem, mutable> -> !ttg.memdesc<256x32xbf16, #shared, #smem, mutable>
1870+
%10 = ttg.async_copy_global_to_local %4, %9 : tensor<256x32x!tt.ptr<bf16>, #blocked> -> <256x32xbf16, #shared, #smem, mutable>
1871+
%11 = ttg.async_commit_group %10
1872+
%12 = ttg.local_load %arg31 token %arg33 : !ttg.memdesc<256x32xbf16, #shared, #smem, mutable> -> tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
1873+
%13 = ttg.memdesc_subview %arg23[%8, %c0_i32, %c0_i32] : !ttg.memdesc<3x32x256xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable>
1874+
%14 = ttg.async_copy_global_to_local %5, %13 : tensor<32x256x!tt.ptr<bf16>, #blocked1> -> <32x256xbf16, #shared1, #smem, mutable>
1875+
%15 = ttg.async_commit_group %14
1876+
%16 = ttg.local_load %arg34 token %arg36 : !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
1877+
%17 = tt.dot %12, %16, %arg27 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x256xf32, #mma>
1878+
%18 = ttg.async_wait %arg37 {num = 0 : i32}
1879+
%19 = ttg.async_wait %arg38 {num = 0 : i32}
1880+
scf.yield %17, %4, %5, %8, %arg32, %9, %18, %arg35, %13, %19, %11, %15 : tensor<256x256xf32, #mma>, tensor<256x32x!tt.ptr<bf16>, #blocked>, tensor<32x256x!tt.ptr<bf16>, #blocked1>, i32, !ttg.memdesc<256x32xbf16, #shared, #smem, mutable>, !ttg.memdesc<256x32xbf16, #shared, #smem, mutable>, !ttg.async.token, !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable>, !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable>, !ttg.async.token, !ttg.async.token, !ttg.async.token
1881+
}
1882+
%1 = ttg.async_wait %0#10 {num = 0 : i32}
1883+
%2 = ttg.async_wait %0#11 {num = 0 : i32}
1884+
ttg.local_dealloc %arg22 : !ttg.memdesc<3x256x32xbf16, #shared, #smem, mutable>
1885+
ttg.local_dealloc %arg23 : !ttg.memdesc<3x32x256xbf16, #shared1, #smem, mutable>
1886+
%3 = arith.truncf %0#0 : tensor<256x256xf32, #mma> to tensor<256x256xbf16, #mma>
1887+
tt.store %arg24, %3, %arg25 : tensor<256x256x!tt.ptr<bf16>, #mma>
1888+
tt.return
1889+
}
1890+
}
1891+
1892+
1893+
// -----
1894+
// CHECK-LABEL: gemm_mxfp4
1895+
// CHECK: amdgpu.cond_barrier
1896+
// CHECK: %[[WAIT:.+]] = ttg.async_wait
1897+
// CHECK: ttg.async_copy_global_to_local
1898+
// CHECK: ttg.async_copy_global_to_local
1899+
// CHECK: ttg.async_copy_global_to_local
1900+
// CHECK: ttg.async_copy_global_to_local
1901+
// CHECK: rocdl.sched.barrier 0
1902+
// CHECK: rocdl.s.barrier
1903+
// CHECK: rocdl.sched.barrier 0
1904+
// CHECK: %[[LL0:.+]] = ttg.local_load
1905+
// CHECK-SAME: %[[WAIT]]
1906+
// CHECK: %[[LL1:.+]] = ttg.local_load
1907+
// CHECK-SAME: %[[WAIT]]
1908+
// CHECK: %[[LL2:.+]] = ttg.local_load
1909+
// CHECK-SAME: %[[WAIT]]
1910+
// CHECK: %[[LL3:.+]] = ttg.local_load
1911+
// CHECK-SAME: %[[WAIT]]
1912+
// CHECK: tt.dot_scaled %[[LL2]] scale %[[LL0]], %[[LL3]] scale %[[LL1]]
1913+
// CHECK: amdgpu.cond_barrier
1914+
1915+
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 8], order = [0, 1]}>
1916+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
1917+
#blocked2 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
1918+
#linear = #ttg.linear<{register = [[0, 4], [32, 0], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[0, 0], [0, 0], [16, 0]], block = []}>
1919+
#linear1 = #ttg.linear<{register = [[0, 4], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[16, 0], [32, 0], [0, 0]], block = []}>
1920+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>
1921+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
1922+
#shared1 = #ttg.swizzled_shared<{vec = 16, perPhase = 2, maxPhase = 8, order = [1, 0]}>
1923+
#shared2 = #ttg.swizzled_shared<{vec = 16, perPhase = 2, maxPhase = 8, order = [0, 1]}>
1924+
#smem = #ttg.shared_memory
1925+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
1926+
tt.func public @gemm_mxfp4(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: tensor<256x8x!tt.ptr<i8>, #blocked>, %arg15: tensor<256x8x!tt.ptr<i8>, #blocked>, %arg16: tensor<256x128x!tt.ptr<i8>, #blocked1>, %arg17: tensor<128x256x!tt.ptr<i8>, #blocked2>, %arg18: !ttg.async.token, %arg19: !ttg.async.token, %arg20: !ttg.async.token, %arg21: !ttg.async.token, %arg22: !ttg.memdesc<256x8xi8, #shared, #smem, mutable>, %arg23: !ttg.memdesc<256x8xi8, #shared, #smem, mutable>, %arg24: !ttg.memdesc<256x128xi8, #shared1, #smem, mutable>, %arg25: !ttg.memdesc<128x256xi8, #shared2, #smem, mutable>, %arg26: tensor<256x8xi32, #blocked>, %arg27: tensor<256x8xi32, #blocked>, %arg28: tensor<256x256x!tt.ptr<bf16>, #mma>, %arg29: tensor<256x256xi1, #mma>) attributes {noinline = false} {
1927+
%c63_i32 = arith.constant 63 : i32
1928+
%c2_i32 = arith.constant 2 : i32
1929+
%cst = arith.constant dense<128> : tensor<256x128xi32, #blocked1>
1930+
%cst_0 = arith.constant dense<128> : tensor<128x256xi32, #blocked2>
1931+
%c1_i32 = arith.constant 1 : i32
1932+
%c0_i32 = arith.constant 0 : i32
1933+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
1934+
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x256x128xi8, #shared1, #smem, mutable>
1935+
%1 = ttg.local_alloc : () -> !ttg.memdesc<2x128x256xi8, #shared2, #smem, mutable>
1936+
%2 = ttg.local_alloc : () -> !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable>
1937+
%3 = ttg.local_alloc : () -> !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable>
1938+
%4:14 = scf.for %arg30 = %c0_i32 to %c63_i32 step %c1_i32 iter_args(%arg31 = %cst_1, %arg32 = %arg14, %arg33 = %arg15, %arg34 = %arg16, %arg35 = %arg17, %arg36 = %c0_i32, %arg37 = %arg18, %arg38 = %arg19, %arg39 = %arg20, %arg40 = %arg21, %arg41 = %arg22, %arg42 = %arg23, %arg43 = %arg24, %arg44 = %arg25) -> (tensor<256x256xf32, #mma>, tensor<256x8x!tt.ptr<i8>, #blocked>, tensor<256x8x!tt.ptr<i8>, #blocked>, tensor<256x128x!tt.ptr<i8>, #blocked1>, tensor<128x256x!tt.ptr<i8>, #blocked2>, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<256x8xi8, #shared, #smem, mutable>, !ttg.memdesc<256x8xi8, #shared, #smem, mutable>, !ttg.memdesc<256x128xi8, #shared1, #smem, mutable>, !ttg.memdesc<128x256xi8, #shared2, #smem, mutable>) : i32 {
1939+
%7 = ttg.async_wait %arg37, %arg38, %arg39, %arg40 {num = 0 : i32}
1940+
%8 = tt.addptr %arg34, %cst : tensor<256x128x!tt.ptr<i8>, #blocked1>, tensor<256x128xi32, #blocked1>
1941+
%9 = tt.addptr %arg35, %cst_0 : tensor<128x256x!tt.ptr<i8>, #blocked2>, tensor<128x256xi32, #blocked2>
1942+
%10 = tt.addptr %arg32, %arg26 : tensor<256x8x!tt.ptr<i8>, #blocked>, tensor<256x8xi32, #blocked>
1943+
%11 = tt.addptr %arg33, %arg27 : tensor<256x8x!tt.ptr<i8>, #blocked>, tensor<256x8xi32, #blocked>
1944+
%12 = arith.addi %arg36, %c1_i32 : i32
1945+
%13 = arith.cmpi slt, %12, %c2_i32 : i32
1946+
%14 = arith.select %13, %12, %c0_i32 : i32
1947+
%15 = ttg.memdesc_subview %2[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x8xi8, #shared, #smem, mutable>
1948+
%16 = ttg.async_copy_global_to_local %10, %15 : tensor<256x8x!tt.ptr<i8>, #blocked> -> <256x8xi8, #shared, #smem, mutable>
1949+
%17 = ttg.async_commit_group %16
1950+
%18 = ttg.local_load %arg41 token %7 : !ttg.memdesc<256x8xi8, #shared, #smem, mutable> -> tensor<256x8xi8, #linear>
1951+
%19 = ttg.memdesc_subview %3[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x8xi8, #shared, #smem, mutable>
1952+
%20 = ttg.async_copy_global_to_local %11, %19 : tensor<256x8x!tt.ptr<i8>, #blocked> -> <256x8xi8, #shared, #smem, mutable>
1953+
%21 = ttg.async_commit_group %20
1954+
%22 = ttg.local_load %arg42 token %7 : !ttg.memdesc<256x8xi8, #shared, #smem, mutable> -> tensor<256x8xi8, #linear1>
1955+
%23 = ttg.memdesc_subview %0[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared1, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared1, #smem, mutable>
1956+
%24 = ttg.async_copy_global_to_local %8, %23 : tensor<256x128x!tt.ptr<i8>, #blocked1> -> <256x128xi8, #shared1, #smem, mutable>
1957+
%25 = ttg.async_commit_group %24
1958+
%26 = ttg.local_load %arg43 token %7 : !ttg.memdesc<256x128xi8, #shared1, #smem, mutable> -> tensor<256x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
1959+
%27 = ttg.memdesc_subview %1[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x256xi8, #shared2, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared2, #smem, mutable>
1960+
%28 = ttg.async_copy_global_to_local %9, %27 : tensor<128x256x!tt.ptr<i8>, #blocked2> -> <128x256xi8, #shared2, #smem, mutable>
1961+
%29 = ttg.async_commit_group %28
1962+
%30 = ttg.local_load %arg44 token %7 : !ttg.memdesc<128x256xi8, #shared2, #smem, mutable> -> tensor<128x256xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
1963+
%31 = tt.dot_scaled %26 scale %18, %30 scale %22, %arg31 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<256x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<256x8xi8, #linear> * tensor<128x256xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<256x8xi8, #linear1> -> tensor<256x256xf32, #mma>
1964+
scf.yield %31, %10, %11, %8, %9, %14, %17, %21, %25, %29, %15, %19, %23, %27 : tensor<256x256xf32, #mma>, tensor<256x8x!tt.ptr<i8>, #blocked>, tensor<256x8x!tt.ptr<i8>, #blocked>, tensor<256x128x!tt.ptr<i8>, #blocked1>, tensor<128x256x!tt.ptr<i8>, #blocked2>, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<256x8xi8, #shared, #smem, mutable>, !ttg.memdesc<256x8xi8, #shared, #smem, mutable>, !ttg.memdesc<256x128xi8, #shared1, #smem, mutable>, !ttg.memdesc<128x256xi8, #shared2, #smem, mutable>
1965+
}
1966+
%5 = ttg.async_wait %4#6, %4#7, %4#8, %4#9 {num = 0 : i32}
1967+
ttg.local_dealloc %0 : !ttg.memdesc<2x256x128xi8, #shared1, #smem, mutable>
1968+
ttg.local_dealloc %1 : !ttg.memdesc<2x128x256xi8, #shared2, #smem, mutable>
1969+
ttg.local_dealloc %2 : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable>
1970+
ttg.local_dealloc %3 : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable>
1971+
%6 = arith.truncf %4#0 : tensor<256x256xf32, #mma> to tensor<256x256xbf16, #mma>
1972+
tt.store %arg28, %6, %arg29 : tensor<256x256x!tt.ptr<bf16>, #mma>
1973+
tt.return
1974+
}
1975+
}

test/TritonGPU/amd/mfma-double-rate.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
132132
tt.return
133133
}
134134
}
135+
136+
// -----
137+
138+
// CHECK-LABEL:mxfp4_2step
139+
#linear = #ttg.linear<{register = [[0, 4], [32, 0], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[0, 0], [0, 0], [16, 0]], block = []}>
140+
#linear1 = #ttg.linear<{register = [[0, 4], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[16, 0], [32, 0], [0, 0]], block = []}>
141+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>
142+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
143+
tt.func public @mxfp4_2step(%arg0: tensor<256x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<256x8xi8, #linear>, %arg2: tensor<128x256xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<256x8xi8, #linear1>) {
144+
// CHECK-COUNT-32: rocdl.mfma.scale.f32.16x16x128.f8f6f4
145+
// CHECK: rocdl.sched.barrier 0
146+
// CHECK: rocdl.s.barrier
147+
// CHECK: rocdl.sched.barrier 0
148+
// CHECK-COUNT-32: rocdl.mfma.scale.f32.16x16x128.f8f6f4
149+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
150+
%dots = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false, pingpong_2step} : tensor<256x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<256x8xi8, #linear> * tensor<128x256xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<256x8xi8, #linear1> -> tensor<256x256xf32, #mma>
151+
tt.return
152+
}
153+
}

third_party/amd/backend/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ def make_ttgir(mod, metadata, options):
218218
global_prefetch = knobs.amd.global_prefetch
219219
local_prefetch = knobs.amd.local_prefetch
220220
use_async_copy = knobs.amd.use_async_copy
221+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
221222

222-
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
223+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
224+
use_block_pingpong)
223225
if use_async_copy:
224226
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
225227
passes.common.add_canonicalizer(pm)
@@ -232,8 +234,7 @@ def make_ttgir(mod, metadata, options):
232234
amd.passes.ttgpuir.add_in_thread_transpose(pm)
233235
passes.ttgpuir.add_remove_layout_conversions(pm)
234236
amd.passes.ttgpuir.add_reorder_instructions(pm)
235-
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
236-
if use_block_pingpong and options.num_stages == 2:
237+
if use_block_pingpong and options.num_stages > 1:
237238
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
238239

239240
if knobs.amd.use_buffer_ops:

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod
2626
Option<"useAsyncCopy", "use_async_copy",
2727
"bool", /*default*/"false",
2828
"Use AsyncCopyGlobalToLocal to directly load to shared memory">,
29+
Option<"usePingpong", "use_pingpong",
30+
"bool", /*default*/"false",
31+
"Use schedules to enable block ping-pong">,
2932
];
3033
}
3134

0 commit comments

Comments
 (0)