Skip to content

Commit 536b0b9

Browse files
Merge commit '52cf1aee47f806585fcb1a88f5b24880ab6f6257'
2 parents 9ec46fe + 52cf1ae commit 536b0b9

File tree

26 files changed

+1054
-409
lines changed

26 files changed

+1054
-409
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,18 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
335335
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
336336
Value linear, ArrayRef<unsigned> shape);
337337

338+
SmallVector<unsigned> delinearize(unsigned linear, ArrayRef<unsigned> shape,
339+
ArrayRef<unsigned> order);
340+
338341
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
339342
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
340343

341344
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
342345
ArrayRef<unsigned> shape);
343346

347+
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
348+
ArrayRef<unsigned> order);
349+
344350
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
345351
StringRef content);
346352

@@ -495,6 +501,24 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
495501
return ret;
496502
}
497503

504+
/// Extend 2d shared object to 3d.
505+
///
506+
/// If tensor has 3 dimensions, returns original shared object.
507+
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
508+
///
509+
/// This Function is used to simplify processing of 2d and 3d dot operands,
510+
/// particularly in the conversion of local_load operation.
511+
///
512+
/// \param rewriter
513+
/// \param loc
514+
/// \param smemObj
515+
/// \param shape shape of a tensor represented by smemObj
516+
/// \returns shared object describing 3d tensor
517+
SharedMemoryObject
518+
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
519+
SharedMemoryObject smemObj,
520+
ArrayRef<int64_t> shape);
521+
498522
// -----------------------------------------------------------------------
499523
// Blocked layout indices
500524
// -----------------------------------------------------------------------

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,18 @@ def TT_GatherOp : TT_Op<"gather", [Pure,
882882
dimension, and each dimension of the indices tensor that is not the gather
883883
dimension cannot be greater than the corresponding dimension in the input
884884
tensor.
885+
886+
The `efficient_layout` attribute is set when the compiler has determined an
887+
optimized layout for the operation, indicating that it should not be
888+
changed.
885889
}];
886890

887-
let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis);
891+
let arguments = (ins
892+
TT_Tensor:$src,
893+
TT_IntTensor:$indices,
894+
I32Attr:$axis,
895+
UnitAttr:$efficient_layout
896+
);
888897
let results = (outs TT_Tensor:$result);
889898

890899
let assemblyFormat = [{

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ void dumpHWLayout(RankedTensorType tensorType);
234234
// Return a string representation of the layout of the tensor.
235235
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);
236236

237+
template <typename T>
238+
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
239+
240+
llvm::SmallVector<unsigned>
241+
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);
242+
237243
} // namespace gpu
238244
} // namespace triton
239245
} // namespace mlir

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,20 @@ def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality",
158158
let summary = "Reduce the cost of synchronization between threads in an SM";
159159

160160
let description = [{
161-
The aim of this pass is to reduce cross-thread communication for reduction
162-
operations, by adjusting the reduction size (or layout) to avoid splitting
163-
the reduction operation between multiple threads. Currently, this pass only
164-
optimizes reduction yielded by loop to be thread-local until
165-
after the loop completes.
161+
The aim of this pass is to reduce cross-thread communication for certain
162+
operations, like reductions, reshapes, and gathers.
163+
164+
For reduction operations, this pass attempts to adjust the reduction size
165+
(or layout) to avoid splitting the reduction operation between multiple
166+
threads. Currently, this pass only optimizes reduction yielded by loop to be
167+
thread-local until after the loop completes.
168+
169+
For gathers, this pass will attempt to pick an optimized layout for gather
170+
operations in the module. This is determined based on the shapes of the
171+
gather operands as well as their existing layouts. The pass applies
172+
heuristics to determine when it is appropriate to assign specific layouts
173+
and trigger their respective codegen paths. For now, the pass only attempts
174+
to apply layouts that result in warp-synchronous gathers.
166175
}];
167176

168177
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
163163
LogicalResult getConvertBackwardSlice(
164164
Value root, SetVector<Value> &slice, Attribute rootEncoding,
165165
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr,
167-
std::function<Value(Value, Attribute)> getExistingConversion = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr);
168167

169168
// Populate pattern to remove dead cycles in ForOp.
170169
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

0 commit comments

Comments
 (0)