Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 41 additions & 33 deletions third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ namespace mlir::triton::gpu::intel {
namespace {

// FIXME: Remove once IGC can split large 2D block loads.
static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName,
static void setAttrOnBOperand(Operation *op, StringRef attrName,
Attribute attr) {
Operation *defOp = dotOp.getB().getDefiningOp();
assert(isa<tt::DotOp>(op) && "Unexpected operation type");
Operation *defOp = cast<tt::DotOp>(op).getB().getDefiningOp();
while (auto convOp = dyn_cast_or_null<ttg::ConvertLayoutOp>(defOp))
defOp = convOp.getSrc().getDefiningOp();
if (auto transOp = dyn_cast_or_null<tt::TransOp>(defOp))
Expand All @@ -46,7 +47,8 @@ static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName,
}

SmallVector<unsigned>
getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
getWarpsPerTile(Operation *dotOp,
ttgi::DpasEncodingAttr::DPASCapability dpasCap,
const ArrayRef<int64_t> shape, unsigned numWarps) {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
Expand All @@ -60,8 +62,8 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
MLIRContext *ctx = forOp->getContext();
StringRef attrName =
ttgi::TritonIntelGPUDialect::getOneMatrixPerLoadAttrName();
setAttrOnBOperand(dotOp, attrName, UnitAttr::get(ctx));
setAttrOnBOperand(cast<tt::DotOp>(op), attrName, UnitAttr::get(ctx));
setAttrOnBOperand(op, attrName, UnitAttr::get(ctx));
setAttrOnBOperand(op, attrName, UnitAttr::get(ctx));
}
SmallVector<unsigned> ret(shape.size(), 1);
ret[0] = numWarps;
Expand Down Expand Up @@ -108,42 +110,44 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
return ret;
}

class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
const ttg::intel::DPASAnalysis &dpasAnalysis;
template <class OpTy,
typename = std::enable_if_t<llvm::is_one_of<OpTy, tt::DotOp>::value>>
class BlockedToDPAS : public OpRewritePattern<OpTy> {
const ttgi::DPASAnalysis &dpasAnalysis;
using TensorValue = TypedValue<RankedTensorType>;

public:
BlockedToDPAS(MLIRContext *context,
const ttg::intel::DPASAnalysis &dpasAnalysis)
: OpRewritePattern<tt::DotOp>(context), dpasAnalysis(dpasAnalysis) {}
BlockedToDPAS(MLIRContext *context, const ttgi::DPASAnalysis &dpasAnalysis,
int benefit)
: OpRewritePattern<OpTy>(context, benefit), dpasAnalysis(dpasAnalysis) {}

LogicalResult matchAndRewrite(tt::DotOp dotOp,
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
RankedTensorType oldRetType = dotOp.getType();
RankedTensorType oldRetType = op.getType();
if (!oldRetType.getEncoding() ||
isa<ttgi::DpasEncodingAttr>(oldRetType.getEncoding()))
return failure();

auto funcOp = dotOp->getParentOfType<FunctionOpInterface>();
if (dpasAnalysis.canUseDPAS(funcOp) !=
ttg::intel::DPASAnalysis::Result::True)
auto funcOp = op->template getParentOfType<FunctionOpInterface>();
if (dpasAnalysis.canUseDPAS(funcOp) != ttgi::DPASAnalysis::Result::True)
return failure();

// Create DPAS encoding for the given number of warps
ArrayRef<int64_t> retShape = oldRetType.getShape();
unsigned numWarps = ttg::lookupNumWarps(funcOp);

TensorValue a = dotOp.getA();
TensorValue b = dotOp.getB();
TensorValue a = op.getA();
TensorValue b = op.getB();
auto oldAType = cast<RankedTensorType>(a.getType());
auto oldBType = cast<RankedTensorType>(b.getType());

ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
auto dpasCap = ttgi::DpasEncodingAttr::getDPASCapability(mod);
ModuleOp mod = funcOp->template getParentOfType<ModuleOp>();
ttgi::DpasEncodingAttr::DPASCapability dpasCap =
ttgi::DpasEncodingAttr::getDPASCapability(mod);
Type elemType = oldAType.getElementType();
unsigned opsPerChan = ttgi::DpasEncodingAttr::getOpsPerChannel(elemType);
SmallVector<unsigned> warpsPerTile =
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
getWarpsPerTile(op, dpasCap, retShape, numWarps);
size_t rank = retShape.size();
SmallVector<unsigned> repCluster(rank, 1);

Expand All @@ -156,6 +160,7 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
: dpasCap.systolicDepth * 2; // A is packed to i16 or i32.
unsigned minM = mlir::ceil<unsigned>(threadsPerWarp, numElemsPerRowForA);
repeatCount = std::max(repeatCount, minM);

auto dpasEnc = ttgi::DpasEncodingAttr::get(
oldRetType.getContext(), repeatCount, dpasCap.systolicDepth,
dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster,
Expand Down Expand Up @@ -194,11 +199,11 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
threadsPerWarp);
}

RankedTensorType newRetType =
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), dpasEnc);

// convert accumulator
TensorValue oldAcc = dotOp.getC();
TensorValue oldAcc = op.getC();
auto newAcc = ttg::ConvertLayoutOp::create(rewriter, oldAcc.getLoc(),
newRetType, oldAcc);
// opA are packed to i16 for scalar type < 16 bits. opB are packed to i32.
Expand All @@ -215,15 +220,17 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {

a = ttg::ConvertLayoutOp::create(rewriter, a.getLoc(), newAType, a);
b = ttg::ConvertLayoutOp::create(rewriter, b.getLoc(), newBType, b);
auto newDot = tt::DotOp::create(rewriter, dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(dotOp, oldRetType,
auto newDot =
tt::DotOp::create(rewriter, op.getLoc(), newRetType, a, b, newAcc,
op.getInputPrecision(), op.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
newDot.getResult());
return success();
}
};

} // namespace

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Expand Down Expand Up @@ -258,13 +265,13 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
OpBuilder builder(dotOp);
Type AElType = dotOp.getA().getType().getElementType();
auto dpasLayout =
dyn_cast<ttg::intel::DpasEncodingAttr>(D.getType().getEncoding());
dyn_cast<ttgi::DpasEncodingAttr>(D.getType().getEncoding());

Type promoteType;
if (dpasLayout) {
bool isNativeFP8 = isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
// fp8 is not natively supported by the the DPAS instruction, promote it
// to fp16.
// fp8 is not always natively supported by the the DPAS instruction,
// promote it to fp16 when necessary.
if (!isNativeFP8)
return;
promoteType = builder.getF16Type();
Expand Down Expand Up @@ -428,23 +435,24 @@ static void transposeDots(ModuleOp m) {
}

class TritonIntelGPUAccelerateMatmulPass
: public triton::gpu::intel::impl::TritonIntelGPUAccelerateMatmulBase<
: public ttgi::impl::TritonIntelGPUAccelerateMatmulBase<
TritonIntelGPUAccelerateMatmulPass> {
public:
using triton::gpu::intel::impl::TritonIntelGPUAccelerateMatmulBase<
using ttgi::impl::TritonIntelGPUAccelerateMatmulBase<
TritonIntelGPUAccelerateMatmulPass>::TritonIntelGPUAccelerateMatmulBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
auto &dpasAnalysis = getAnalysis<ttg::intel::DPASAnalysis>();
auto &dpasAnalysis = getAnalysis<ttgi::DPASAnalysis>();

// Transpose dotOp operations that have a scale on the RHS.
transposeDots(m);

RewritePatternSet patterns(context);
constexpr int benefitDefault = 1;
patterns.add<BlockedToDPAS>(context, dpasAnalysis);
patterns.add<BlockedToDPAS<tt::DotOp>>(context, dpasAnalysis,
benefitDefault + 1);
ttgi::populateDecomposeScaledBlockedPatterns(patterns, benefitDefault);
if (applyPatternsGreedily(m, std::move(patterns)).failed())
signalPassFailure();
Expand Down
Loading