Skip to content

Commit 3cb3e69

Browse files
authored
[Backend] Support broadcasted layouts in warp shuffle gather codegen (#5395)
The original implementation side-skirts the issue of dealing with broadcasting layouts. In trying to implement layout selection for gather in the middle end, I found that it's pretty common to have tensors that are too small to avoid broadcasting. E.g. for 4 warps and 32 threads, the tensor needs to have at least 128 elements. It turns out that "enabling" broadcasting support is pretty trivial. Since we know that in a broadcasted layout, broadcasted threads can just index into any other "group" of threads mapped to the same gather column, and that the codegen does not vary based on the broadcasted warps, we can use the pseudo-inverse of the source layout, regardless of what it is computed to.
1 parent 6f5baf6 commit 3cb3e69

File tree

6 files changed

+86
-14
lines changed

6 files changed

+86
-14
lines changed

include/triton/Tools/LinearLayout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,10 @@ class LinearLayout {
679679

680680
// Get the layout that is the inverse of this layout.
681681
[[nodiscard]] LinearLayout invert() const;
682+
// Compute and return a psueodinverse of this layout. This is a layout such
683+
// that `B = A.psuedoinvert()` implies that `A(B(x)) = I`. If `A` is
684+
// invertible, then this returns `A^-1`.
685+
[[nodiscard]] LinearLayout pseudoinvert() const;
682686

683687
// For each in-dim, returns a bitmask of the "free variables" in the layout
684688
// function.

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,8 @@ bool GatherLoweringHelper::isWarpLocal() {
491491
// in the index and source tensors are the same. This means we don't need to
492492
// xor shuffle across threads before emitting index shuffles; we push warp
493493
// shuffling to layout conversions.
494-
if (srcLayout->sublayout(kLane, otherDims) !=
495-
idxLayout->sublayout(kLane, otherDims))
496-
return false;
497-
498-
// Otherwise, the source layout has to be invertible. This primarily means
499-
// the codegen path doesn't support broadcasted source layouts.
500-
return srcLayout->isInvertible();
494+
return srcLayout->sublayout(kLane, otherDims) ==
495+
idxLayout->sublayout(kLane, otherDims);
501496
}
502497

503498
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {

lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,10 @@ void GatherOpConversion::emitWarpLocalGather(
240240
// `llvm.select` using `src_reg` to get the right one. `K` is the number of
241241
// elements per column owned by a thread.
242242

243-
// Fully invert the source layout. We know it is invertible because
244-
// `isWarpLocal` checked this.
245-
LinearLayout invSrcLayout = srcLayout.invert();
243+
// Invert the source layout. It doesn't matter whether it is fully invertible
244+
// with respect to anything except the register input dimension, since we know
245+
// those don't vary in ways that matter for codegen.
246+
LinearLayout invSrcLayout = srcLayout.pseudoinvert();
246247

247248
// Sanity check: the warp must be invariant to the index because otherwise the
248249
// gather would need to read across warps!

lib/Tools/LinearLayout.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,9 +957,13 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const {
957957
}
958958

959959
LinearLayout LinearLayout::invert() const {
960-
// A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A)
961960
assert(isInvertible() &&
962961
"A linear layout must be surjective and square to be invertible");
962+
return pseudoinvert();
963+
}
964+
965+
LinearLayout LinearLayout::pseudoinvert() const {
966+
// A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A)
963967
LinearLayout identity = LinearLayout::empty();
964968
for (auto outDim : getOutDimNames()) {
965969
identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim);

test/Conversion/allocate_shared_memory.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: triton-opt %s --allocate-shared-memory | FileCheck %s
22

3-
#blocked = #ttg.blocked<{sizePerThread = [32, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
44

55
// CHECK-LABEL: module
66
// CHECK-SAME: ttg.shared = 131072 : i32

test/Conversion/gather_to_llvm.mlir

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>
1717
#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>
1818

19-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
19+
#broadcasted_lane_1d = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
20+
#broadcasted_warp_2d = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
21+
22+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
2023

2124
// Each source element is mapped to a single thread, so we expect one index shuffle.
2225
// CHECK-LABEL: @gather_warp_local_trivial
@@ -222,6 +225,59 @@ tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1:
222225
tt.return %0 : tensor<32x16xf32, #crazy_2d_idx>
223226
}
224227

228+
// There are 16 elements in the tensor. For each warp, each half-warp is mapped
229+
// to the 16 elements, so it doesn't matter if the second half [16, 32) indexes
230+
// into [0, 16), since they contain the same data.
231+
// CHECK-LABEL: @gather_broadcasted_lane_1d
232+
tt.func private @gather_broadcasted_lane_1d(%arg0: tensor<16xi32, #broadcasted_lane_1d>, %arg1: tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> {
233+
// CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0
234+
// CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0
235+
236+
// CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 15
237+
// CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32
238+
// CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31)
239+
%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<16xf32, #broadcasted_lane_1d>, tensor<16xi32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d>
240+
241+
// CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float
242+
// CHECK-NEXT: ret float [[RES]]
243+
tt.return %0 : tensor<16xf32, #broadcasted_lane_1d>
244+
}
245+
246+
// Single gather column with 64 elements, all of which have to fit into a single
247+
// warp, so the whole column is broadcasted across the 4 warps. Each process the
248+
// same data so the warp doesn't matter.
249+
// CHECK-LABEL: @gather_broadcasted_warp_2d
250+
tt.func private @gather_broadcasted_warp_2d(%arg0: tensor<64x1xi32, #broadcasted_warp_2d>, %arg1: tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> {
251+
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0
252+
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1
253+
// CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0
254+
// CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1
255+
256+
// CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1
257+
// CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1
258+
// CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 31
259+
260+
// CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32
261+
// CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31)
262+
// CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32
263+
// CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID0]], i32 31)
264+
265+
// CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0
266+
// CHECK-NEXT: select i1 [[PICK0]], i32 [[RES0_i32]], i32 [[RES1_i32]]
267+
268+
// CHECK: [[REGID1:%.*]] = and i32 [[IDX1]], 1
269+
// CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1
270+
// CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 31
271+
272+
// CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID1]], i32 31)
273+
// CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31)
274+
275+
// CHECK-NEXT: [[PICK1:%.*]] = icmp eq i32 [[REGID1]], 0
276+
// CHECK-NEXT: select i1 [[PICK1]], i32 [[RES0_i32]], i32 [[RES1_i32]]
277+
%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64x1xf32, #broadcasted_warp_2d>, tensor<64x1xi32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d>
278+
tt.return %0 : tensor<64x1xf32, #broadcasted_warp_2d>
279+
}
280+
225281
// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM
226282
// from removing unused function results.
227283
tt.func @anchor(%ptr: !llvm.ptr,
@@ -235,7 +291,11 @@ tt.func @anchor(%ptr: !llvm.ptr,
235291
%arg7: tensor<32x2xi32, #span_2d_cols>,
236292
%arg8: tensor<32x2xf32, #span_2d_cols>,
237293
%arg9: tensor<32x16xi32, #crazy_2d_idx>,
238-
%arg10: tensor<32x16xf32, #crazy_2d_src>) {
294+
%arg10: tensor<32x16xf32, #crazy_2d_src>,
295+
%arg11: tensor<16xi32, #broadcasted_lane_1d>,
296+
%arg12: tensor<16xf32, #broadcasted_lane_1d>,
297+
%arg13: tensor<64x1xi32, #broadcasted_warp_2d>,
298+
%arg14: tensor<64x1xf32, #broadcasted_warp_2d>) {
239299

240300
%0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout>
241301
%1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)>
@@ -265,6 +325,14 @@ tt.func @anchor(%ptr: !llvm.ptr,
265325
%13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)>
266326
llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr
267327

328+
%14 = tt.call @gather_broadcasted_lane_1d(%arg11, %arg12) : (tensor<16xi32, #broadcasted_lane_1d>, tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d>
329+
%15 = builtin.unrealized_conversion_cast %14 : tensor<16xf32, #broadcasted_lane_1d> to !llvm.struct<(f32)>
330+
llvm.store volatile %15, %ptr : !llvm.struct<(f32)>, !llvm.ptr
331+
332+
%16 = tt.call @gather_broadcasted_warp_2d(%arg13, %arg14) : (tensor<64x1xi32, #broadcasted_warp_2d>, tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d>
333+
%17 = builtin.unrealized_conversion_cast %16 : tensor<64x1xf32, #broadcasted_warp_2d> to !llvm.struct<(f32, f32)>
334+
llvm.store volatile %17, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr
335+
268336
tt.return
269337
}
270338

0 commit comments

Comments
 (0)