diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 068965468e..be4b7c4873 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -1,3 +1,4 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -114,17 +115,36 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx)); } else { auto oldEncoding = cast(encoding); - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), - oldEncoding.getKWidth() * 2); - // Figure out the K dimension for the input A/B, given that the return - // type is upcasted A/B type so we need to update the proper dim size. + const int opIdx = oldEncoding.getOpIdx(); const bool hasBatch = xShape.size() == 3; const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; newShape[kIdx] *= 2; - retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx), - newVEncoding); + Type elemType = FloatType::getBF16(ctx); + + // Note: For Intel the dot operands layout's kWidth parameter must + // match the parent's DPAS layout opsPerChannel so we need to materialize + // a new DPAS layout. + Attribute newVEncoding; + if (auto dpasEncoding = + dyn_cast(oldEncoding.getParent())) { + auto newDpasEncoding = intel::DpasEncodingAttr::get( + ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), + dpasEncoding.getExecutionSize(), + intel::DpasEncodingAttr::getOpsPerChannel(elemType), + dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), + dpasEncoding.getSubGroupSize()); + newVEncoding = DotOperandEncodingAttr::get( + ctx, oldEncoding.getOpIdx(), newDpasEncoding, + newDpasEncoding.getOpsPerChannel()); + } else { + // Figure out the K dimension for the input A/B, given that the return + // type is upcasted A/B type so we need to update the proper dim size. + newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), + oldEncoding.getParent(), + oldEncoding.getKWidth() * 2); + } + retTy = RankedTensorType::get(newShape, elemType, newVEncoding); } inferredReturnTypes.push_back(retTy); } else { diff --git a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir index 4161bf50df..ecfecf09be 100644 --- a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir +++ b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir @@ -201,3 +201,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> + +module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} { + // CHECK: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> + // CHECK: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> + // CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> + // CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}> + // CHECK: dot_scaled + tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { + // CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]> + // CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]> + // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> + // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]> + // CHECK: [[A:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> + // CHECK: [[B:%.*]] = ttg.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> + // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]> + // CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]> + + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h index b159017e34..0b5ab80873 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h @@ -3,6 +3,10 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" +namespace mlir { +class ModuleOp; +} + #define GET_ATTRDEF_CLASSES #include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc" diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 6036c54e84..8cb10fcfe4 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -74,7 +74,6 @@ along the row (resp. col) dimension. ); let extraClassDeclaration = extraDistributedDeclaration # [{ - SmallVector getDPASInstShapeA() const; SmallVector getDPASInstShapeB() const; SmallVector getDPASInstShapeC() const; @@ -91,7 +90,30 @@ along the row (resp. col) dimension. return true; } - SmallVector getContigPerThread(); + SmallVector getContigPerThread() const; + + struct DPASCapability { + explicit DPASCapability(unsigned minSGSize) : executionSize(minSGSize) {} + DPASCapability() = default; + + bool isPVC() const { + return executionSize == 16; + } + bool isFalconShore() const { + return executionSize == 16; + } + bool isATSM() const { + return executionSize == 8; + } + + static constexpr unsigned systolicDepth = 8u; + static constexpr unsigned repeatCount = 8u; + static constexpr unsigned opsChanBitWidths = 32u; + unsigned executionSize = 0u; + }; + + static DPASCapability getDPASCapability(mlir::ModuleOp mod); + static unsigned getOpsPerChannel(Type elemType); }]; let hasCustomAssemblyFormat = 1; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 0546d0c8de..661d2bb9cb 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -357,7 +357,7 @@ SmallVector DpasEncodingAttr::getElemsPerThreadForOperands( return elemsPerThread; }; -SmallVector DpasEncodingAttr::getContigPerThread() { +SmallVector DpasEncodingAttr::getContigPerThread() const { size_t rank = getWarpsPerCTA().size(); assert(rank == 2 || rank == 3); SmallVector contigPerThread(rank, 1); @@ -381,6 +381,30 @@ SmallVector DpasEncodingAttr::getContigPerThread() { "be smaller than the threads required per row."); } +DpasEncodingAttr::DPASCapability +DpasEncodingAttr::getDPASCapability(ModuleOp mod) { + assert(mod && "expected a valid module"); + + if (auto minSGSizeAttr = mod->getAttrOfType( + triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) { + unsigned minSGSize = minSGSizeAttr.getInt(); + assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize"); + return DPASCapability(minSGSize); + } + + return DPASCapability(); +} + +unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) { + assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr"); + + unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth(); + if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN()) + dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16. + + return DPASCapability::opsChanBitWidths / dpasElemBitWidths; +} + LogicalResult DpasEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned repeatCount, unsigned systolicDepth, unsigned executionSize, @@ -469,18 +493,14 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const { llvm::ArrayRef rC = shapeC; auto warpsPerCTA = getWarpsPerCTA(); auto repCluster = getRepCluster(); - printer << "<{" - << "repeatCount = " << getRepeatCount() << ", " + printer << "<{" << "repeatCount = " << getRepeatCount() << ", " << "systolicDepth = " << getSystolicDepth() << ", " << "executionSize = " << getExecutionSize() << ", " << "opsPerChan = " << getOpsPerChannel() << ", " << "threadsPerWarp = " << getSubGroupSize() << ", " << "warpsPerCTA = [" << llvm::ArrayRef(warpsPerCTA) << "], " - << "repCluster = [" << repCluster << "], " - << "A = [" << rA << "], " - << "B = [" << rB << "], " - << "C = [" << rC << "]" - << "}>"; + << "repCluster = [" << repCluster << "], " << "A = [" << rA << "], " + << "B = [" << rB << "], " << "C = [" << rC << "]" << "}>"; } std::optional @@ -553,13 +573,10 @@ Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) { void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const { auto threadsPerWarp = getThreadsPerWarp(); auto sizePerThread = getSizePerThread(); - printer << "<{" - << "sizePerThread = [" << llvm::ArrayRef(sizePerThread) - << "]" + printer << "<{" << "sizePerThread = [" + << llvm::ArrayRef(sizePerThread) << "]" << ", threadsPerWarp = [" << llvm::ArrayRef(threadsPerWarp) - << "]" - << ", order = [" << getOrder() << "]" - << "}>"; + << "]" << ", order = [" << getOrder() << "]" << "}>"; } //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index ae385ed960..ff88d287dc 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -1,10 +1,15 @@ +#include "Dialect/TritonIntelGPU/IR/Attributes.h" +#include "Dialect/TritonIntelGPU/Transforms/Utility.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "intel/include/Analysis/DPAS.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -23,40 +28,10 @@ namespace mlir::triton::gpu::intel { namespace { -struct IntelDPASCapability { - uint32_t systolicDepth; - uint32_t repeatCount; - uint32_t executionSize; - uint32_t opsChanBitWidths; -}; - -IntelDPASCapability getDPASCapability(unsigned minSGSize) { - switch (minSGSize) { - case 8: { - IntelDPASCapability cap; - cap.systolicDepth = 8; - cap.repeatCount = 8; - cap.executionSize = 8; - cap.opsChanBitWidths = 32; - return cap; - } - case 16: { - IntelDPASCapability cap; - cap.systolicDepth = 8; - cap.repeatCount = 8; - cap.executionSize = 16; - cap.opsChanBitWidths = 32; - return cap; - } - default: - return IntelDPASCapability(); - } -} - -SmallVector getWarpsPerTile(tt::DotOp dotOp, - struct IntelDPASCapability dpasCap, - const ArrayRef shape, - unsigned numWarps) { +SmallVector +getWarpsPerTile(tt::DotOp dotOp, + ttg::intel::DpasEncodingAttr::DPASCapability dpasCap, + const ArrayRef shape, unsigned numWarps) { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; @@ -108,6 +83,7 @@ SmallVector getWarpsPerTile(tt::DotOp dotOp, class BlockedToDPAS : public OpRewritePattern { const ttg::intel::DPASAnalysis &dpasAnalysis; + using TensorValue = TypedValue; public: BlockedToDPAS(MLIRContext *context, @@ -116,8 +92,6 @@ class BlockedToDPAS : public OpRewritePattern { LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - using TensorValue = TypedValue; - RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || isa(oldRetType.getEncoding())) @@ -138,20 +112,10 @@ class BlockedToDPAS : public OpRewritePattern { auto oldAType = cast(a.getType()); auto oldBType = cast(b.getType()); - unsigned minSGSize = - mod->getAttrOfType( - ttg::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) - .getInt(); - IntelDPASCapability dpasCap = getDPASCapability(minSGSize); - unsigned dpasElemBitWidths = - oldAType.getElementType().getIntOrFloatBitWidth(); - - // We are upcasting FP8 to FP16 - if (oldAType.getElementType().isFloat8E5M2() || - oldAType.getElementType().isFloat8E4M3FN()) - dpasElemBitWidths = 2 * dpasElemBitWidths; - - unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; + auto dpasCap = ttg::intel::DpasEncodingAttr::getDPASCapability(mod); + Type elemType = oldAType.getElementType(); + unsigned opsPerChan = + ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType); SmallVector warpsPerTile = getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); size_t rank = retShape.size(); @@ -163,7 +127,15 @@ class BlockedToDPAS : public OpRewritePattern { dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, threadsPerWarp); - if (dpasCap.executionSize == 16 /* PVC */) { + if (dpasCap.isPVC() || dpasCap.isFalconShore()) { + unsigned dpasElemBitWidths = + oldAType.getElementType().getIntOrFloatBitWidth(); + + // We are upcasting FP8 to FP16 + if (oldAType.getElementType().isFloat8E5M2() || + oldAType.getElementType().isFloat8E4M3FN()) + dpasElemBitWidths = 2 * dpasElemBitWidths; + // Enlarge the repCluster size to use the large 2D load for A and B // operands. unsigned maxRepClusterM = @@ -219,6 +191,181 @@ class BlockedToDPAS : public OpRewritePattern { } }; +class DecomposeScaledBlocked : public OpRewritePattern { + const ttg::intel::DPASAnalysis &dpasAnalysis; + using TensorValue = TypedValue; + +public: + DecomposeScaledBlocked(MLIRContext *context, + const ttg::intel::DPASAnalysis &dpasAnalysis) + : OpRewritePattern(context), dpasAnalysis(dpasAnalysis) { + } + + mlir::LogicalResult + matchAndRewrite(tt::DotScaledOp scaledDotOp, + PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = scaledDotOp.getType(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + scaledDotOp, "expected blocked encoding result tensor"); + + unsigned rank = oldRetType.getRank(); + if (rank == 3) + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: 3d case"); + + TensorValue a = scaledDotOp.getLhs(); + TensorValue b = scaledDotOp.getRhs(); + TensorValue aScale = scaledDotOp.getLhsScale(); + TensorValue bScale = scaledDotOp.getRhsScale(); + if (bScale) + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: RHS scale"); + if (aScale && bScale) + return rewriter.notifyMatchFailure(scaledDotOp, + "NYI: both LHS and RHS scale"); + assert(aScale && "LHS scale missing"); + + tt::ScaleDotElemType aElemType = scaledDotOp.getLhsType(); + tt::ScaleDotElemType bElemType = scaledDotOp.getRhsType(); + + auto supportsTypes = [](tt::ScaleDotElemType elemType) { + return elemType == tt::ScaleDotElemType::E2M1 || + elemType == tt::ScaleDotElemType::E4M3 || + elemType == tt::ScaleDotElemType::E5M2 || + elemType == tt::ScaleDotElemType::BF16; + }; + if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: mxfp6 operand"); + + MLIRContext *ctx = scaledDotOp.getContext(); + auto mod = scaledDotOp->getParentOfType(); + + // Convert accumulator. + ttg::intel::DpasEncodingAttr dpasEnc = + getDPASEncoding(rewriter, scaledDotOp); + llvm::errs() << "dpasEnc: " << dpasEnc << "\n"; + + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), dpasEnc); + TensorValue oldAcc = scaledDotOp.getC(); + TensorValue newAcc = rewriter.create( + oldAcc.getLoc(), newRetType, oldAcc); + + unsigned opsPerChannel = dpasEnc.getOpsPerChannel(); + if (aElemType == tt::ScaleDotElemType::E2M1) + opsPerChannel *= 2; + + // Upcast A operand. + auto dpasEncForA = ttg::intel::DpasEncodingAttr::get( + ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), + dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), + dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize()); + auto newAEncoding = ttg::DotOperandEncodingAttr::get( + ctx, 0, dpasEncForA, dpasEncForA.getOpsPerChannel()); + a = createArg(rewriter, a, aElemType, newAEncoding); + + unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[1]; + SmallVector threadsPerWarp{instrShapeM, warpSize / instrShapeM}; + auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); + auto newScaleEncoding = ttg::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), + newAEncoding.getCTAOrder(), CTALayout); + aScale = createScale(rewriter, aScale, newScaleEncoding); + + auto retTypeEncoding = ttg::DotOperandEncodingAttr::get( + ctx, 0, dpasEnc, dpasEnc.getOpsPerChannel()); + Value scaledA = + createUpcastMxfpOp(rewriter, a, aScale, aElemType, retTypeEncoding); + + // Create B operand. + assert(bElemType != tt::ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); + auto newBEncoding = ttg::DotOperandEncodingAttr::get( + ctx, 1, dpasEnc, dpasEnc.getOpsPerChannel()); + b = createArg(rewriter, b, bElemType, newBEncoding); + + auto newDot = rewriter.create(scaledDotOp.getLoc(), newRetType, + scaledA, b, newAcc); + rewriter.replaceOpWithNewOp(scaledDotOp, oldRetType, + newDot); + return success(); + } + +private: + // TODO: this works only when the operand to scale is on the LHS, extend to + // support scaling the RHS operand. + ttg::intel::DpasEncodingAttr + getDPASEncoding(PatternRewriter &rewriter, + tt::DotScaledOp scaledDotOp) const { + auto mod = scaledDotOp->getParentOfType(); + auto dpasCap = ttg::intel::DpasEncodingAttr::getDPASCapability(mod); + Type elemType = scaledDotOp.getRhs().getType().getElementType(); + unsigned opsPerChan = + ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType); + unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + SmallVector warpsPerTile = {numWarps, 1}; + + ArrayRef retShape = scaledDotOp.getType().getShape(); + size_t rank = retShape.size(); + SmallVector repCluster(rank, 1); + + unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + + return ttg::intel::DpasEncodingAttr::get( + rewriter.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, + dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, + threadsPerWarp); + } + + TensorValue createArg(PatternRewriter &rewriter, TensorValue v, + tt::ScaleDotElemType type, Attribute vEncoding) const { + RankedTensorType vType = v.getType(); + auto newVType = RankedTensorType::get(vType.getShape(), + vType.getElementType(), vEncoding); + TensorValue ret = + rewriter.create(v.getLoc(), newVType, v); + + // convert to bf16 + if (type != tt::ScaleDotElemType::E2M1 && + type != tt::ScaleDotElemType::BF16) { + assert(type == tt::ScaleDotElemType::E5M2 || + type == tt::ScaleDotElemType::E4M3); + auto vTypeBf16 = RankedTensorType::get( + newVType.getShape(), rewriter.getBF16Type(), newVType.getEncoding()); + ret = cast>( + rewriter.create(v.getLoc(), vTypeBf16, ret) + .getResult()); + } + return ret; + } + + TensorValue createScale(PatternRewriter &rewriter, TensorValue scale, + Attribute scaleEncoding) const { + RankedTensorType scaleType = scale.getType(); + auto newScaleDotElemType = RankedTensorType::get( + scaleType.getShape(), scaleType.getElementType(), scaleEncoding); + return rewriter.create(scale.getLoc(), + newScaleDotElemType, scale); + } + + TensorValue createUpcastMxfpOp(PatternRewriter &rewriter, TensorValue a, + TensorValue scale, tt::ScaleDotElemType type, + Attribute retTypeEncoding) const { + auto aType = cast(a.getType()); + auto retType = RankedTensorType::get( + aType.getShape(), aType.getElementType(), retTypeEncoding); + if (type == tt::ScaleDotElemType::E2M1) { + RankedTensorType retTy; + SmallVector newShape(aType.getShape()); + newShape.back() *= 2; + retType = RankedTensorType::get( + newShape, FloatType::getBF16(rewriter.getContext()), retTypeEncoding); + } + // TODO: Check whether constructing without explicit retType works. + return rewriter.create(a.getLoc(), retType, a, scale, + type); + } +}; + } // namespace static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, @@ -291,7 +438,7 @@ class TritonIntelGPUAccelerateMatmulPass auto &dpasAnalysis = getAnalysis(); RewritePatternSet patterns(context); - patterns.add(context, dpasAnalysis); + patterns.add(context, dpasAnalysis); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure();