|
| 1 | +// RUN: triton-opt %s -split-input-file --allow-unregistered-dialect --nvws-assign-stage-phase -canonicalize | FileCheck %s |
| 2 | + |
| 3 | +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> |
| 4 | +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> |
| 5 | +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> |
| 6 | +#smem = #ttg.shared_memory |
| 7 | +module attributes {"ttg.num-warps" = 4 : i32} { |
| 8 | + |
| 9 | + //CHECK-LABEL: @two_consumers |
| 10 | + tt.func @two_consumers(%arg0: i32, %arg1: i32, %arg2: i32) { |
| 11 | + // CHECK: [[C3:%.*]] = arith.constant 3 : i32 |
| 12 | + // CHECK: [[C1:%.*]] = arith.constant 1 : i32 |
| 13 | + // CHECK: [[C2:%.*]] = arith.constant 2 : i32 |
| 14 | + // CHECK: [[C0:%.*]] = arith.constant 0 : i32 |
| 15 | + %ub = arith.constant 4 : i32 |
| 16 | + %c0_i32 = arith.constant 0 : i32 |
| 17 | + %0 = ttg.local_alloc : () -> !ttg.memdesc<3x1xi32, #shared, #smem, mutable> |
| 18 | + // CHECK: [[AREF:%.*]] = nvws.aref.create |
| 19 | + %1 = nvws.aref.create %0 : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> |
| 20 | + // CHECK: [[IDX:%.*]]:6 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[S0:%.*]] = [[C2]], [[P0:%.*]] = [[C0]], [[S1:%.*]] = [[C2]], [[P1:%.*]] = [[C1]], [[S2:%.*]] = [[C2]], [[P2:%.*]] = [[C1]]) |
| 21 | + scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 { |
| 22 | + %2 = "op_a"() {ttg.partition = 0 : i32} : () -> tensor<1xi32, #blocked> |
| 23 | + // CHECK: op_a |
| 24 | + // CHECK-NEXT: [[S0a:%.*]] = arith.addi [[S0]], [[C1]] |
| 25 | + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S0a]], [[C3]] |
| 26 | + // CHECK-NEXT: [[S0b:%.*]] = arith.select [[CMP]], [[C0]], [[S0a]] |
| 27 | + // CHECK-NEXT: [[P0a:%.*]] = arith.xori [[P0]], [[C1]] |
| 28 | + // CHECK-NEXT: [[P0b:%.*]] = arith.select [[CMP]], [[P0a]], [[P0]] |
| 29 | + // CHECK-NEXT: put.enter [[AREF]][[[S0b]], [[P0b]]] |
| 30 | + %buffers, %token = nvws.aref.put.enter %1[%c0_i32, %c0_i32] {ttg.partition = 0 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token |
| 31 | + ttg.local_store %2, %buffers {ttg.partition = 0 : i32} : tensor<1xi32, #blocked> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1> |
| 32 | + // CHECK: put.exit [[AREF]][[[S0b]]] |
| 33 | + nvws.aref.put.exit %1[%c0_i32], %token [#nvws.async_op<none>] {ttg.partition = 0 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>, !ttg.async.token |
| 34 | + |
| 35 | + // CHECK-NEXT: [[S1a:%.*]] = arith.addi [[S1]], [[C1]] |
| 36 | + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S1a]], [[C3]] |
| 37 | + // CHECK-NEXT: [[S1b:%.*]] = arith.select [[CMP]], [[C0]], [[S1a]] |
| 38 | + // CHECK-NEXT: [[P1a:%.*]] = arith.xori [[P1]], [[C1]] |
| 39 | + // CHECK-NEXT: [[P1b:%.*]] = arith.select [[CMP]], [[P1a]], [[P1]] |
| 40 | + // CHECK-NEXT: {{.*}}, [[TOK1:%.*]] = nvws.aref.get.enter [[AREF]][[[S1b]], [[P1b]]] {ttg.partition = 1 : i32} |
| 41 | + %buffers_0, %token_1 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {ttg.partition = 1 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token |
| 42 | + %3 = ttg.local_load %buffers_0 {ttg.partition = 1 : i32} : !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1> -> tensor<1xi32, #blocked> |
| 43 | + // CHECK: get.exit [[AREF]][[[S1b]]], [[TOK1]] [#nvws.async_op<none>] {ttg.partition = 1 : i32} |
| 44 | + nvws.aref.get.exit %1[%c0_i32], %token_1 [#nvws.async_op<none>] {ttg.partition = 1 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>, !ttg.async.token |
| 45 | + "op_b"(%3) {ttg.partition = 1 : i32} : (tensor<1xi32, #blocked>) -> () |
| 46 | + |
| 47 | + // CHECK: op_b |
| 48 | + // CHECK-NEXT: [[S2a:%.*]] = arith.addi [[S2]], [[C1]] |
| 49 | + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S2a]], [[C3]] |
| 50 | + // CHECK-NEXT: [[S2b:%.*]] = arith.select [[CMP]], [[C0]], [[S2a]] |
| 51 | + // CHECK-NEXT: [[P2a:%.*]] = arith.xori [[P2]], [[C1]] |
| 52 | + // CHECK-NEXT: [[P2b:%.*]] = arith.select [[CMP]], [[P2a]], [[P2]] |
| 53 | + // CHECK-NEXT: {{.*}}, [[TOK2:%.*]] = nvws.aref.get.enter [[AREF]][[[S2b]], [[P2b]]] {ttg.partition = 2 : i32} |
| 54 | + %buffers_2, %token_3 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {ttg.partition = 2 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token |
| 55 | + %4 = ttg.local_load %buffers_2 {ttg.partition = 2 : i32} : !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1> -> tensor<1xi32, #blocked> |
| 56 | + // CHECK: get.exit [[AREF]][[[S2b]]], [[TOK2]] [#nvws.async_op<none>] {ttg.partition = 2 : i32} |
| 57 | + nvws.aref.get.exit %1[%c0_i32], %token_3 [#nvws.async_op<none>] {ttg.partition = 2 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>, !ttg.async.token |
| 58 | + "op_c"(%4) {ttg.partition = 2 : i32} : (tensor<1xi32, #blocked>) -> () |
| 59 | + "op_d"(%4) {ttg.partition = 2 : i32} : (tensor<1xi32, #blocked>) -> () |
| 60 | + // CHECK: op_c |
| 61 | + // CHECK-NEXT: op_d |
| 62 | + // CHECK-NEXT: yield [[S0b]], [[P0b]], [[S1b]], [[P1b]], [[S2b]], [[P2b]] |
| 63 | + |
| 64 | + } {ttg.paArtition.stages = [0 : i32, 2 : i32, 2 : i32], ttg.warp_specialize.tag = 0 : i32} |
| 65 | + |
| 66 | + ttg.local_dealloc %0 : !ttg.memdesc<3x1xi32, #shared, #smem, mutable> |
| 67 | + tt.return |
| 68 | + } |
| 69 | + |
| 70 | +} |
| 71 | + |
| 72 | +// ----- |
| 73 | + |
| 74 | +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> |
| 75 | +#smem = #ttg.shared_memory |
| 76 | +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { |
| 77 | + //CHECK-LABEL: @aref_lowering |
| 78 | + tt.func @aref_lowering(%d : !ttg.memdesc<3x64x16xf16, #shared0, #smem>, |
| 79 | + %e : !ttg.memdesc<3x16x32xf16, #shared0, #smem>, |
| 80 | + %cond : i1) { |
| 81 | + // CHECK: [[C3:%.*]] = arith.constant 3 : i32 |
| 82 | + // CHECK: [[C2:%.*]] = arith.constant 2 : i32 |
| 83 | + // CHECK: [[C0:%.*]] = arith.constant 0 : i32 |
| 84 | + // CHECK: [[C1:%.*]] = arith.constant 1 : i32 |
| 85 | + %c0_i32 = arith.constant 0 : i32 |
| 86 | + %c1_i32 = arith.constant 1 : i32 |
| 87 | + %lb = arith.constant 0 : i32 |
| 88 | + %ub = arith.constant 4 : i32 |
| 89 | + |
| 90 | + // CHECK: [[AREF0:%.*]] = nvws.aref.create |
| 91 | + // CHECK-NEXT: [[AREF1:%.*]] = nvws.aref.create |
| 92 | + %aref0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> |
| 93 | + %aref1 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> |
| 94 | + // CHECK: [[IDX:%.*]]:8 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[S0:%.*]] = [[C2]], [[P0:%.*]] = [[C0]], [[S1:%.*]] = [[C2]], [[P1:%.*]] = [[C1]], [[S2:%.*]] = [[C2]], [[P2:%.*]] = [[C0]], [[S3:%.*]] = [[C2]], [[P3:%.*]] = [[C1]]) |
| 95 | + scf.for %i = %lb to %ub step %c1_i32 : i32{ |
| 96 | + // CHECK: [[S0a:%.*]] = arith.addi [[S0]], [[C1]] |
| 97 | + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S0a]], [[C3]] |
| 98 | + // CHECK-NEXT: [[S0b:%.*]] = arith.select [[CMP]], [[C0]], [[S0a]] |
| 99 | + // CHECK-NEXT: [[P0a:%.*]] = arith.xori [[P0]], [[C1]] |
| 100 | + // CHECK-NEXT: [[P0b:%.*]] = arith.select [[CMP]], [[P0a]], [[P0]] |
| 101 | + // CHECK-NEXT: put.enter [[AREF0]][[[S0b]], [[P0b]]] |
| 102 | + %1:3 = nvws.aref.put.enter %aref0[%c0_i32, %c0_i32] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token |
| 103 | + "op1"(%1#0) {ttg.partition = 0 : i32}: (!ttg.memdesc<64x16xf16, #shared0, #smem>) -> () |
| 104 | + "op2"(%1#1) {ttg.partition = 0 : i32} : (!ttg.memdesc<16x32xf16, #shared0, #smem>) -> () |
| 105 | + // CHECK: op2 |
| 106 | + // CHECK-NEXT: put.exit [[AREF0]][[[S0b]]] |
| 107 | + nvws.aref.put.exit %aref0[%c0_i32], %1#2 [#nvws.async_op<tma_load>, #nvws.async_op<none>] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token |
| 108 | + |
| 109 | + |
| 110 | + // CHECK-NEXT: [[S1a:%.*]] = arith.addi [[S1]], [[C1]] |
| 111 | + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S1a]], [[C3]] |
| 112 | + // CHECK-NEXT: [[S1b:%.*]] = arith.select [[CMP]], [[C0]], [[S1a]] |
| 113 | + // CHECK-NEXT: [[P1a:%.*]] = arith.xori [[P1]], [[C1]] |
| 114 | + // CHECK-NEXT: [[P1b:%.*]] = arith.select [[CMP]], [[P1a]], [[P1]] |
| 115 | + // CHECK-NEXT: {{.*}}, [[TOK1:%.*]] = nvws.aref.get.enter [[AREF0]][[[S1b]], [[P1b]]] {ttg.partition = 1 : i32} |
| 116 | + %2:3 = nvws.aref.get.enter %aref0[%c0_i32, %c0_i32] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token |
| 117 | + "op3"(%2#0, %2#1) {ttg.partition = 1 : i32}: (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () |
| 118 | + // CHECK: op3 |
| 119 | + // CHECK-NEXT: get.exit [[AREF0]][[[S1b]]], [[TOK1]] [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32} |
| 120 | + nvws.aref.get.exit %aref0[%c0_i32], %2#2 [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token |
| 121 | + // CHECK: [[IDX1:%.*]]:4 = scf.if |
| 122 | + scf.if %cond { |
| 123 | + // CHECK-NEXT: [[S2a:%.*]] = arith.addi [[S2]], [[C1]] |
| 124 | + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S2a]], [[C3]] |
| 125 | + // CHECK-NEXT: [[S2b:%.*]] = arith.select [[CMP]], [[C0]], [[S2a]] |
| 126 | + // CHECK-NEXT: [[P2a:%.*]] = arith.xori [[P2]], [[C1]] |
| 127 | + // CHECK-NEXT: [[P2b:%.*]] = arith.select [[CMP]], [[P2a]], [[P2]] |
| 128 | + // CHECK-NEXT: {{.*}}, [[TOK2:%.*]] = nvws.aref.put.enter [[AREF1]][[[S2b]], [[P2b]]] {ttg.partition = 0 : i32} |
| 129 | + // CHECK-NEXT: op4 |
| 130 | + // CHECK-NEXT: put.exit [[AREF1]][[[S2b]]] |
| 131 | + %4:3 = nvws.aref.put.enter %aref1[%c0_i32, %c0_i32] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token |
| 132 | + "op4"(%4#0, %4#1) {ttg.partition = 0 : i32} : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () |
| 133 | + nvws.aref.put.exit %aref1[%c0_i32], %4#2 [#nvws.async_op<none>] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token |
| 134 | + // CHECK-NEXT: [[S3a:%.*]] = arith.addi [[S3]], [[C1]] |
| 135 | + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S3a]], [[C3]] |
| 136 | + // CHECK-NEXT: [[S3b:%.*]] = arith.select [[CMP]], [[C0]], [[S3a]] |
| 137 | + // CHECK-NEXT: [[P3a:%.*]] = arith.xori [[P3]], [[C1]] |
| 138 | + // CHECK-NEXT: [[P3b:%.*]] = arith.select [[CMP]], [[P3a]], [[P3]] |
| 139 | + // CHECK-NEXT: {{.*}}, [[TOK3:%.*]] = nvws.aref.get.enter [[AREF1]][[[S3b]], [[P3b]]] {ttg.partition = 1 : i32} |
| 140 | + // CHECK-NEXT: op5 |
| 141 | + // CHECK-NEXT: get.exit [[AREF1]][[[S3b]]], [[TOK3]] [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32} |
| 142 | + %5:3 = nvws.aref.get.enter %aref1[%c0_i32, %c0_i32] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token |
| 143 | + "op5"(%5#0, %5#1) {ttg.partition = 1 : i32}: (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () |
| 144 | + nvws.aref.get.exit %aref1[%c0_i32], %5#2 [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token |
| 145 | + // CHECK-NEXT: yield [[S2b]], [[P2b]], [[S3b]], [[P3b]] |
| 146 | + } |
| 147 | + // CHECK-NEXT: } else { |
| 148 | + // CHECK-NEXT: yield [[S2]], [[P2]], [[S3]], [[P3]] |
| 149 | + // CHECK-NEXT: } |
| 150 | + // CHECK: scf.yield [[S0b]], [[P0b]], [[S1b]], [[P1b]], [[IDX1]]#0, [[IDX1]]#1, [[IDX1]]#2, [[IDX1]]#3 |
| 151 | + |
| 152 | + } {ttg.warp_specialize.tag = 0 : i32} |
| 153 | + tt.return |
| 154 | + } |
| 155 | +} |
0 commit comments