|
| 1 | +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s |
| 2 | + |
| 3 | +#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> |
| 4 | + |
| 5 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { |
| 6 | + |
| 7 | +// CHECK-LABEL: @reduce_linear_layout |
| 8 | +tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> { |
| 9 | + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0 |
| 10 | + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1 |
| 11 | + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2 |
| 12 | + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3 |
| 13 | + |
| 14 | + // The layout looks lke |
| 15 | + // [[ T0:0, T32:0, T0:1, T32:1, ... |
| 16 | + // [ T4:0, T36:0, T4:1, T36:1, ... |
| 17 | + // [ T0:2, T32:2, T0:3, T32:3, ... |
| 18 | + // [ T4:2, T36:2, T4:3, T36:3, |
| 19 | + // ... |
| 20 | + // |
| 21 | + // A reduction along axis=0 consists of adding registers (0, 2) and (1, 3) |
| 22 | + // before shuffling. |
| 23 | + // |
| 24 | + // Columns along axis=0 are contained within a warp, so reduction arcoss warps |
| 25 | + // is not needed. |
| 26 | + |
| 27 | + // Reduce within threads |
| 28 | + // CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]] |
| 29 | + // CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]] |
| 30 | + |
| 31 | + // Reduce within warp. |
| 32 | + // CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31) |
| 33 | + // CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]] |
| 34 | + // CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31) |
| 35 | + // CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]] |
| 36 | + // CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31) |
| 37 | + // CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]] |
| 38 | + // CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31) |
| 39 | + // CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]] |
| 40 | + |
| 41 | + // CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31) |
| 42 | + // CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]] |
| 43 | + // CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31) |
| 44 | + // CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]] |
| 45 | + // CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31) |
| 46 | + // CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]] |
| 47 | + // CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31) |
| 48 | + // CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]] |
| 49 | + |
| 50 | + // CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0 |
| 51 | + // CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1 |
| 52 | + |
| 53 | + %0 = "tt.reduce"(%arg0) ({ |
| 54 | + ^bb0(%arg1: i32, %arg2: i32): |
| 55 | + %1 = arith.addi %arg1, %arg2 : i32 |
| 56 | + tt.reduce.return %1 : i32 |
| 57 | + }) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> |
| 58 | + |
| 59 | + // CHECK-NEXT: ret { i32, i32 } [[DST1]] |
| 60 | + tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> |
| 61 | +} |
| 62 | + |
| 63 | +tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) { |
| 64 | + %0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> |
| 65 | + %1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)> |
| 66 | + llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr |
| 67 | + tt.return |
| 68 | +} |
| 69 | + |
| 70 | +} |
0 commit comments