Skip to content

Commit 03cb38c

Browse files
Merge commit '3cb3e693aefd8ca6f1021f3ddec098e07e3ab4ed'
2 parents b70c7f7 + 3cb3e69 commit 03cb38c

File tree

6 files changed

+86
-14
lines changed

6 files changed

+86
-14
lines changed

include/triton/Tools/LinearLayout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,10 @@ class LinearLayout {
679679

680680
// Get the layout that is the inverse of this layout.
681681
[[nodiscard]] LinearLayout invert() const;
682+
// Compute and return a psueodinverse of this layout. This is a layout such
683+
// that `B = A.psuedoinvert()` implies that `A(B(x)) = I`. If `A` is
684+
// invertible, then this returns `A^-1`.
685+
[[nodiscard]] LinearLayout pseudoinvert() const;
682686

683687
// For each in-dim, returns a bitmask of the "free variables" in the layout
684688
// function.

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -497,13 +497,8 @@ bool GatherLoweringHelper::isWarpLocal() {
497497
// in the index and source tensors are the same. This means we don't need to
498498
// xor shuffle across threads before emitting index shuffles; we push warp
499499
// shuffling to layout conversions.
500-
if (srcLayout->sublayout(kLane, otherDims) !=
501-
idxLayout->sublayout(kLane, otherDims))
502-
return false;
503-
504-
// Otherwise, the source layout has to be invertible. This primarily means
505-
// the codegen path doesn't support broadcasted source layouts.
506-
return srcLayout->isInvertible();
500+
return srcLayout->sublayout(kLane, otherDims) ==
501+
idxLayout->sublayout(kLane, otherDims);
507502
}
508503

509504
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {

lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,10 @@ void GatherOpConversion::emitWarpLocalGather(
240240
// `llvm.select` using `src_reg` to get the right one. `K` is the number of
241241
// elements per column owned by a thread.
242242

243-
// Fully invert the source layout. We know it is invertible because
244-
// `isWarpLocal` checked this.
245-
LinearLayout invSrcLayout = srcLayout.invert();
243+
// Invert the source layout. It doesn't matter whether it is fully invertible
244+
// with respect to anything except the register input dimension, since we know
245+
// those don't vary in ways that matter for codegen.
246+
LinearLayout invSrcLayout = srcLayout.pseudoinvert();
246247

247248
// Sanity check: the warp must be invariant to the index because otherwise the
248249
// gather would need to read across warps!

lib/Tools/LinearLayout.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,9 +920,13 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const {
920920
}
921921

922922
LinearLayout LinearLayout::invert() const {
923-
// A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A)
924923
assert(isInvertible() &&
925924
"A linear layout must be surjective and square to be invertible");
925+
return pseudoinvert();
926+
}
927+
928+
LinearLayout LinearLayout::pseudoinvert() const {
929+
// A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A)
926930
LinearLayout identity = LinearLayout::empty();
927931
for (auto outDim : getOutDimNames()) {
928932
identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim);

test/Conversion/allocate_shared_memory.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: triton-opt %s --allocate-shared-memory | FileCheck %s
22

3-
#blocked = #ttg.blocked<{sizePerThread = [32, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
44

55
// CHECK-LABEL: module
66
// CHECK-SAME: ttg.shared = 131072 : i32

test/Conversion/gather_to_llvm.mlir

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>
1717
#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>
1818

19-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
19+
#broadcasted_lane_1d = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
20+
#broadcasted_warp_2d = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
21+
22+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
2023

2124
// Each source element is mapped to a single thread, so we expect one index shuffle.
2225
// CHECK-LABEL: @gather_warp_local_trivial
@@ -222,6 +225,59 @@ tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1:
222225
tt.return %0 : tensor<32x16xf32, #crazy_2d_idx>
223226
}
224227

228+
// There are 16 elements in the tensor. For each warp, each half-warp is mapped
229+
// to the 16 elements, so it doesn't matter if the second half [16, 32) indexes
230+
// into [0, 16), since they contain the same data.
231+
// CHECK-LABEL: @gather_broadcasted_lane_1d
232+
tt.func private @gather_broadcasted_lane_1d(%arg0: tensor<16xi32, #broadcasted_lane_1d>, %arg1: tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> {
233+
// CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0
234+
// CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0
235+
236+
// CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 15
237+
// CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32
238+
// CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31)
239+
%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<16xf32, #broadcasted_lane_1d>, tensor<16xi32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d>
240+
241+
// CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float
242+
// CHECK-NEXT: ret float [[RES]]
243+
tt.return %0 : tensor<16xf32, #broadcasted_lane_1d>
244+
}
245+
246+
// Single gather column with 64 elements, all of which have to fit into a single
247+
// warp, so the whole column is broadcasted across the 4 warps. Each process the
248+
// same data so the warp doesn't matter.
249+
// CHECK-LABEL: @gather_broadcasted_warp_2d
250+
tt.func private @gather_broadcasted_warp_2d(%arg0: tensor<64x1xi32, #broadcasted_warp_2d>, %arg1: tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> {
251+
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0
252+
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1
253+
// CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0
254+
// CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1
255+
256+
// CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1
257+
// CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1
258+
// CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 31
259+
260+
// CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32
261+
// CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31)
262+
// CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32
263+
// CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID0]], i32 31)
264+
265+
// CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0
266+
// CHECK-NEXT: select i1 [[PICK0]], i32 [[RES0_i32]], i32 [[RES1_i32]]
267+
268+
// CHECK: [[REGID1:%.*]] = and i32 [[IDX1]], 1
269+
// CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1
270+
// CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 31
271+
272+
// CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID1]], i32 31)
273+
// CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31)
274+
275+
// CHECK-NEXT: [[PICK1:%.*]] = icmp eq i32 [[REGID1]], 0
276+
// CHECK-NEXT: select i1 [[PICK1]], i32 [[RES0_i32]], i32 [[RES1_i32]]
277+
%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64x1xf32, #broadcasted_warp_2d>, tensor<64x1xi32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d>
278+
tt.return %0 : tensor<64x1xf32, #broadcasted_warp_2d>
279+
}
280+
225281
// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM
226282
// from removing unused function results.
227283
tt.func @anchor(%ptr: !llvm.ptr,
@@ -235,7 +291,11 @@ tt.func @anchor(%ptr: !llvm.ptr,
235291
%arg7: tensor<32x2xi32, #span_2d_cols>,
236292
%arg8: tensor<32x2xf32, #span_2d_cols>,
237293
%arg9: tensor<32x16xi32, #crazy_2d_idx>,
238-
%arg10: tensor<32x16xf32, #crazy_2d_src>) {
294+
%arg10: tensor<32x16xf32, #crazy_2d_src>,
295+
%arg11: tensor<16xi32, #broadcasted_lane_1d>,
296+
%arg12: tensor<16xf32, #broadcasted_lane_1d>,
297+
%arg13: tensor<64x1xi32, #broadcasted_warp_2d>,
298+
%arg14: tensor<64x1xf32, #broadcasted_warp_2d>) {
239299

240300
%0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout>
241301
%1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)>
@@ -265,6 +325,14 @@ tt.func @anchor(%ptr: !llvm.ptr,
265325
%13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)>
266326
llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr
267327

328+
%14 = tt.call @gather_broadcasted_lane_1d(%arg11, %arg12) : (tensor<16xi32, #broadcasted_lane_1d>, tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d>
329+
%15 = builtin.unrealized_conversion_cast %14 : tensor<16xf32, #broadcasted_lane_1d> to !llvm.struct<(f32)>
330+
llvm.store volatile %15, %ptr : !llvm.struct<(f32)>, !llvm.ptr
331+
332+
%16 = tt.call @gather_broadcasted_warp_2d(%arg13, %arg14) : (tensor<64x1xi32, #broadcasted_warp_2d>, tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d>
333+
%17 = builtin.unrealized_conversion_cast %16 : tensor<64x1xf32, #broadcasted_warp_2d> to !llvm.struct<(f32, f32)>
334+
llvm.store volatile %17, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr
335+
268336
tt.return
269337
}
270338

0 commit comments

Comments
 (0)