|
| 1 | +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=2 | FileCheck %s |
| 2 | +// CHECK-LABEL: @indirect_load_two_stages |
| 3 | +// CHECK: scf.for |
| 4 | +// CHECK: tt.dot |
| 5 | +// CHECK: tt.load |
| 6 | +// CHECK: async_copy_global_to_local |
| 7 | +// CHECK: async_copy_global_to_local |
| 8 | + |
| 9 | +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> |
| 10 | +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> |
| 11 | +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> |
| 12 | + |
| 13 | +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { |
| 14 | + tt.func public @indirect_load_two_stages(%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32) attributes {noinline = false} { |
| 15 | + %c32_i32 = arith.constant 32 : i32 |
| 16 | + %c16_i32 = arith.constant 16 : i32 |
| 17 | + %cst = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked> |
| 18 | + |
| 19 | + %0 = tt.get_program_id y : i32 |
| 20 | + %1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32 |
| 21 | + %2 = tt.load %1 : !tt.ptr<i64> |
| 22 | + |
| 23 | + %7 = tt.get_program_id x : i32 |
| 24 | + %8 = arith.muli %7, %c16_i32 : i32 |
| 25 | + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> |
| 26 | + %15 = tt.splat %8 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> |
| 27 | + %18 = arith.addi %15, %10 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> |
| 28 | + |
| 29 | + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> |
| 30 | + %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> |
| 31 | + %34 = arith.extsi %arg12 : i32 to i64 |
| 32 | + %35 = arith.muli %2, %34 : i64 |
| 33 | + %36 = tt.addptr %arg2, %35 : !tt.ptr<f32>, i64 |
| 34 | + |
| 35 | + %47 = tt.splat %arg4 : !tt.ptr<i64> -> tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> |
| 36 | + %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> |
| 37 | + |
| 38 | + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> |
| 39 | + %61 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> |
| 40 | + %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> |
| 41 | + |
| 42 | + %85 = arith.extsi %22 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> |
| 43 | + %107 = tt.splat %36 : !tt.ptr<f32> -> tensor<32x128x!tt.ptr<f32>, #blocked3> |
| 44 | + %108 = tt.splat %34 : i64 -> tensor<32x1xi64, #blocked3> |
| 45 | + %109 = tt.broadcast %63 : tensor<1x128xi64, #blocked3> -> tensor<32x128xi64, #blocked3> |
| 46 | + |
| 47 | + %101 = tt.splat %arg5 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked1> |
| 48 | + %111:1 = scf.for %arg28 = %arg18 to %arg19 step %c32_i32 iter_args(%arg29 = %cst) -> (tensor<16x128xf32, #blocked>) : i32 { |
| 49 | + %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> |
| 50 | + %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> |
| 51 | + %161 = tt.load %160 : tensor<32x!tt.ptr<i64>, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> |
| 52 | + %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> |
| 53 | + %163 = tt.broadcast %162 : tensor<1x32xi64, #blocked1> -> tensor<16x32xi64, #blocked1> |
| 54 | + %182 = tt.addptr %101, %163 : tensor<16x32x!tt.ptr<f32>, #blocked1>, tensor<16x32xi64, #blocked1> |
| 55 | + %183 = tt.load %182 : tensor<16x32x!tt.ptr<f32>, #blocked1> |
| 56 | + |
| 57 | + %197 = arith.extsi %arg28 : i32 to i64 |
| 58 | + %198 = tt.splat %197 : i64 -> tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> |
| 59 | + %199 = arith.addi %198, %85 : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> |
| 60 | + %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> |
| 61 | + %201 = arith.muli %200, %108 : tensor<32x1xi64, #blocked3> |
| 62 | + %202 = tt.broadcast %201 : tensor<32x1xi64, #blocked3> -> tensor<32x128xi64, #blocked3> |
| 63 | + %203 = arith.addi %202, %109 : tensor<32x128xi64, #blocked3> |
| 64 | + %204 = tt.addptr %107, %203 : tensor<32x128x!tt.ptr<f32>, #blocked3>, tensor<32x128xi64, #blocked3> |
| 65 | + %209 = tt.load %204 : tensor<32x128x!tt.ptr<f32>, #blocked3> |
| 66 | + |
| 67 | + %210 = triton_gpu.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> |
| 68 | + %211 = triton_gpu.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> |
| 69 | + %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> |
| 70 | + scf.yield %212 : tensor<16x128xf32, #blocked> |
| 71 | + } |
| 72 | + %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> |
| 73 | + %113 = tt.splat %2 : i64 -> tensor<16x1xi64, #blocked3> |
| 74 | + %114 = arith.extsi %112 : tensor<16x1xi32, #blocked3> to tensor<16x1xi64, #blocked3> |
| 75 | + %115 = arith.addi %113, %114 : tensor<16x1xi64, #blocked3> |
| 76 | + %116 = arith.extsi %arg17 : i32 to i64 |
| 77 | + %117 = tt.splat %116 : i64 -> tensor<16x1xi64, #blocked3> |
| 78 | + %118 = arith.muli %115, %117 : tensor<16x1xi64, #blocked3> |
| 79 | + %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> |
| 80 | + %120 = tt.broadcast %118 : tensor<16x1xi64, #blocked3> -> tensor<16x128xi64, #blocked3> |
| 81 | + %121 = arith.extsi %119 : tensor<1x128xi32, #blocked3> to tensor<1x128xi64, #blocked3> |
| 82 | + %122 = tt.broadcast %121 : tensor<1x128xi64, #blocked3> -> tensor<16x128xi64, #blocked3> |
| 83 | + %123 = arith.addi %120, %122 : tensor<16x128xi64, #blocked3> |
| 84 | + %124 = tt.splat %arg7 : !tt.ptr<f32> -> tensor<16x128x!tt.ptr<f32>, #blocked3> |
| 85 | + %125 = tt.addptr %124, %123 : tensor<16x128x!tt.ptr<f32>, #blocked3>, tensor<16x128xi64, #blocked3> |
| 86 | + %128 = triton_gpu.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> |
| 87 | + tt.store %125, %128 : tensor<16x128x!tt.ptr<f32>, #blocked3> |
| 88 | + tt.return |
| 89 | + } |
| 90 | +} |
0 commit comments