Skip to content

Commit 2c0b791

Browse files
authored
[Triton] Add tl.gather with a naive codegen implementation (#5262)
This PR adds a `tl.gather` builtin that implements a local gather along a single axis, with semantics matching `torch.gather`. `tl.gather` generates a `tt.gather` op, which is piped through the compiler mostly untouched at the moment, since the codegen is very naive. The `tt.gather` is implemented by writing the source tensor into shared memory and then performing a gather out of shared memory, thus it requires scratch space to be allocated. In a follow-up, I will implement an optimized layout rule for the op that ensures the gather axis fits into a single warp, allowing the gather to be implemented using warp shuffles. There are other avenues for optimization as well: `tt.gather(tt.load)` where the load only has one use can be lowered into a DMA from global memory to shared, and then gather directly from shared.
1 parent b8a4b87 commit 2c0b791

File tree

26 files changed

+518
-20
lines changed

26 files changed

+518
-20
lines changed

include/triton/Analysis/Utility.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,19 @@ class ScanLoweringHelper {
153153
SmallVector<Type> srcElementTypes;
154154
};
155155

156+
// Helper class for lowering `tt.gather` operations. This class shares lowering
157+
// logic between shared memory allocation and LLVM codegen.
158+
class GatherLoweringHelper {
159+
public:
160+
GatherLoweringHelper(triton::GatherOp gatherOp);
161+
162+
// Get the shared memory scratch size required by this op.
163+
unsigned getScratchSizeInBytes();
164+
165+
private:
166+
triton::GatherOp gatherOp;
167+
};
168+
156169
// Decomposes a reshape into simpler pieces.
157170
//
158171
// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2].

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
9292
RewritePatternSet &patterns,
9393
const TargetInfoBase &targetInfo,
9494
PatternBenefit benefit);
95+
void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
96+
RewritePatternSet &patterns,
97+
const TargetInfoBase &targetInfo,
98+
PatternBenefit benefit);
9599

96100
void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
97101
const TargetInfoBase &targetInfo,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,11 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
11251125

11261126
// Emit indices calculation within each ConversionPattern, and returns a
11271127
// [elemsPerThread X rank] index matrix.
1128+
//
1129+
// For example, for a thread a owns `elemsPerThread` elements of a tensor with
1130+
// type `type` and layout `layout`, the result will contain `elemsPerThread`
1131+
// vectors. Each vector contains the SSA values of the indices required to
1132+
// access the corresponding element, starting from the inner dimension.
11281133
SmallVector<SmallVector<Value>>
11291134
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
11301135
Attribute layout, RankedTensorType type, bool withCTAOffset);

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,32 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> {
869869
}];
870870
}
871871

872+
//
873+
// Gather Op
874+
//
875+
def TT_GatherOp : TT_Op<"gather", [Pure,
876+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
877+
let summary = "local gather operation";
878+
let description = [{
879+
Gather elements from the input tensor using the indices tensor along a
880+
single specified axis. The output tensor has the same shape as the indices
881+
tensor. The input and indices tensors must have the same number of
882+
dimension, and each dimension of the indices tensor that is not the gather
883+
dimension cannot be greater than the corresponding dimension in the input
884+
tensor.
885+
}];
886+
887+
let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis);
888+
let results = (outs TT_Tensor:$result);
889+
890+
let assemblyFormat = [{
891+
$src `[` $indices `]` attr-dict `:`
892+
functional-type(operands, results)
893+
}];
894+
895+
let hasVerifier = 1;
896+
}
897+
872898
//
873899
// Print Op
874900
//

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
125125
ScanLoweringHelper helper(scanOp);
126126
return helper.getScratchSizeInBytes();
127127
}
128+
if (auto gatherOp = dyn_cast<GatherOp>(op)) {
129+
GatherLoweringHelper helper(gatherOp);
130+
return helper.getScratchSizeInBytes();
131+
}
128132
if (auto histogram = dyn_cast<HistogramOp>(op)) {
129133
auto dstTy = histogram.getType();
130134
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(

lib/Analysis/Utility.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,17 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
408408
llvm_unreachable("Axis not found in order");
409409
}
410410

411+
GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
412+
: gatherOp(gatherOp) {}
413+
414+
unsigned GatherLoweringHelper::getScratchSizeInBytes() {
415+
// For now, lower the gather op by writing the source tensor to shared memory.
416+
// TODO(jeff): Leverage locality to avoid using scratch space when possible.
417+
RankedTensorType srcType = gatherOp.getSrc().getType();
418+
return product(srcType.getShape()) *
419+
ceil<unsigned>(srcType.getElementTypeBitWidth(), 8);
420+
}
421+
411422
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
412423
if (shape.empty())
413424
return 0;

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_triton_library(TritonGPUToLLVM
1313
AllocateSharedMemory.cpp
1414
ReduceOpToLLVM.cpp
1515
ScanOpToLLVM.cpp
16+
GatherOpToLLVM.cpp
1617
ConvertLayoutOpToLLVM.cpp
1718
ControlFlowOpToLLVM.cpp
1819
FuncOpToLLVM.cpp
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
2+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
4+
using namespace mlir;
5+
using namespace mlir::triton;
6+
7+
namespace {
8+
class GatherOpConversion : public ConvertOpToLLVMPattern<GatherOp> {
9+
public:
10+
GatherOpConversion(LLVMTypeConverter &typeConverter,
11+
const TargetInfoBase &targetInfo, PatternBenefit benefit)
12+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
13+
}
14+
15+
LogicalResult
16+
matchAndRewrite(GatherOp op, OpAdaptor adaptor,
17+
ConversionPatternRewriter &rewriter) const override;
18+
19+
private:
20+
const TargetInfoBase &targetInfo;
21+
};
22+
23+
LogicalResult
24+
GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor,
25+
ConversionPatternRewriter &rewriter) const {
26+
Location loc = op.getLoc();
27+
RankedTensorType srcType = op.getSrc().getType();
28+
29+
// Compute the src subtensor shape owned by this CTA.
30+
SmallVector<unsigned> srcShapePerCTA =
31+
convertType<unsigned>(triton::gpu::getShapePerCTA(srcType));
32+
33+
// Grab the src values in this thread.
34+
SmallVector<Value> srcValues =
35+
unpackLLElements(loc, adaptor.getSrc(), rewriter);
36+
37+
// Emit the indices of the src values owned by this thread.
38+
SmallVector<SmallVector<Value>> srcIndices =
39+
emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(),
40+
op.getSrc().getType(), /*withCTAOffset=*/true);
41+
42+
// Store the src values owned by the thread into their respective location in
43+
// the scratch memory.
44+
assert(srcValues.size() == srcIndices.size());
45+
46+
// Get the base pointer to the scratch memory.
47+
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
48+
49+
// For each src element owned by the thread, index into the scratch memory and
50+
// then store it.
51+
Type elemType = getTypeConverter()->convertType(srcType.getElementType());
52+
for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) {
53+
// Convert the index at each dim into a single offset given the shape of the
54+
// tensor.
55+
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
56+
// Emit the offset into the shared memory and then store the value.
57+
Value ptr = gep(smemBase.getType(), elemType, smemBase, offset);
58+
store(value, ptr);
59+
}
60+
61+
// Synchronize the whole CTA.
62+
barrier();
63+
64+
// Grab the index values owned by this thread.
65+
SmallVector<Value> idxValues =
66+
unpackLLElements(loc, adaptor.getIndices(), rewriter);
67+
68+
// Apply the layout of the destination tensor to obtain the indices of the
69+
// column to gather along, then for each column, replace the index along the
70+
// gather axis with the appropriate index value.
71+
//
72+
// I = LL(pid)
73+
// idx = indices[I]
74+
// I_gather = [I[d] if d != axis else idx for d in range(len(I))]
75+
// out[I] = src[I_gather]
76+
RankedTensorType dstType = op.getType();
77+
SmallVector<SmallVector<Value>> dstIndices =
78+
emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType,
79+
/*withCTAOffset=*/true);
80+
81+
unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth();
82+
unsigned axis = op.getAxis();
83+
SmallVector<Value> results(dstIndices.size());
84+
for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) {
85+
// The LL index computations are performed with 32 bit integers. If the
86+
// indices are something else, cast them to i32.
87+
if (idxWidth > 32) {
88+
idx = trunc(i32_ty, idx);
89+
} else if (idxWidth < 32) {
90+
// Negative indices don't make sense, so zero-extend.
91+
idx = zext(i32_ty, idx);
92+
}
93+
indices[axis] = idx;
94+
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
95+
Value ptr = gep(smemBase.getType(), elemType, smemBase, offset);
96+
results[i] = load(elemType, ptr);
97+
}
98+
99+
Value packed =
100+
packLLElements(loc, getTypeConverter(), results, rewriter, dstType);
101+
rewriter.replaceOp(op, packed);
102+
return success();
103+
}
104+
105+
} // namespace
106+
107+
void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
108+
RewritePatternSet &patterns,
109+
const TargetInfoBase &targetInfo,
110+
PatternBenefit benefit) {
111+
patterns.insert<GatherOpConversion>(typeConverter, targetInfo, benefit);
112+
}

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
537537
GenericOpPattern<triton::MakeRangeOp>, TritonExpandDimsPattern,
538538
TritonTransPattern, TritonDotPattern, GenericOpPattern<triton::LoadOp>,
539539
GenericOpPattern<triton::StoreOp>, GenericOpPattern<triton::HistogramOp>,
540+
GenericOpPattern<triton::GatherOp>,
540541
GenericOpPattern<triton::ExternElementwiseOp>,
541542
GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
542543
GenericOpPattern<triton::AtomicCASOp>,

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,54 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() {
10731073
return Speculation::NotSpeculatable;
10741074
}
10751075

1076+
// -- GatherOp --
1077+
LogicalResult GatherOp::verify() {
1078+
RankedTensorType indicesTy = getIndices().getType();
1079+
RankedTensorType srcTy = getSrc().getType();
1080+
RankedTensorType resTy = getResult().getType();
1081+
1082+
if (indicesTy.getShape() != resTy.getShape()) {
1083+
return emitOpError("indices and output shapes must match");
1084+
}
1085+
if (indicesTy.getEncoding() != resTy.getEncoding()) {
1086+
return emitOpError("indices and output encodings must match");
1087+
}
1088+
if (srcTy.getElementType() != resTy.getElementType()) {
1089+
return emitOpError("input and output element types must match");
1090+
}
1091+
if (srcTy.getRank() != indicesTy.getRank()) {
1092+
return emitOpError("input and indices ranks must match");
1093+
}
1094+
if (getAxis() >= srcTy.getRank()) {
1095+
return emitOpError("gather dimension must be less than the input rank");
1096+
}
1097+
for (int dim = 0; dim < indicesTy.getRank(); ++dim) {
1098+
if (dim == getAxis())
1099+
continue;
1100+
if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) {
1101+
return emitOpError("indices dimension ")
1102+
<< dim << " must match the corresponding input dimension";
1103+
}
1104+
}
1105+
1106+
return success();
1107+
}
1108+
1109+
LogicalResult GatherOp::inferReturnTypes(
1110+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
1111+
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1112+
SmallVectorImpl<Type> &inferredReturnTypes) {
1113+
GatherOpAdaptor adaptor(operands, attributes, properties, regions);
1114+
auto indicesType = cast<RankedTensorType>(adaptor.getIndices().getType());
1115+
auto srcType = cast<RankedTensorType>(adaptor.getSrc().getType());
1116+
1117+
// Shape and encoding of the indices with the element type of the src.
1118+
inferredReturnTypes.push_back(
1119+
RankedTensorType::get(indicesType.getShape(), srcType.getElementType(),
1120+
indicesType.getEncoding()));
1121+
return success();
1122+
}
1123+
10761124
// -- ExperimentalTensormapCreateOp --
10771125
LogicalResult ExperimentalTensormapCreateOp::verify() {
10781126
auto rank = getBoxDim().size();

0 commit comments

Comments
 (0)