Skip to content

Commit 386840d

Browse files
Merge OpenAI Triton commit 26b45d8 (#4280)
This PR change the Triton base from 4b9efc5 to 26b45d8 (May 15). Pass rate: 96.85%->95.34% (#4281)
2 parents b929bc9 + 2cc9369 commit 386840d

File tree

33 files changed

+1005
-487
lines changed

33 files changed

+1005
-487
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ See [`python/triton/knobs.py`](python/triton/knobs.py) for the full list of conf
212212
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1.
213213
- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`.
214214
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
215+
- `TRITON_STRIP_DEBUG_INFO` removes all debug information from the module, including location information
215216

216217
N.B. Some of these environment variables don't have a knob in `knobs.py`-- those are only relevant to the C++ layer(s), hence they don't exist in the python layer.
217218

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 24 additions & 248 deletions
Original file line numberDiff line numberDiff line change
@@ -338,46 +338,21 @@ using namespace mlir::triton;
338338

339339
class SharedMemoryObject {
340340
public:
341-
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets)
342-
: base(base), baseElemType(baseElemType),
343-
offsets(offsets.begin(), offsets.end()) {}
341+
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets);
344342

345343
SharedMemoryObject(Value base, Type baseElemType, int64_t rank, Location loc,
346-
RewriterBase &rewriter)
347-
: base(base), baseElemType(baseElemType) {
348-
auto b = TritonLLVMOpBuilder(loc, rewriter);
349-
offsets.append(rank, b.i32_val(0));
350-
}
344+
RewriterBase &rewriter);
351345

352346
SmallVector<Value> getOffsets() const { return offsets; }
353347
Value getBase() const { return base; }
354348
Type getBaseElemType() const { return baseElemType; }
355349

356-
SmallVector<Value> getElems() const {
357-
SmallVector<Value> elems;
358-
elems.push_back(base);
359-
elems.append(offsets.begin(), offsets.end());
360-
return elems;
361-
}
350+
SmallVector<Value> getElems() const;
362351

363-
SmallVector<Type> getTypes() const {
364-
SmallVector<Type> types;
365-
types.push_back(base.getType());
366-
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
367-
return types;
368-
}
352+
SmallVector<Type> getTypes() const;
369353

370354
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
371-
RewriterBase &rewriter) const {
372-
auto allocShape = memDesc.getAllocShape();
373-
auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA(
374-
memDesc.getEncoding(), allocShape);
375-
auto layoutOrder = triton::gpu::getOrder(memDesc);
376-
auto allocStrides = SharedMemoryObject::getStridesForShape(
377-
allocShapePerCTA, layoutOrder, loc, rewriter);
378-
return SmallVector<Value>(allocStrides.end() - offsets.size(),
379-
allocStrides.end());
380-
}
355+
RewriterBase &rewriter) const;
381356

382357
// TODO(Keren): deprecate the method once AMD backend has cleaned up
383358
Value getCSwizzleOffset(int dim) const {
@@ -386,50 +361,16 @@ class SharedMemoryObject {
386361
}
387362

388363
// TODO(Keren): deprecate the method once AMD backend has cleaned up
389-
Value getBaseBeforeSlice(int dim, Location loc,
390-
RewriterBase &rewriter) const {
391-
auto b = TritonLLVMOpBuilder(loc, rewriter);
392-
Value cSwizzleOffset = getCSwizzleOffset(dim);
393-
Value offset = b.sub(b.i32_val(0), cSwizzleOffset);
394-
Type type = base.getType();
395-
return b.gep(type, baseElemType, base, offset);
396-
}
364+
Value getBaseBeforeSlice(int dim, Location loc, RewriterBase &rewriter) const;
397365

398366
private:
399-
static SmallVector<unsigned>
400-
getOrderForShape(ArrayRef<int64_t> shape, ArrayRef<unsigned> layoutOrder) {
401-
SmallVector<unsigned> order(shape.size());
402-
// Default minor-to-major order
403-
std::iota(order.rbegin(), order.rend(), 0);
404-
if (layoutOrder.size() > 0) {
405-
// If a layout order is provided, we assume it specifies the order in
406-
// which the dimensions are first accessed, and unspecified dimensions
407-
// retain the minor-to-major order. For example, if order = [2, 1, 0] and
408-
// layoutOrder = [0, 1], we need to shift `layoutOrder`
409-
// by -1 (move them right). The resulting order will then be [1, 2, 0].
410-
int rankDiff = layoutOrder.size() - shape.size();
411-
auto minRank = std::min<size_t>(shape.size(), layoutOrder.size());
412-
for (size_t i = 0; i < minRank; ++i)
413-
order[i] = layoutOrder[i] - rankDiff;
414-
}
415-
assert(isPermutationOfIota(order) && "Invalid order");
416-
return order;
417-
}
367+
static SmallVector<unsigned> getOrderForShape(ArrayRef<int64_t> shape,
368+
ArrayRef<unsigned> layoutOrder);
418369

419370
static SmallVector<Value> getStridesForShape(ArrayRef<int64_t> shape,
420371
ArrayRef<unsigned> layoutOrder,
421372
Location loc,
422-
RewriterBase &rewriter) {
423-
SmallVector<Value> strides(shape.size());
424-
auto order = SharedMemoryObject::getOrderForShape(shape, layoutOrder);
425-
int64_t stride = 1;
426-
auto b = TritonLLVMOpBuilder(loc, rewriter);
427-
for (auto idx : order) {
428-
strides[idx] = b.i32_val(stride);
429-
stride *= shape[idx];
430-
}
431-
return strides;
432-
}
373+
RewriterBase &rewriter);
433374

434375
Value base; // i32 ptr. The start address of the shared memory object.
435376
Type baseElemType;
@@ -486,97 +427,14 @@ inline bool isKernel(FunctionOpInterface funcOp) {
486427
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
487428
}
488429

489-
inline Value getStackPointer(RewriterBase &rewriter,
490-
FunctionOpInterface funcOp) {
491-
// See NOTE: [Additional Function Arguments]
492-
if (!isKernel(funcOp)) {
493-
return funcOp.getArgument(funcOp.getNumArguments() - 2);
494-
}
495-
496-
auto mod = funcOp->getParentOfType<ModuleOp>();
497-
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
498-
assert(globalBase);
499-
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
500-
}
501-
502-
inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
503-
const TargetInfoBase &targetInfo,
504-
FunctionOpInterface funcOp,
505-
Value allocOffset = {}) {
506-
// See NOTE: [Additional Function Arguments]
507-
if (!isKernel(funcOp)) {
508-
// Base for this function
509-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
510-
if (!allocOffset) {
511-
return gmemBase;
512-
}
513-
514-
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
515-
auto b = TritonLLVMOpBuilder(loc, rewriter);
516-
return b.gep(ptrTy, i8_ty, gmemBase, allocOffset);
517-
}
518-
519-
// Base for entire kernel
520-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
521-
522-
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
523-
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
524-
"ttg.global_scratch_memory_size");
525-
if (!allocSizeAttr) {
526-
return gmemBase;
527-
}
430+
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp);
528431

529-
Value gridIdx[3];
530-
Value gridDim[2];
531-
for (int k = 0; k < 3; ++k) {
532-
gridIdx[k] = rewriter.create<GetProgramIdOp>(loc, k);
533-
}
534-
for (int k = 0; k < 2; ++k) {
535-
gridDim[k] = rewriter.create<GetNumProgramsOp>(loc, k);
536-
}
432+
Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
433+
const TargetInfoBase &targetInfo,
434+
FunctionOpInterface funcOp, Value allocOffset);
537435

538-
auto b = TritonLLVMOpBuilder(loc, rewriter);
539-
Value linearId = gridIdx[2];
540-
for (int k = 0; k < 2; ++k) {
541-
linearId = b.add(gridIdx[1 - k], b.mul(linearId, gridDim[1 - k]));
542-
}
543-
auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
544-
if (numCTAs > 1) {
545-
linearId = b.mul(linearId, b.i32_val(numCTAs));
546-
linearId = b.add(linearId, targetInfo.getClusterCTAId(rewriter, loc));
547-
}
548-
549-
auto allocSize = allocSizeAttr.getValue().getZExtValue();
550-
551-
Value offset = b.mul(linearId, b.i32_val(allocSize));
552-
if (allocOffset) {
553-
offset = b.add(offset, allocOffset);
554-
}
555-
556-
auto *ctx = rewriter.getContext();
557-
auto res =
558-
b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset);
559-
return res;
560-
}
561-
562-
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
563-
const TargetInfoBase &target, Operation *op) {
564-
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
565-
target.getSharedAddressSpace());
566-
auto func = op->template getParentOfType<FunctionOpInterface>();
567-
if (!func)
568-
func = cast<FunctionOpInterface>(op);
569-
570-
assert(op->hasAttr("allocation.offset"));
571-
size_t offset = cast<IntegerAttr>(op->getAttr("allocation.offset"))
572-
.getValue()
573-
.getZExtValue();
574-
auto b = TritonLLVMOpBuilder(loc, rewriter);
575-
Value offVal = b.i32_val(offset);
576-
Value base =
577-
b.gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
578-
return base;
579-
}
436+
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
437+
const TargetInfoBase &target, Operation *op);
580438

581439
// -----------------------------------------------------------------------
582440
// MXFP utilities
@@ -619,16 +477,8 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
619477
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
620478
using ::mlir::triton::gpu::SliceEncodingAttr;
621479

622-
inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
623-
ArrayRef<Value> strides) {
624-
assert(offsets.size() == strides.size());
625-
auto b = TritonLLVMOpBuilder(loc, rewriter);
626-
Value ret = b.i32_val(0);
627-
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
628-
ret = b.add(ret, b.mul(offset, stride));
629-
}
630-
return ret;
631-
}
480+
Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
481+
ArrayRef<Value> strides);
632482

633483
/// Extend 2d shared object to 3d.
634484
///
@@ -720,91 +570,17 @@ SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
720570

721571
Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter);
722572

723-
inline std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
724-
switch (atomicOp) {
725-
case RMWOp::AND:
726-
return LLVM::AtomicBinOp::_and;
727-
case RMWOp::OR:
728-
return LLVM::AtomicBinOp::_or;
729-
case RMWOp::XOR:
730-
return LLVM::AtomicBinOp::_xor;
731-
case RMWOp::ADD:
732-
return LLVM::AtomicBinOp::add;
733-
case RMWOp::FADD:
734-
return LLVM::AtomicBinOp::fadd;
735-
case RMWOp::MAX:
736-
return LLVM::AtomicBinOp::max;
737-
case RMWOp::MIN:
738-
return LLVM::AtomicBinOp::min;
739-
case RMWOp::UMAX:
740-
return LLVM::AtomicBinOp::umax;
741-
case RMWOp::UMIN:
742-
return LLVM::AtomicBinOp::umin;
743-
case RMWOp::XCHG:
744-
return LLVM::AtomicBinOp::xchg;
745-
default:
746-
return {};
747-
}
748-
}
573+
std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp);
749574

750-
inline std::optional<LLVM::AtomicOrdering>
751-
getMemoryOrdering(MemSemantic memOrdering) {
752-
switch (memOrdering) {
753-
case MemSemantic::RELAXED:
754-
return LLVM::AtomicOrdering::monotonic;
755-
case MemSemantic::ACQUIRE:
756-
return LLVM::AtomicOrdering::acquire;
757-
case MemSemantic::RELEASE:
758-
return LLVM::AtomicOrdering::release;
759-
case MemSemantic::ACQUIRE_RELEASE:
760-
return LLVM::AtomicOrdering::acq_rel;
761-
default:
762-
return {};
763-
}
764-
}
575+
std::optional<LLVM::AtomicOrdering> getMemoryOrdering(MemSemantic memOrdering);
765576

766-
inline bool
767-
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
768-
ArrayRef<int64_t> allocShape,
769-
triton::gpu::SharedEncodingTrait sharedEnc) {
770-
auto rank = shape.size();
771-
auto swizzledLayout =
772-
dyn_cast<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc);
773-
auto nvmmaLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(sharedEnc);
774-
bool noSwizzling = (swizzledLayout && swizzledLayout.getMaxPhase() == 1) ||
775-
(nvmmaLayout && nvmmaLayout.getSwizzlingByteWidth() == 0);
776-
return /*no swizzling*/ noSwizzling ||
777-
/*swizzling but same shape*/ shape == allocShape ||
778-
/*swizzling and rank-reduced and rank >= 2*/
779-
(shape == allocShape.take_back(rank) && rank >= 2);
780-
}
577+
bool isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
578+
ArrayRef<int64_t> allocShape,
579+
triton::gpu::SharedEncodingTrait sharedEnc);
781580

782-
inline llvm::MapVector<StringAttr, int32_t>
783-
getAllFreeVarMasks(MLIRContext *ctx) {
784-
// Mask where all elements are redundant
785-
auto kReg = str_attr("reg");
786-
auto kLane = str_attr("lane");
787-
auto kWarp = str_attr("warp");
788-
auto kBlock = str_attr("block");
581+
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx);
789582

790-
int32_t fullMask = -1;
791-
llvm::MapVector<StringAttr, int32_t> ret;
792-
for (auto dimName : {kReg, kLane, kWarp, kBlock}) {
793-
ret[dimName] = fullMask;
794-
}
795-
return ret;
796-
}
797-
798-
inline llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type) {
799-
auto ctx = type.getContext();
800-
auto tensorTy = dyn_cast<RankedTensorType>(type);
801-
if (!tensorTy) {
802-
return getAllFreeVarMasks(ctx);
803-
}
804-
auto ll =
805-
triton::gpu::toLinearLayout(tensorTy.getShape(), tensorTy.getEncoding());
806-
return ll.getFreeVariableMasks();
807-
}
583+
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type);
808584

809585
inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) {
810586
return (index & freeVarMask) == 0;

include/triton/Dialect/Triton/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,15 @@ def TritonLoopInvariantCodeMotion : Pass</*cli-arg*/"triton-licm", /*Op*/"mlir::
7979
let dependentDialects = ["mlir::triton::TritonDialect"];
8080
}
8181

82+
def TritonLoopAwareCSE : Pass<"triton-loop-aware-cse", "mlir::ModuleOp"> {
83+
let summary = "CSE within loop bodies";
84+
85+
let description = [{
86+
The `triton-loop-aware-cse` pass performs recursive common subexpression
87+
elimination within loop bodies. Unlike regular CSE, which is a single-pass
88+
greedy algorithm, this pass can recursively eliminate loop iteration
89+
arguments and subcomputations that always have the same value.
90+
}];
91+
}
92+
8293
#endif

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,18 @@ def TTG_LocalStoreOp : TTG_Op<"local_store"> {
313313
}];
314314
}
315315

316+
def TTG_PredicateStageOp: TTG_Op<"predicate_stage",
317+
[Pure, AllTypesMatch<["iv", "ub", "step"]>]> {
318+
let summary = "pipeliner stage predicate";
319+
let arguments = (ins AnySignlessIntegerOrIndex:$iv,
320+
AnySignlessIntegerOrIndex:$ub,
321+
AnySignlessIntegerOrIndex:$step,
322+
I32Attr:$maxStage,
323+
I32Attr:$stage);
324+
let results = (outs I1:$result);
325+
let assemblyFormat = "$iv `,` $ub `,` $step `maxStage` $maxStage `stage` $stage attr-dict `:` type($iv) `->` type($result)";
326+
}
327+
316328
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
317329
let summary = "Upcast fp4 (e2m1) to fp";
318330

include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ struct PipeliningOption {
5757
/// pipeliner will have to predicate operations in the prologue/epilogue.
5858
bool supportDynamicLoops = false;
5959

60+
/// If set, use this function to emit the predicate stage ops instead of the
61+
/// default one.
62+
using EmitPredicateStageFnType = std::function<Value(
63+
RewriterBase &, Value, Value, Value, uint64_t, uint64_t)>;
64+
EmitPredicateStageFnType emitPredicateStageFn = nullptr;
65+
6066
// Callback to predicate operations when the prologue or epilogue are not
6167
// peeled. This takes the original operation, an i1 predicate value and the
6268
// pattern rewriter. It is expected to replace the given operation with
@@ -95,6 +101,10 @@ FailureOr<scf::ForOp> pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp,
95101
const PipeliningOption &options,
96102
bool *modifiedIR = nullptr);
97103

104+
Value emitPredicateForStage(RewriterBase &rewriter, Value inductionVar,
105+
Value upperBound, Value step, uint64_t maxStage,
106+
uint64_t stage);
107+
98108
} // namespace triton
99109
} // namespace mlir
100110

0 commit comments

Comments
 (0)