Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,114 @@ namespace triton {

namespace {

/// FoldCmpSelectToMinMaxPattern is an optimization pattern that matches
/// 'triton::ReduceOp' operations whose reduction body consists of a
/// single 'arith.select' operation based on a floating-point comparsion,
/// and rewrites them into equivalent 'arith.minf', 'arith.maxf',
/// 'arith.MinimumFOp' or 'arith.MaximumFOp' operations.
///
/// This pattern handles the following cases:
///
/// 1. ** Simple Min/Max Reduction **
/// - select (cmpf ogt a, b), a, b --> maxf(a, b)
/// - select (cmpf olt a, b), a, b --> minf(a, b)
///
/// 2. ** NaN-Aware Min/Max Reduction **
/// - select (cmpf ogt a, b) || cmpf une a, a), a, b --> arith.maximumf(a, b)
/// - select (cmpf olt a, b) || cmpf une a, a), a, b --> arith.minimumf(a, b)
///
/// These transformations not only improve IR canonicalization but also
/// allow the successful lowering of tt.reduce operations to linalg operations,
/// which is already supported in the triton-shared dialect conversion pipeline.

struct FoldCmpSelectToMinMaxPattern
: public OpRewritePattern<triton::ReduceOp> {
using OpRewritePattern<triton::ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(triton::ReduceOp op,
PatternRewriter &rewriter) const override {
Block &body = *op.getBody();
auto *term = body.getTerminator();

// Get the value being yielded from the reduction.
Value ret = term->getOperand(0);

// Check if the yielded value is produced by an arith.select operation.
auto sel = ret.getDefiningOp<arith::SelectOp>();
if (!sel || !isa<FloatType>(sel.getType()))
return failure(); // Only handle floating-point types.

// Extract the condition and operands of the select operation.
auto cond = sel.getCondition().getDefiningOp();
Value trueVal = sel.getTrueValue();
Value falseVal = sel.getFalseValue();

// Case 1: Simple Min/Max Reduction.
if (auto cmp = dyn_cast<arith::CmpFOp>(cond)) {
// Match: select (cmpf ogt a, b), a, b → arith.maxf(a, b).
if (cmp.getPredicate() == arith::CmpFPredicate::OGT &&
trueVal == cmp.getLhs() && falseVal == cmp.getRhs()) {
rewriter.setInsertionPoint(sel);
auto maxOp =
rewriter.create<arith::MaxNumFOp>(sel.getLoc(), trueVal, falseVal);
rewriter.replaceOp(sel, maxOp.getResult());
return success();
}
// Match: select (cmpf olt a, b), a, b → arith.minf(a, b).
if (cmp.getPredicate() == arith::CmpFPredicate::OLT &&
trueVal == cmp.getLhs() && falseVal == cmp.getRhs()) {
rewriter.setInsertionPoint(sel);
auto minOp =
rewriter.create<arith::MinNumFOp>(sel.getLoc(), trueVal, falseVal);
rewriter.replaceOp(sel, minOp.getResult());
return success();
}
}

// Case 2: NaN-Aware Min/Max Reduction.
if (auto ori = dyn_cast<arith::OrIOp>(cond)) {
// Extract both sides of the OR condition.
auto cmp1 = ori.getLhs().getDefiningOp<arith::CmpFOp>();
auto cmp2 = ori.getRhs().getDefiningOp<arith::CmpFOp>();
if (!cmp1 || !cmp2)
return failure();

// Helper lambdas to identify comparison patterns.
auto isOGT = [&](arith::CmpFOp cmp) {
return cmp.getPredicate() == arith::CmpFPredicate::OGT &&
trueVal == cmp.getLhs() && falseVal == cmp.getRhs();
};
auto isOLT = [&](arith::CmpFOp cmp) {
return cmp.getPredicate() == arith::CmpFPredicate::OLT &&
trueVal == cmp.getLhs() && falseVal == cmp.getRhs();
};
auto isNaN = [&](arith::CmpFOp cmp) {
return cmp.getPredicate() == arith::CmpFPredicate::UNE &&
trueVal == cmp.getLhs() && trueVal == cmp.getRhs();
};

// Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b).
if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) {
rewriter.setInsertionPoint(sel);
auto maxOp =
rewriter.create<arith::MaximumFOp>(sel.getLoc(), trueVal, falseVal);
rewriter.replaceOp(sel, maxOp.getResult());
return success();
}

// Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b).
if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) {
rewriter.setInsertionPoint(sel);
auto minOp =
rewriter.create<arith::MinimumFOp>(sel.getLoc(), trueVal, falseVal);
rewriter.replaceOp(sel, minOp.getResult());
return success();
}
}
return failure();
}
};

class TritonArithToLinalgPass
: public triton::impl::TritonArithToLinalgBase<TritonArithToLinalgPass> {
using TritonArithToLinalgBase<
Expand Down Expand Up @@ -75,6 +183,17 @@ class TritonArithToLinalgPass
}
}

LogicalResult foldCmpSelectToMinMax() {
auto moduleOp = getOperation();
RewritePatternSet patterns(&getContext());
patterns.add<FoldCmpSelectToMinMaxPattern>(&getContext());

if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
return failure();
}
return success();
}

LogicalResult applyTensorConcatDecomposition() {
auto moduleOp = getOperation();
MLIRContext *context = &getContext();
Expand Down Expand Up @@ -185,6 +304,11 @@ class TritonArithToLinalgPass
target.addLegalOp<triton::AssertOp>();
}

// Fold cmp/select patterns before applying the main conversion patterns.
if (failed(foldCmpSelectToMinMax())) {
signalPassFailure();
}

triton::populateTritonArithToLinalgConversionPatterns(
pidsToFuncArgs, addptrToLinalg, assertToCf, transposeReduceToRank0,
patterns);
Expand Down