Skip to content

Commit 2350d5a

Browse files
Merge commit '61daa335a8b21102fae790ed286c7ddf71383af5'
2 parents 0c70ca3 + 61daa33 commit 2350d5a

File tree

37 files changed

+790
-46
lines changed

37 files changed

+790
-46
lines changed

include/triton/Analysis/Utility.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,19 @@ class ScanLoweringHelper {
153153
SmallVector<Type> srcElementTypes;
154154
};
155155

156+
// Helper class for lowering `tt.gather` operations. This class shares lowering
157+
// logic between shared memory allocation and LLVM codegen.
158+
class GatherLoweringHelper {
159+
public:
160+
GatherLoweringHelper(triton::GatherOp gatherOp);
161+
162+
// Get the shared memory scratch size required by this op.
163+
unsigned getScratchSizeInBytes();
164+
165+
private:
166+
triton::GatherOp gatherOp;
167+
};
168+
156169
// Decomposes a reshape into simpler pieces.
157170
//
158171
// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2].

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
9292
RewritePatternSet &patterns,
9393
const TargetInfoBase &targetInfo,
9494
PatternBenefit benefit);
95+
void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
96+
RewritePatternSet &patterns,
97+
const TargetInfoBase &targetInfo,
98+
PatternBenefit benefit);
9599

96100
void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
97101
const TargetInfoBase &targetInfo,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,11 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
11251125

11261126
// Emit indices calculation within each ConversionPattern, and returns a
11271127
// [elemsPerThread X rank] index matrix.
1128+
//
1129+
// For example, for a thread a owns `elemsPerThread` elements of a tensor with
1130+
// type `type` and layout `layout`, the result will contain `elemsPerThread`
1131+
// vectors. Each vector contains the SSA values of the indices required to
1132+
// access the corresponding element, starting from the inner dimension.
11281133
SmallVector<SmallVector<Value>>
11291134
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
11301135
Attribute layout, RankedTensorType type, bool withCTAOffset);

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,32 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> {
869869
}];
870870
}
871871

872+
//
873+
// Gather Op
874+
//
875+
def TT_GatherOp : TT_Op<"gather", [Pure,
876+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
877+
let summary = "local gather operation";
878+
let description = [{
879+
Gather elements from the input tensor using the indices tensor along a
880+
single specified axis. The output tensor has the same shape as the indices
881+
tensor. The input and indices tensors must have the same number of
882+
dimension, and each dimension of the indices tensor that is not the gather
883+
dimension cannot be greater than the corresponding dimension in the input
884+
tensor.
885+
}];
886+
887+
let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis);
888+
let results = (outs TT_Tensor:$result);
889+
890+
let assemblyFormat = [{
891+
$src `[` $indices `]` attr-dict `:`
892+
functional-type(operands, results)
893+
}];
894+
895+
let hasVerifier = 1;
896+
}
897+
872898
//
873899
// Print Op
874900
//

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,16 @@ def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp">
192192
"number of pipeline stages">
193193
];
194194
}
195+
196+
def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> {
197+
let summary = "Improve coalescing for async global to local copies";
198+
199+
let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than "
200+
"the blocked encoding's sizePerThread, this pass improves coalescing by clipping the "
201+
"sizePerThread value";
202+
203+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
204+
"mlir::triton::TritonDialect"];
205+
}
206+
195207
#endif

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ enum class MMALoadType {
202202
// pipelining
203203
};
204204
MMALoadType getMMALoadType(Operation *loadOp);
205+
206+
// Returns composed LinearLayout for register to shared copy
207+
std::optional<triton::LinearLayout>
208+
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
209+
Attribute srcEnc, Attribute dstEnc, int elemBitWidth);
205210
} // namespace mlir
206211

207212
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
125125
ScanLoweringHelper helper(scanOp);
126126
return helper.getScratchSizeInBytes();
127127
}
128+
if (auto gatherOp = dyn_cast<GatherOp>(op)) {
129+
GatherLoweringHelper helper(gatherOp);
130+
return helper.getScratchSizeInBytes();
131+
}
128132
if (auto histogram = dyn_cast<HistogramOp>(op)) {
129133
auto dstTy = histogram.getType();
130134
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(

lib/Analysis/Utility.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,17 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
415415
llvm_unreachable("Axis not found in order");
416416
}
417417

418+
GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
419+
: gatherOp(gatherOp) {}
420+
421+
unsigned GatherLoweringHelper::getScratchSizeInBytes() {
422+
// For now, lower the gather op by writing the source tensor to shared memory.
423+
// TODO(jeff): Leverage locality to avoid using scratch space when possible.
424+
RankedTensorType srcType = gatherOp.getSrc().getType();
425+
return product(srcType.getShape()) *
426+
ceil<unsigned>(srcType.getElementTypeBitWidth(), 8);
427+
}
428+
418429
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
419430
if (shape.empty())
420431
return 0;

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_triton_library(TritonGPUToLLVM
1313
AllocateSharedMemory.cpp
1414
ReduceOpToLLVM.cpp
1515
ScanOpToLLVM.cpp
16+
GatherOpToLLVM.cpp
1617
ConvertLayoutOpToLLVM.cpp
1718
ControlFlowOpToLLVM.cpp
1819
FuncOpToLLVM.cpp

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,12 @@ struct ElementwiseInlineAsmOpConversion
325325
// asmResults is a flat struct; pack its values into
326326
// [return_value][op.getPackedElement()].
327327
SmallVector<SmallVector<Value>> ret(op->getNumResults());
328+
int structIdx = 0;
328329
for (int i = 0; i < op->getNumResults(); i++) {
329-
int structIdx = 0;
330330
for (int j = 0; j < op.getPackedElement(); j++) {
331331
Value val;
332332
if (asmRetTypes.size() > 1) {
333-
val =
334-
extract_val(asmResults, i * op.getPackedElement() + structIdx++);
333+
val = extract_val(asmResults, structIdx++);
335334
} else {
336335
val = asmResults;
337336
}

0 commit comments

Comments
 (0)