diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 9bdf6b4743..95f0cf8732 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -34,9 +34,10 @@ namespace mlir::triton::gpu::intel { namespace { // FIXME: Remove once IGC can split large 2D block loads. -static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName, +static void setAttrOnBOperand(Operation *op, StringRef attrName, Attribute attr) { - Operation *defOp = dotOp.getB().getDefiningOp(); + assert(isa(op) && "Unexpected operation type"); + Operation *defOp = cast(op).getB().getDefiningOp(); while (auto convOp = dyn_cast_or_null(defOp)) defOp = convOp.getSrc().getDefiningOp(); if (auto transOp = dyn_cast_or_null(defOp)) @@ -46,7 +47,8 @@ static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName, } SmallVector -getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap, +getWarpsPerTile(Operation *dotOp, + ttgi::DpasEncodingAttr::DPASCapability dpasCap, const ArrayRef shape, unsigned numWarps) { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); @@ -60,8 +62,8 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap, MLIRContext *ctx = forOp->getContext(); StringRef attrName = ttgi::TritonIntelGPUDialect::getOneMatrixPerLoadAttrName(); - setAttrOnBOperand(dotOp, attrName, UnitAttr::get(ctx)); - setAttrOnBOperand(cast(op), attrName, UnitAttr::get(ctx)); + setAttrOnBOperand(op, attrName, UnitAttr::get(ctx)); + setAttrOnBOperand(op, attrName, UnitAttr::get(ctx)); } SmallVector ret(shape.size(), 1); ret[0] = numWarps; @@ -108,42 +110,44 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap, return ret; } -class BlockedToDPAS : public OpRewritePattern { - const ttg::intel::DPASAnalysis &dpasAnalysis; +template ::value>> +class BlockedToDPAS : public OpRewritePattern { + const ttgi::DPASAnalysis &dpasAnalysis; using TensorValue = TypedValue; public: - BlockedToDPAS(MLIRContext *context, - const ttg::intel::DPASAnalysis &dpasAnalysis) - : OpRewritePattern(context), dpasAnalysis(dpasAnalysis) {} + BlockedToDPAS(MLIRContext *context, const ttgi::DPASAnalysis &dpasAnalysis, + int benefit) + : OpRewritePattern(context, benefit), dpasAnalysis(dpasAnalysis) {} - LogicalResult matchAndRewrite(tt::DotOp dotOp, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - RankedTensorType oldRetType = dotOp.getType(); + RankedTensorType oldRetType = op.getType(); if (!oldRetType.getEncoding() || isa(oldRetType.getEncoding())) return failure(); - auto funcOp = dotOp->getParentOfType(); - if (dpasAnalysis.canUseDPAS(funcOp) != - ttg::intel::DPASAnalysis::Result::True) + auto funcOp = op->template getParentOfType(); + if (dpasAnalysis.canUseDPAS(funcOp) != ttgi::DPASAnalysis::Result::True) return failure(); // Create DPAS encoding for the given number of warps ArrayRef retShape = oldRetType.getShape(); unsigned numWarps = ttg::lookupNumWarps(funcOp); - TensorValue a = dotOp.getA(); - TensorValue b = dotOp.getB(); + TensorValue a = op.getA(); + TensorValue b = op.getB(); auto oldAType = cast(a.getType()); auto oldBType = cast(b.getType()); - ModuleOp mod = funcOp->getParentOfType(); - auto dpasCap = ttgi::DpasEncodingAttr::getDPASCapability(mod); + ModuleOp mod = funcOp->template getParentOfType(); + ttgi::DpasEncodingAttr::DPASCapability dpasCap = + ttgi::DpasEncodingAttr::getDPASCapability(mod); Type elemType = oldAType.getElementType(); unsigned opsPerChan = ttgi::DpasEncodingAttr::getOpsPerChannel(elemType); SmallVector warpsPerTile = - getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); + getWarpsPerTile(op, dpasCap, retShape, numWarps); size_t rank = retShape.size(); SmallVector repCluster(rank, 1); @@ -156,6 +160,7 @@ class BlockedToDPAS : public OpRewritePattern { : dpasCap.systolicDepth * 2; // A is packed to i16 or i32. unsigned minM = mlir::ceil(threadsPerWarp, numElemsPerRowForA); repeatCount = std::max(repeatCount, minM); + auto dpasEnc = ttgi::DpasEncodingAttr::get( oldRetType.getContext(), repeatCount, dpasCap.systolicDepth, dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, @@ -194,11 +199,11 @@ class BlockedToDPAS : public OpRewritePattern { threadsPerWarp); } - RankedTensorType newRetType = + auto newRetType = RankedTensorType::get(retShape, oldRetType.getElementType(), dpasEnc); // convert accumulator - TensorValue oldAcc = dotOp.getC(); + TensorValue oldAcc = op.getC(); auto newAcc = ttg::ConvertLayoutOp::create(rewriter, oldAcc.getLoc(), newRetType, oldAcc); // opA are packed to i16 for scalar type < 16 bits. opB are packed to i32. @@ -215,15 +220,17 @@ class BlockedToDPAS : public OpRewritePattern { a = ttg::ConvertLayoutOp::create(rewriter, a.getLoc(), newAType, a); b = ttg::ConvertLayoutOp::create(rewriter, b.getLoc(), newBType, b); - auto newDot = tt::DotOp::create(rewriter, dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); - rewriter.replaceOpWithNewOp(dotOp, oldRetType, + auto newDot = + tt::DotOp::create(rewriter, op.getLoc(), newRetType, a, b, newAcc, + op.getInputPrecision(), op.getMaxNumImpreciseAcc()); + + rewriter.replaceOpWithNewOp(op, oldRetType, newDot.getResult()); return success(); } }; + } // namespace static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, @@ -258,13 +265,13 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { OpBuilder builder(dotOp); Type AElType = dotOp.getA().getType().getElementType(); auto dpasLayout = - dyn_cast(D.getType().getEncoding()); + dyn_cast(D.getType().getEncoding()); Type promoteType; if (dpasLayout) { bool isNativeFP8 = isa(AElType); - // fp8 is not natively supported by the the DPAS instruction, promote it - // to fp16. + // fp8 is not always natively supported by the the DPAS instruction, + // promote it to fp16 when necessary. if (!isNativeFP8) return; promoteType = builder.getF16Type(); @@ -428,23 +435,24 @@ static void transposeDots(ModuleOp m) { } class TritonIntelGPUAccelerateMatmulPass - : public triton::gpu::intel::impl::TritonIntelGPUAccelerateMatmulBase< + : public ttgi::impl::TritonIntelGPUAccelerateMatmulBase< TritonIntelGPUAccelerateMatmulPass> { public: - using triton::gpu::intel::impl::TritonIntelGPUAccelerateMatmulBase< + using ttgi::impl::TritonIntelGPUAccelerateMatmulBase< TritonIntelGPUAccelerateMatmulPass>::TritonIntelGPUAccelerateMatmulBase; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - auto &dpasAnalysis = getAnalysis(); + auto &dpasAnalysis = getAnalysis(); // Transpose dotOp operations that have a scale on the RHS. transposeDots(m); RewritePatternSet patterns(context); constexpr int benefitDefault = 1; - patterns.add(context, dpasAnalysis); + patterns.add>(context, dpasAnalysis, + benefitDefault + 1); ttgi::populateDecomposeScaledBlockedPatterns(patterns, benefitDefault); if (applyPatternsGreedily(m, std::move(patterns)).failed()) signalPassFailure();