-
Notifications
You must be signed in to change notification settings - Fork 68
Open
Copy link
Description
Given the following test case:
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg5: i32) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #dpas>
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dpas>>
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dpas>>
// Issue: Causes additional block ptr to be allocated and passed to the loop, increasing register pressure in the loop.
// To avoid this problem we can reuse %23#1 (which will have dot layout because of the tl.dot operation in the loop), by injecting a convert layout op to
// convert that value to a blocked layout.
%28 = tt.load %23#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
tt.store %29, %28 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
tt.return
}
}
The code generated by the RemoveLAyoutConversion pass is:
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #mma>
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%1 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%2 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%3:4 = scf.for %arg4 = %c0_i32 to %arg3 step %c32_i32 iter_args(%arg5 = %cst, %arg6 = %1, %arg7 = %0, %arg8 = %2) -> (tensor<64x256xf32, #mma>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>) : i32 {
%8 = tt.load %arg7 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%9 = tt.load %arg8 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%10 = tt.dot %8, %9, %arg5, inputPrecision = tf32 : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xf32, #mma>
%11 = tt.advance %arg7, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%12 = tt.advance %arg6, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
%13 = tt.advance %arg8, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
scf.yield %10, %12, %11, %13 : tensor<64x256xf32, #mma>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
}
%4 = arith.truncf %3#0 : tensor<64x256xf32, #mma> to tensor<64x256xf16, #mma>
%5 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #mma>>
tt.store %5, %4 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #mma>>
%6 = tt.load %3#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
%7 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
tt.store %7, %6 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
tt.return
}
}
We can observe that one additional loop carried value (%12
) has been added to the loop, potentially increasing register pressure in the loop. The additional value is added because a load op (`%6) uses a return value from the loop which has dot layout while the load expect a blocked layout.
To reproduce:
/triton-opt ~/tmp/test1.mlir --tritonintelgpu-remove-layout-conversions --debug-only=tritonintelgpu-remove-layout-conversions