|
| 1 | +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s |
| 2 | + |
| 3 | +// Check the optimized LLVMIR, since InstCombine makes the linear layout |
| 4 | +// logic understandable enough (in simple cases) to check correctness by eye. |
| 5 | + |
| 6 | +#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> |
| 7 | +#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}> |
| 8 | +#broadcasted_lane_1d = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> |
| 9 | +#broadcasted_warp_2d = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> |
| 10 | + |
| 11 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { |
| 12 | + |
| 13 | +// CHECK-LABEL: @gather_2d_crazy |
| 14 | +tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1: tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> { |
| 15 | + // The specific logic becomes hard to grasp here. Just check the shuffles. |
| 16 | + |
| 17 | + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float, float, float } %1, 0 |
| 18 | + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float, float, float } %1, 1 |
| 19 | + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue { float, float, float, float } %1, 2 |
| 20 | + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue { float, float, float, float } %1, 3 |
| 21 | + |
| 22 | + // CHECK: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 |
| 23 | + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], |
| 24 | + // CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float |
| 25 | + // CHECK-NEXT: {{%.*}} = icmp eq i32 |
| 26 | + // CHECK-NEXT: {{%.*}} = select i1 |
| 27 | + // CHECK-NEXT: [[VALUE2:%.*]] = bitcast float [[SRC2]] to i32 |
| 28 | + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], |
| 29 | + |
| 30 | + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], |
| 31 | + // CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float |
| 32 | + // CHECK-NEXT: {{%.*}} = icmp eq i32 |
| 33 | + // CHECK-NEXT: {{%.*}} = select i1 |
| 34 | + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], |
| 35 | + |
| 36 | + // CHECK: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 |
| 37 | + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], |
| 38 | + // CHECK-NEXT: [[VALUE3:%.*]] = bitcast float [[SRC3]] to i32 |
| 39 | + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], |
| 40 | + // CHECK-NEXT: {{%.*}} = icmp eq i32 |
| 41 | + // CHECK-NEXT: {{%.*}} = select i1 |
| 42 | + // CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float |
| 43 | + |
| 44 | + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], |
| 45 | + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], |
| 46 | + // CHECK-NEXT: {{%.*}} = icmp eq i32 |
| 47 | + // CHECK-NEXT: {{%.*}} = select i1 |
| 48 | + // CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float |
| 49 | + |
| 50 | + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x16xf32, #crazy_2d_src>, tensor<32x16xi32, #crazy_2d_idx>) -> tensor<32x16xf32, #crazy_2d_idx> |
| 51 | + tt.return %0 : tensor<32x16xf32, #crazy_2d_idx> |
| 52 | +} |
| 53 | + |
| 54 | +// There are 16 elements in the tensor. For each warp, each half-warp is mapped |
| 55 | +// to the 16 elements, so it doesn't matter if the second half [16, 32) indexes |
| 56 | +// into [0, 16), since they contain the same data. |
| 57 | +// CHECK-LABEL: @gather_broadcasted_lane_1d |
| 58 | +tt.func private @gather_broadcasted_lane_1d(%arg0: tensor<16xi32, #broadcasted_lane_1d>, %arg1: tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> { |
| 59 | + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 |
| 60 | + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 |
| 61 | + |
| 62 | + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 15 |
| 63 | + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 |
| 64 | + // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) |
| 65 | + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<16xf32, #broadcasted_lane_1d>, tensor<16xi32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> |
| 66 | + |
| 67 | + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float |
| 68 | + // CHECK-NEXT: ret float [[RES]] |
| 69 | + tt.return %0 : tensor<16xf32, #broadcasted_lane_1d> |
| 70 | +} |
| 71 | + |
| 72 | +// Single gather column with 64 elements, all of which have to fit into a single |
| 73 | +// warp, so the whole column is broadcasted across the 4 warps. Each process the |
| 74 | +// same data so the warp doesn't matter. |
| 75 | +// CHECK-LABEL: @gather_broadcasted_warp_2d |
| 76 | +tt.func private @gather_broadcasted_warp_2d(%arg0: tensor<64x1xi32, #broadcasted_warp_2d>, %arg1: tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> { |
| 77 | + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 |
| 78 | + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 |
| 79 | + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 |
| 80 | + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 |
| 81 | + |
| 82 | + // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1 |
| 83 | + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1 |
| 84 | + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 31 |
| 85 | + |
| 86 | + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 |
| 87 | + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31) |
| 88 | + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 |
| 89 | + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID0]], i32 31) |
| 90 | + |
| 91 | + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0 |
| 92 | + // CHECK-NEXT: select i1 [[PICK0]], i32 [[RES0_i32]], i32 [[RES1_i32]] |
| 93 | + |
| 94 | + // CHECK: [[REGID1:%.*]] = and i32 [[IDX1]], 1 |
| 95 | + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1 |
| 96 | + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 31 |
| 97 | + |
| 98 | + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID1]], i32 31) |
| 99 | + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31) |
| 100 | + |
| 101 | + // CHECK-NEXT: [[PICK1:%.*]] = icmp eq i32 [[REGID1]], 0 |
| 102 | + // CHECK-NEXT: select i1 [[PICK1]], i32 [[RES0_i32]], i32 [[RES1_i32]] |
| 103 | + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64x1xf32, #broadcasted_warp_2d>, tensor<64x1xi32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> |
| 104 | + tt.return %0 : tensor<64x1xf32, #broadcasted_warp_2d> |
| 105 | +} |
| 106 | + |
| 107 | +// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM |
| 108 | +// from removing unused function results. |
| 109 | +tt.func @anchor_warp4(%ptr: !llvm.ptr, |
| 110 | + %arg9: tensor<32x16xi32, #crazy_2d_idx>, |
| 111 | + %arg10: tensor<32x16xf32, #crazy_2d_src>, |
| 112 | + %arg11: tensor<16xi32, #broadcasted_lane_1d>, |
| 113 | + %arg12: tensor<16xf32, #broadcasted_lane_1d>, |
| 114 | + %arg13: tensor<64x1xi32, #broadcasted_warp_2d>, |
| 115 | + %arg14: tensor<64x1xf32, #broadcasted_warp_2d>) { |
| 116 | + |
| 117 | + %12 = tt.call @gather_2d_crazy(%arg9, %arg10) : (tensor<32x16xi32, #crazy_2d_idx>, tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> |
| 118 | + %13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)> |
| 119 | + llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr |
| 120 | + |
| 121 | + %14 = tt.call @gather_broadcasted_lane_1d(%arg11, %arg12) : (tensor<16xi32, #broadcasted_lane_1d>, tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> |
| 122 | + %15 = builtin.unrealized_conversion_cast %14 : tensor<16xf32, #broadcasted_lane_1d> to !llvm.struct<(f32)> |
| 123 | + llvm.store volatile %15, %ptr : !llvm.struct<(f32)>, !llvm.ptr |
| 124 | + |
| 125 | + %16 = tt.call @gather_broadcasted_warp_2d(%arg13, %arg14) : (tensor<64x1xi32, #broadcasted_warp_2d>, tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> |
| 126 | + %17 = builtin.unrealized_conversion_cast %16 : tensor<64x1xf32, #broadcasted_warp_2d> to !llvm.struct<(f32, f32)> |
| 127 | + llvm.store volatile %17, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr |
| 128 | + |
| 129 | + tt.return |
| 130 | +} |
| 131 | + |
| 132 | +} |
0 commit comments