|
1 | 1 | // RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s |
2 | 2 |
|
3 | | -// Check that we order load, local_alloc, local_store (optional) and local_load one after another. This is useful |
4 | | -// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers |
| 3 | +// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands |
| 4 | +// in cases where local_alloc is in the loop but it's operand is not. |
| 5 | +// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers |
5 | 6 | // throughout the computation. |
6 | | -// CHECK-LABEL: order_load_alloc_local_load |
7 | | -// CHECK: %[[LOAD:.+]] = tt.load |
8 | | -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]] |
9 | | -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] |
10 | | -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> |
11 | | -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> |
12 | | -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> |
13 | | -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { |
14 | | - tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked>) attributes {noinline = false} { |
15 | | - %9 = tt.load %arg0 : tensor<32x32x!tt.ptr<f32>, #blocked> |
16 | | - %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> |
17 | | - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> |
18 | | - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> |
19 | | - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> |
20 | | - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> |
21 | | - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> |
22 | | - tt.store %arg0, %13 : tensor<32x32x!tt.ptr<f32>, #blocked> |
| 7 | + |
| 8 | +// CHECK-LABEL: hoist_q_out_of_the_loop |
| 9 | +// CHECK: %[[TRUNCF:.+]] = arith.truncf |
| 10 | +// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] |
| 11 | +// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] |
| 12 | +// CHECK: scf.for |
| 13 | +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> |
| 14 | +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> |
| 15 | +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> |
| 16 | +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> |
| 17 | +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { |
| 18 | + tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 19 | + %c0_i32 = arith.constant 0 : i32 |
| 20 | + %cst = arith.constant 1.44269502 : f32 |
| 21 | + %c128_i32 = arith.constant 128 : i32 |
| 22 | + %c128_i64 = arith.constant 128 : i64 |
| 23 | + %c0_i64 = arith.constant 0 : i64 |
| 24 | + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> |
| 25 | + %1 = tt.get_program_id y : i32 |
| 26 | + %2 = arith.muli %1, %arg7 : i32 |
| 27 | + %3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32 |
| 28 | + %12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1> |
| 29 | + %41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1> |
| 30 | + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> |
| 31 | + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> |
| 32 | + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> |
| 33 | + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> |
| 34 | + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { |
| 35 | + %73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2> |
| 36 | + %74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2> |
| 37 | + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> |
| 38 | + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> |
| 39 | + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> |
| 40 | + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> |
| 41 | + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> |
| 42 | + %107 = arith.addi %arg26, %c128_i64 : i64 |
| 43 | + scf.yield %107 : i64 |
| 44 | + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} |
23 | 45 | tt.return |
24 | 46 | } |
25 | 47 | } |
| 48 | + |
| 49 | + |
| 50 | +// ----- |
| 51 | +// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both |
| 52 | +// local_alloc and it's src tensor defining op are in the loop. |
| 53 | +// CHECK-LABEL: no_hoist_q_type_reordering |
| 54 | +// CHECK: scf.for |
| 55 | +// CHECK: %[[TRUNCF:.+]] = arith.truncf |
| 56 | +// CHECK-NEXT: arith.constant |
| 57 | +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> |
| 58 | +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> |
| 59 | +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> |
| 60 | +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> |
| 61 | +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { |
| 62 | + tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 63 | + %c0_i32 = arith.constant 0 : i32 |
| 64 | + %cst = arith.constant 1.44269502 : f32 |
| 65 | + %c128_i32 = arith.constant 128 : i32 |
| 66 | + %c128_i64 = arith.constant 128 : i64 |
| 67 | + %c0_i64 = arith.constant 0 : i64 |
| 68 | + %1 = tt.get_program_id y : i32 |
| 69 | + %2 = arith.muli %1, %arg7 : i32 |
| 70 | + %3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32 |
| 71 | + %12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1> |
| 72 | + %41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1> |
| 73 | + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> |
| 74 | + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> |
| 75 | + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> |
| 76 | + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { |
| 77 | + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> |
| 78 | + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> |
| 79 | + %73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2> |
| 80 | + %74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2> |
| 81 | + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> |
| 82 | + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> |
| 83 | + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> |
| 84 | + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> |
| 85 | + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> |
| 86 | + %107 = arith.addi %arg26, %c128_i64 : i64 |
| 87 | + scf.yield %107 : i64 |
| 88 | + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} |
| 89 | + tt.return |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +// ----- |
| 94 | +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> |
| 95 | +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> |
| 96 | +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> |
| 97 | + |
26 | 98 | // CHECK-LABEL: order_load_alloc_local_load_local_store |
27 | 99 | // CHECK: %[[LOAD:.+]] = tt.load |
28 | 100 | // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc |
|
0 commit comments