From 4728298e77e67db98eb1c00c4b4cb945c7136fa0 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 11 Nov 2024 18:21:19 +0000 Subject: [PATCH 1/9] [NFC]: Clean up AccelerateMatmul.cpp Signed-off-by: Tiotto, Ettore --- .../AccelerateMatmul.cpp | 93 +++++++++---------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 3e636d5bae..ae385ed960 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -4,7 +4,6 @@ #include "intel/include/Analysis/DPAS.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" -#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -14,9 +13,8 @@ #define PVC_2D_LOAD_MAXIMUM_BYTES_OF_COLS 64 using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::gpu; -using DPASAnalysis = intel::DPASAnalysis; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; namespace mlir::triton::gpu::intel { #define GEN_PASS_DEF_TRITONINTELGPUACCELERATEMATMUL @@ -55,7 +53,7 @@ IntelDPASCapability getDPASCapability(unsigned minSGSize) { } } -SmallVector getWarpsPerTile(DotOp dotOp, +SmallVector getWarpsPerTile(tt::DotOp dotOp, struct IntelDPASCapability dpasCap, const ArrayRef shape, unsigned numWarps) { @@ -66,7 +64,7 @@ SmallVector getWarpsPerTile(DotOp dotOp, SetVector slices = getSlice(dotOp, {filter}); // TODO: revisit this in flash attention. for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (isa(op) && (op != dotOp)) return {numWarps, 1}; size_t rank = shape.size(); @@ -108,41 +106,41 @@ SmallVector getWarpsPerTile(DotOp dotOp, return ret; } -class BlockedToDPAS : public RewritePattern { - const DPASAnalysis &dpasAnalysis; +class BlockedToDPAS : public OpRewritePattern { + const ttg::intel::DPASAnalysis &dpasAnalysis; public: - BlockedToDPAS(MLIRContext *context, const DPASAnalysis &dpasAnalysis) - : RewritePattern(DotOp::getOperationName(), 2, context), - dpasAnalysis(dpasAnalysis) {} + BlockedToDPAS(MLIRContext *context, + const ttg::intel::DPASAnalysis &dpasAnalysis) + : OpRewritePattern(context), dpasAnalysis(dpasAnalysis) {} - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - DotOp dotOp = cast(op); - RankedTensorType oldRetType = - cast(dotOp.getResult().getType()); + using TensorValue = TypedValue; + + RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || - isa(oldRetType.getEncoding())) + isa(oldRetType.getEncoding())) return failure(); - auto funcOp = op->getParentOfType(); - if (dpasAnalysis.canUseDPAS(funcOp) != DPASAnalysis::Result::True) + auto funcOp = dotOp->getParentOfType(); + if (dpasAnalysis.canUseDPAS(funcOp) != + ttg::intel::DPASAnalysis::Result::True) return failure(); // Create DPAS encoding for the given number of warps ArrayRef retShape = oldRetType.getShape(); - size_t rank = retShape.size(); ModuleOp mod = funcOp->getParentOfType(); - unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod); - Value a = dotOp.getA(); - Value b = dotOp.getB(); - RankedTensorType oldAType = cast(a.getType()); - RankedTensorType oldBType = cast(b.getType()); + TensorValue a = dotOp.getA(); + TensorValue b = dotOp.getB(); + auto oldAType = cast(a.getType()); + auto oldBType = cast(b.getType()); unsigned minSGSize = mod->getAttrOfType( - intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) + ttg::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) .getInt(); IntelDPASCapability dpasCap = getDPASCapability(minSGSize); unsigned dpasElemBitWidths = @@ -156,10 +154,11 @@ class BlockedToDPAS : public RewritePattern { unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; SmallVector warpsPerTile = getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); + size_t rank = retShape.size(); SmallVector repCluster(rank, 1); - unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); - auto dpasEnc = intel::DpasEncodingAttr::get( + unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + auto dpasEnc = ttg::intel::DpasEncodingAttr::get( oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, threadsPerWarp); @@ -184,7 +183,7 @@ class BlockedToDPAS : public RewritePattern { repCluster[rank - 2] = repClusterDimM; repCluster[rank - 1] = repClusterDimN; - dpasEnc = intel::DpasEncodingAttr::get( + dpasEnc = ttg::intel::DpasEncodingAttr::get( oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, threadsPerWarp); @@ -194,28 +193,28 @@ class BlockedToDPAS : public RewritePattern { RankedTensorType::get(retShape, oldRetType.getElementType(), dpasEnc); // convert accumulator - Value oldAcc = dotOp.getC(); - ConvertLayoutOp newAcc = - rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + TensorValue oldAcc = dotOp.getC(); + auto newAcc = rewriter.create(oldAcc.getLoc(), + newRetType, oldAcc); - DotOperandEncodingAttr newAEncoding = DotOperandEncodingAttr::get( + auto newAEncoding = ttg::DotOperandEncodingAttr::get( oldAType.getContext(), 0, newRetType.getEncoding(), opsPerChan); - DotOperandEncodingAttr newBEncoding = DotOperandEncodingAttr::get( + auto newBEncoding = ttg::DotOperandEncodingAttr::get( oldBType.getContext(), 1, newRetType.getEncoding(), opsPerChan); - RankedTensorType newAType = RankedTensorType::get( + auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), newAEncoding); - RankedTensorType newBType = RankedTensorType::get( + auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), newBEncoding); - a = rewriter.create(a.getLoc(), newAType, a); - b = rewriter.create(b.getLoc(), newBType, b); - DotOp newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); + a = rewriter.create(a.getLoc(), newAType, a); + b = rewriter.create(b.getLoc(), newBType, b); + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); - rewriter.replaceOpWithNewOp(op, oldRetType, - newDot.getResult()); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot.getResult()); return success(); } }; @@ -230,7 +229,7 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, return llvm::TypeSwitch(elemType) .Case([&](auto) { - return builder.create(loc, tensorPromotedType, operand); + return builder.create(loc, tensorPromotedType, operand); }) .Case([&](auto) { unsigned tgtBitWidth = elemType.getIntOrFloatBitWidth(), @@ -248,12 +247,12 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, // promote operands of dot op if the existing combination is not natively // supported. static void decomposeMixedModeDotOp(ModuleOp mod) { - mod.walk([](DotOp dotOp) -> void { + mod.walk([](tt::DotOp dotOp) -> void { auto D = dotOp.getD(); OpBuilder builder(dotOp); Type AElType = dotOp.getA().getType().getElementType(); auto dpasLayout = - dyn_cast(D.getType().getEncoding()); + dyn_cast(D.getType().getEncoding()); Type promoteType; if (dpasLayout) { @@ -289,15 +288,13 @@ class TritonIntelGPUAccelerateMatmulPass void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - DPASAnalysis &dpasAnalysis = getAnalysis(); + auto &dpasAnalysis = getAnalysis(); RewritePatternSet patterns(context); patterns.add(context, dpasAnalysis); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); - // now that we pick the scalar type decompose dot that are not natively - // supported. decomposeMixedModeDotOp(m); } }; From f02c70ff614caa6047e53185d1b33525cc9916c4 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 13 Nov 2024 15:50:46 +0000 Subject: [PATCH 2/9] Codegen for tritongpu.upcast_mxfp Signed-off-by: Tiotto, Ettore --- .../Conversion/TritonGPUToLLVM/Utility.h | 12 ++- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 69 +++++++++------ lib/Dialect/TritonGPU/IR/Ops.cpp | 21 ++++- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 33 ++------ .../lib/TritonIntelGPUToLLVM/CMakeLists.txt | 1 + .../PatternTritonGPUOpToLLVM.h | 5 ++ .../TritonIntelGPUToLLVM/PipelineManager.h | 2 + .../TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp | 83 +++++++++++++++++++ .../UpcastMXFPToLLVM.cpp | 69 ++------------- 9 files changed, 170 insertions(+), 125 deletions(-) create mode 100644 third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index f0f62b8e83..aa51ff2e59 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -396,10 +396,14 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, // MXFP utilities // ----------------------------------------------------------------------- -// Convert one int8, which contain, 2 packed mxfp4 values, into 2 bf16 -// standalone values and returns them as a pair for (high 4 bits, low 4 bits). -std::pair convertMxfp4x2ToBf16x2(RewriterBase &rewriter, - Location loc, Value v); +// Convert each value, which is an int8 containing 2 packed mxfp4 values, +// into 2 standalone bf16 values +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef); + +// Scale a mxfp4 value by a given scale. +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale); + } // namespace LLVM /* ------------------------------------ */ diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index fdca492875..b5ab3601ea 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -862,32 +862,49 @@ SmallVector getWrappedMultiDimOffset( return multiDimOffsetWrapped; } -std::pair convertMxfp4x2ToBf16x2(RewriterBase &rewriter, - Location loc, Value v) { - auto em0 = and_(v, i8_val(0x70)); - auto em1 = and_(v, i8_val(0x7)); - Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), - shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); - Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), - shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); - - // Three cases: - // 1) x is normal and non-zero: Correct bias - v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), - add(v0, i16_val((127 - 1) << 7)), v0); - v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), - add(v1, i16_val((127 - 1) << 7)), v1); - - // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in - // bf16 - v0 = select(icmp_eq(em0, i8_val(0x10)), - or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0); - v1 = select(icmp_eq(em1, i8_val(0x1)), - or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1); - // 3) x is zero, nothing to do - - return {v0, v1}; -} +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef values) { + SmallVector results; + for (auto v : values) { + auto em0 = and_(v, i8_val(0x70)); + auto em1 = and_(v, i8_val(0x7)); + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), + add(v0, i16_val((127 - 1) << 7)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), + add(v1, i16_val((127 - 1) << 7)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in + // bf16 + v0 = bitcast(select(icmp_eq(em0, i8_val(0x10)), + or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0), + bf16_ty); + v1 = bitcast(select(icmp_eq(em1, i8_val(0x1)), + or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1), + bf16_ty); + // 3) x is zero, nothing to do + results.push_back(v0); + results.push_back(v1); + } + return results; +} + +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value vBf16 = bitcast(v, bf16_ty); + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + Value scaledBf16 = fmul(vBf16, scaleBf16); + // Account for NaN in the scale as per the mxfp specification. + return select(scaleIsNan, nanBf16, scaledBf16); +}; } // namespace LLVM } // namespace mlir diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 65c647e16b..5d6ad114fe 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -1,3 +1,4 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -95,9 +96,23 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( if (typeEncoded == ScaleDotElemType::E2M1) { auto oldEncoding = cast(encoding); - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), - oldEncoding.getKWidth() * 2); + auto parentEncoding = oldEncoding.getParent(); + + // Note: For Intel the dot operands layout's kWidth parameter must + // match the parent's dpas layout opsPerChannel. Given that the kWidth + // parameter for the result dot layout is going to be twice the kWidth + // parameter of the operand, we cannot reuse the operand's parent dpas + // layout and we need to materialize a new dpas encoding. + if (auto dpasEncoding = dyn_cast(parentEncoding)) + parentEncoding = intel::DpasEncodingAttr::get( + ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), + dpasEncoding.getExecutionSize(), dpasEncoding.getOpsPerChannel() * 2, + dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), + dpasEncoding.getSubGroupSize()); + + auto newVEncoding = + DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), parentEncoding, + oldEncoding.getKWidth() * 2); auto newShape = SmallVector(xShape); newShape.back() *= 2; inferredReturnTypes.push_back( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 07fa634ec7..f8165a7693 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -19,17 +19,6 @@ using namespace mlir::triton::gpu; namespace { -Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, - Value scale) { - Value vBf16 = bitcast(v, bf16_ty); - Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); - Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); - Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); - Value scaledBf16 = fmul(vBf16, scaleBf16); - // Account for NaN in the scale as per the mxfp specification. - return select(scaleIsNan, nanBf16, scaledBf16); -}; - class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { private: const TargetInfoBase &targetInfo; @@ -83,7 +72,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value laneId = urem(tid, warpSize); if (isPacked) - xVals = unpackFP4Elements(loc, rewriter, xVals); + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); // Given that MFMA layout for the A tensor arranges thread in a column-major // manner, for the current tid, it's at row (tid % mDim). When we set up @@ -110,7 +99,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; - xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); } } } else { @@ -132,7 +122,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; - xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); } } } @@ -142,20 +133,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { rewriter.replaceOp(op, result); return success(); } - -private: - SmallVector unpackFP4Elements(Location loc, RewriterBase &rewriter, - ArrayRef packed) const { - // Split every fp4x2 into 2 bf16 values. - llvm::SmallVector unpacked; - unpacked.reserve(packed.size() * 2); - for (Value v : packed) { - auto [e0, e1] = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, v); - unpacked.push_back(e0); - unpacked.push_back(e1); - } - return unpacked; - } }; } // anonymous namespace diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index f46c265fa4..3accb81d09 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -24,6 +24,7 @@ add_triton_library(TritonIntelGPUToLLVM TritonGPUToLLVM.cpp TritonOpsToLLVM.cpp TypeConverter.cpp + UpcastMXFPToLLVM.cpp Utility.cpp ViewOpToLLVM.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index 40116a17ca..4ee1a012bc 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -41,6 +41,11 @@ void populateElementwiseOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + void populateBF16CastsLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h index b52b3a3b97..69dcd53e2d 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h @@ -253,6 +253,8 @@ class TritonGPUToLLVMPipelineManager { targetInfo, benefit); intel::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, benefit); + intel::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); } intel::populateSPMDOpToLLVMPattern(typeConverter, patterns, targetInfo, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 0000000000..59d90e2930 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,83 @@ +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto operands = adaptor.getOperands(); + SmallVector xVals = unpackLLElements(loc, operands[0], rewriter); + SmallVector scaleVals = unpackLLElements(loc, operands[1], rewriter); + ScaleDotElemType fpType = op.getFpType(); + + Value tid = tid_val(); + auto mod = op->getParentOfType(); + Value warpSize = + i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + Value warpId = udiv(tid, warpSize); + Value laneId = urem(tid, warpSize); + + if (fpType == ScaleDotElemType::E2M1) + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); + + // Each thread owns elements of 4 mxfp vectors so we need 4 scales + // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + + // 16, c + 17 + auto c = mul(udiv(laneId, i32_val(4)), i32_val(2)); + std::array ci = {c, add(c, i32_val(1)), add(c, i32_val(16)), + add(c, i32_val(17))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + // column major as per the DotOperandEncoding(opidx=0) layout + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), + }; + + for (int j = 0; j < 32; ++j) { + xVals[32 * i + j] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], si[j / 8]); + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // anonymous namespace + +void mlir::triton::intel::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 0c9ae00ce1..6cba3f45da 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -30,47 +30,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {} - llvm::SmallVector unpackFP4Elements(Location loc, - RewriterBase &rewriter, - ArrayRef vals) const { - - auto fp4x8ToBf16x2 = [&loc, &rewriter](Value v) { - llvm::SmallVector results(4); - for (int i = 0; i < 4; ++i) { - auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i))); - auto [e0, e1] = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, v_i); - // Swap as they come packed in big endian - results[i] = or_(zext(i32_ty, e0), shl(zext(i32_ty, e1), i32_val(16))); - } - return results; - }; - - // Split fp4x8 into 4 bf16x2 - llvm::SmallVector ret; - ret.reserve(vals.size() * 4); - for (int i = 0; i < vals.size(); ++i) { - auto vs = fp4x8ToBf16x2(vals[i]); - assert(vs.size() == 4); - for (auto v : vs) { - ret.push_back(v); - } - } - // FIXME [Dot LL] - // The DotOperandEncodingAttr without LLs encodes the - // layout as - // e0 e1 - // e2 e3 - // rather than transposed that, as the PTX docs say - // We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2) - assert(ret.size() % 16 == 0); - for (int i = 0; i < ret.size() / 16; ++i) { - for (int j = 0; j < 4; ++j) { - std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]); - } - } - return ret; - } - LogicalResult matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -90,27 +49,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (fpType == ScaleDotElemType::E2M1) { - xVals = unpackFP4Elements(loc, rewriter, xVals); - } - - auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { - // Split bf16x2 into 2 bf16, scale each of them, and pack them back - // TODO Is it true that the bfloats are always packed as bf16x2? - auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); - auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); - auto scaleIsNan = icmp_eq(s, i8_val(0xff)); - auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); - auto scaledBf16_0 = fmul(bf16_0, scaleBf16); - auto scaledBf16_1 = fmul(bf16_1, scaleBf16); - auto i16_0 = bitcast(scaledBf16_0, i16_ty); - auto i16_1 = bitcast(scaledBf16_1, i16_ty); - auto packed = - or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); - // Account for NaN in the scale as per the mxfp specification - auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); - return packed_nan; - }; + if (fpType == ScaleDotElemType::E2M1) + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + @@ -128,8 +68,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), }; - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); + for (int j = 0; j < 32; ++j) { + xVals[32 * i + j] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], si[j / 8]); } } From 0f4091d4ec7f6620dd4be360d5162cf4456af599 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 18 Nov 2024 15:42:18 +0000 Subject: [PATCH 3/9] WIP: tt.scaled_dot Signed-off-by: Tiotto, Ettore --- lib/Dialect/TritonGPU/IR/Ops.cpp | 72 ++++--- .../AccelerateMatmul.cpp | 200 +++++++++++++++++- 2 files changed, 238 insertions(+), 34 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 5d6ad114fe..32d1718e64 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -53,14 +53,23 @@ LogicalResult UpcastMXFPOp::verify() { "all dimensions except the last must match between operands"); } - auto dotEncoding = - dyn_cast_or_null(xTy.getEncoding()); + auto layoutX = xTy.getEncoding(); + auto layoutScale = scaleTy.getEncoding(); + if (bool(layoutX) != bool(layoutScale)) { + return emitOpError( + "Expected either both or neither operands to have an encoding"); + } + // Nothing to check if no encoding. This is used to infer the return type in + // AccelerateMatmul.cpp + if (!layoutX) { + return success(); + } + + auto dotEncoding = dyn_cast(xTy.getEncoding()); if (!dotEncoding) { return emitOpError("Expected a DotOperandEncodingAttr for values"); } - - auto blockedScale = - dyn_cast_or_null(scaleTy.getEncoding()); + auto blockedScale = dyn_cast(scaleTy.getEncoding()); if (!blockedScale) { return emitOpError("Expected a BlockOperandEncoding for scales"); } @@ -87,36 +96,37 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( auto xShape = xTy.getShape(); auto encoding = xTy.getEncoding(); - if (!encoding) { - return emitOptionalError(loc, "expected an encoding"); - } - if (!mlir::isa(encoding)) { - return emitOptionalError(loc, "expected a dotOperand encoding"); - } if (typeEncoded == ScaleDotElemType::E2M1) { - auto oldEncoding = cast(encoding); - auto parentEncoding = oldEncoding.getParent(); - - // Note: For Intel the dot operands layout's kWidth parameter must - // match the parent's dpas layout opsPerChannel. Given that the kWidth - // parameter for the result dot layout is going to be twice the kWidth - // parameter of the operand, we cannot reuse the operand's parent dpas - // layout and we need to materialize a new dpas encoding. - if (auto dpasEncoding = dyn_cast(parentEncoding)) - parentEncoding = intel::DpasEncodingAttr::get( - ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), - dpasEncoding.getExecutionSize(), dpasEncoding.getOpsPerChannel() * 2, - dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), - dpasEncoding.getSubGroupSize()); - - auto newVEncoding = - DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), parentEncoding, - oldEncoding.getKWidth() * 2); + RankedTensorType retTy; + auto newShape = SmallVector(xShape); newShape.back() *= 2; - inferredReturnTypes.push_back( - RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding)); + if (!encoding) { + retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx)); + } else { + auto oldEncoding = cast(encoding); + + // Note: For Intel the dot operands layout's kWidth parameter must + // match the parent's dpas layout opsPerChannel. Given that the kWidth + // parameter for the result dot layout is going to be twice the kWidth + // parameter of the operand, we cannot reuse the operand's parent dpas + // layout and we need to materialize a new dpas encoding. + auto parentEncoding = oldEncoding.getParent(); + if (auto dpasEncoding = dyn_cast(parentEncoding)) + parentEncoding = intel::DpasEncodingAttr::get( + ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), + dpasEncoding.getExecutionSize(), + dpasEncoding.getOpsPerChannel() * 2, dpasEncoding.getWarpsPerCTA(), + dpasEncoding.getRepCluster(), dpasEncoding.getSubGroupSize()); + + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, oldEncoding.getOpIdx(), parentEncoding, + oldEncoding.getKWidth() * 2); + retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx), + newVEncoding); + } + inferredReturnTypes.push_back(retTy); } else { inferredReturnTypes.push_back(xTy); } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index ae385ed960..91227665e2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -1,10 +1,15 @@ +#include "Dialect/TritonIntelGPU/IR/Attributes.h" +#include "Dialect/TritonIntelGPU/Transforms/Utility.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "intel/include/Analysis/DPAS.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -108,6 +113,7 @@ SmallVector getWarpsPerTile(tt::DotOp dotOp, class BlockedToDPAS : public OpRewritePattern { const ttg::intel::DPASAnalysis &dpasAnalysis; + using TensorValue = TypedValue; public: BlockedToDPAS(MLIRContext *context, @@ -116,8 +122,6 @@ class BlockedToDPAS : public OpRewritePattern { LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - using TensorValue = TypedValue; - RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || isa(oldRetType.getEncoding())) @@ -219,6 +223,196 @@ class BlockedToDPAS : public OpRewritePattern { } }; +class DecomposeScaledBlocked : public OpRewritePattern { + const ttg::intel::DPASAnalysis &dpasAnalysis; + using TensorValue = TypedValue; + +public: + DecomposeScaledBlocked(MLIRContext *context, + const ttg::intel::DPASAnalysis &dpasAnalysis) + : OpRewritePattern(context), dpasAnalysis(dpasAnalysis) { + } + + mlir::LogicalResult + matchAndRewrite(tt::DotScaledOp scaledDotOp, + PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = scaledDotOp.getType(); + if (!oldRetType.getEncoding() || + isa(oldRetType.getEncoding())) + return failure(); + + TensorValue a = scaledDotOp.getLhs(); + TensorValue b = scaledDotOp.getRhs(); + TensorValue scale = scaledDotOp.getLhsScale(); + tt::ScaleDotElemType aType = scaledDotOp.getLhsType(); + tt::ScaleDotElemType bType = scaledDotOp.getRhsType(); + + assert(scaledDotOp.getRhsScale() == nullptr && "rhs scale NYI"); + assert((aType == tt::ScaleDotElemType::E4M3 || + aType == tt::ScaleDotElemType::E5M2 || + aType == tt::ScaleDotElemType::E2M1) && + "NYI: lhs supports fp4 or fp8"); + assert(bType == tt::ScaleDotElemType::E4M3 || + bType == tt::ScaleDotElemType::E5M2 || + bType == tt::ScaleDotElemType::BF16 && + "NYI: rhs supports fp8 and bf16"); + + ttg::intel::DpasEncodingAttr dpasEnc = + getDPASEncoding(rewriter, scaledDotOp); + + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), dpasEnc); + llvm::errs() << "newRetType: " << newRetType << "\n"; + + // convert accumulator + TensorValue oldAcc = scaledDotOp.getC(); + TensorValue newAcc = rewriter.create( + oldAcc.getLoc(), newRetType, oldAcc); + llvm::errs() << "newAcc: " << newAcc << "\n"; + + MLIRContext *ctx = scaledDotOp.getContext(); + auto newAEncoding = ttg::DotOperandEncodingAttr::get( + ctx, 0, dpasEnc, dpasEnc.getOpsPerChannel()); + llvm::errs() << "newAEncoding: " << newAEncoding << "\n"; + + auto mod = scaledDotOp->getParentOfType(); + unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), + newAEncoding.getCTAOrder(), CTALayout); + a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); + llvm::errs() << "operand A: " << a << "\n"; + + // Upcast B operand + assert(bType != tt::ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); + auto newBEncoding = ttg::DotOperandEncodingAttr::get( + ctx, 1, dpasEnc, dpasEnc.getOpsPerChannel()); + b = createArg(rewriter, b, 1, bType, newBEncoding, + /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt); + llvm::errs() << "operand B: " << b << "\n"; + + auto newDot = rewriter.create(scaledDotOp.getLoc(), newRetType, + a, b, newAcc); + llvm::errs() << "dot: " << *newDot << "\n"; + + return success(); + } + +private: + TensorValue createArg(mlir::PatternRewriter &rewriter, TensorValue v, int idx, + tt::ScaleDotElemType type, + std::optional vEncoding, + std::optional opt_scale, + std::optional scaleEncoding) const { + MLIRContext *ctx = rewriter.getContext(); + // Create a new tensor with a given encoding or remove the encoding + auto maybeWithEncoding = [](RankedTensorType ty, + std::optional enc) { + if (enc.has_value()) + return RankedTensorType::get(ty.getShape(), ty.getElementType(), *enc); + + return RankedTensorType::get(ty.getShape(), ty.getElementType()); + }; + + RankedTensorType newVType = maybeWithEncoding(v.getType(), vEncoding); + TensorValue ret = + rewriter.create(v.getLoc(), newVType, v); + + // convert to bf16 + if (type != tt::ScaleDotElemType::E2M1 && + type != tt::ScaleDotElemType::BF16) { + assert(type == tt::ScaleDotElemType::E5M2 || + type == tt::ScaleDotElemType::E4M3); + auto vTypeBf16 = RankedTensorType::get( + newVType.getShape(), rewriter.getBF16Type(), newVType.getEncoding()); + ret = rewriter.create(v.getLoc(), vTypeBf16, ret); + } + if (opt_scale.has_value()) { + TensorValue scale = *opt_scale; + assert(idx == 0 && "NYI: rhs scale"); + RankedTensorType newScaleDotElemType = + maybeWithEncoding(scale.getType(), scaleEncoding); + scale = rewriter.create(scale.getLoc(), + newScaleDotElemType, scale); + ret = rewriter.create(v.getLoc(), ret, scale, type); + } + return ret; + } + + ttg::intel::DpasEncodingAttr + getDPASEncoding(PatternRewriter &rewriter, + tt::DotScaledOp scaledDotOp) const { + llvm::errs() << "at line: " << __LINE__ << "\n"; + MLIRContext *ctx = rewriter.getContext(); + TensorValue a = scaledDotOp.getLhs(); + TensorValue b = scaledDotOp.getRhs(); + TensorValue scale = scaledDotOp.getLhsScale(); + tt::ScaleDotElemType aType = scaledDotOp.getLhsType(); + tt::ScaleDotElemType bType = scaledDotOp.getRhsType(); + + Location loc = scaledDotOp.getLoc(); + RankedTensorType aTType = + createArg(rewriter, a, 0, aType, /*vEncoding=*/std::nullopt, scale, + /*scaleEncoding=*/std::nullopt) + .getType(); + auto aTypeNoEnc = + RankedTensorType::get(aTType.getShape(), aTType.getElementType()); + + llvm::errs() << "aTType: " << aTType << "\n"; + llvm::errs() << "aTypeNoEnc: " << aTypeNoEnc << "\n"; + + a = rewriter.create(loc, aTypeNoEnc, a); + llvm::errs() << "a: " << a << "\n"; + + RankedTensorType bTType = + createArg(rewriter, b, 1, bType, /*vEncoding=*/std::nullopt, + /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt) + .getType(); + auto bTypeNoEnc = + RankedTensorType::get(bTType.getShape(), bTType.getElementType()); + llvm::errs() << "bTType: " << bTType << "\n"; + llvm::errs() << "bTypeNoEnc: " << bTypeNoEnc << "\n"; + + b = rewriter.create(loc, bTypeNoEnc, b); + llvm::errs() << "b: " << b << "\n"; + + RankedTensorType oldRetType = scaledDotOp.getType(); + llvm::errs() << "oldRetType: " << oldRetType << "\n"; + auto dotOp = + rewriter.create(loc, oldRetType, a, b, scaledDotOp.getC()); + llvm::errs() << "dotOp: " << dotOp << "\n"; + + ArrayRef retShape = oldRetType.getShape(); + auto mod = scaledDotOp->getParentOfType(); + + unsigned minSGSize = + mod->getAttrOfType( + ttg::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) + .getInt(); + IntelDPASCapability dpasCap = getDPASCapability(minSGSize); + + unsigned dpasElemBitWidths = + aTType.getElementType().getIntOrFloatBitWidth(); + unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; + + SmallVector warpsPerTile = getWarpsPerTile( + dotOp, dpasCap, retShape, ttg::TritonGPUDialect::getNumWarps(mod)); + size_t rank = retShape.size(); + SmallVector repCluster(rank, 1); + + unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + auto dpasEnc = ttg::intel::DpasEncodingAttr::get( + oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, + dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, + threadsPerWarp); + + llvm::errs() << "dpasEnc: " << dpasEnc << "\n"; + + return dpasEnc; + } +}; + } // namespace static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, @@ -291,7 +485,7 @@ class TritonIntelGPUAccelerateMatmulPass auto &dpasAnalysis = getAnalysis(); RewritePatternSet patterns(context); - patterns.add(context, dpasAnalysis); + patterns.add(context, dpasAnalysis); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); From d93471007a5b7178e4d647ced82c7c057779d17a Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 22 Nov 2024 20:40:18 +0000 Subject: [PATCH 4/9] WIP: tt.scaled_dot Signed-off-by: Tiotto, Ettore --- lib/Dialect/TritonGPU/IR/Ops.cpp | 40 +-- .../TritonIntelGPU/accelerate-matmul-pvc.mlir | 29 ++ .../Dialect/TritonIntelGPU/IR/Attributes.h | 4 + .../IR/TritonIntelGPUAttrDefs.td | 13 +- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 71 ++++- .../AccelerateMatmul.cpp | 255 +++++++----------- 6 files changed, 217 insertions(+), 195 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 32d1718e64..9a62fd2db3 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -106,25 +106,31 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx)); } else { auto oldEncoding = cast(encoding); + Type elemType = FloatType::getBF16(ctx); // Note: For Intel the dot operands layout's kWidth parameter must - // match the parent's dpas layout opsPerChannel. Given that the kWidth - // parameter for the result dot layout is going to be twice the kWidth - // parameter of the operand, we cannot reuse the operand's parent dpas - // layout and we need to materialize a new dpas encoding. - auto parentEncoding = oldEncoding.getParent(); - if (auto dpasEncoding = dyn_cast(parentEncoding)) - parentEncoding = intel::DpasEncodingAttr::get( - ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), - dpasEncoding.getExecutionSize(), - dpasEncoding.getOpsPerChannel() * 2, dpasEncoding.getWarpsPerCTA(), - dpasEncoding.getRepCluster(), dpasEncoding.getSubGroupSize()); - - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, oldEncoding.getOpIdx(), parentEncoding, - oldEncoding.getKWidth() * 2); - retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx), - newVEncoding); + // match the parent's DPAS layout opsPerChannel so we need to materialize + // a new DPAS layout. + Attribute newVEncoding; + if (auto dpasEncoding = + dyn_cast(oldEncoding.getParent())) { + auto mod = operands[0].getDefiningOp()->getParentOfType(); + auto dpasCap = intel::DpasEncodingAttr::getDPASCapability(mod); + auto newDpasEncoding = intel::DpasEncodingAttr::get( + ctx, dpasCap.repeatCount, dpasCap.systolicDepth, + dpasCap.executionSize, + intel::DpasEncodingAttr::getOpsPerChannel(dpasCap, elemType), + dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), + dpasEncoding.getSubGroupSize()); + newVEncoding = DotOperandEncodingAttr::get( + ctx, oldEncoding.getOpIdx(), newDpasEncoding, + newDpasEncoding.getOpsPerChannel()); + } else { + newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), + oldEncoding.getParent(), + oldEncoding.getKWidth() * 2); + } + retTy = RankedTensorType::get(newShape, elemType, newVEncoding); } inferredReturnTypes.push_back(retTy); } else { diff --git a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir index 0cb3dc2a44..e46f79f82f 100644 --- a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir +++ b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir @@ -201,3 +201,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> + +module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} { + // CHECK: [[BLOCKED:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> + // CHECK: [[BLOCKED1:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK: [[BLOCKED2:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> + // CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> + // CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}> + // CHECK: dot_scaled + tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { + // CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]> + // CHECK: [[C:%.*]] = triton_gpu.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]> + // CHECK: [[CVT_ARG0:%.*]] = triton_gpu.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> + // CHECK: [[CVT_ARG1:%.*]] = triton_gpu.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]> + // CHECK: [[A:%.*]] = triton_gpu.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> + // CHECK: [[B:%.*]] = triton_gpu.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> + // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]> + // CHECK: [[RES:%.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]> + + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h index b159017e34..0b5ab80873 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h @@ -3,6 +3,10 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" +namespace mlir { +class ModuleOp; +} + #define GET_ATTRDEF_CLASSES #include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc" diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index b6d1fdd109..c7858ee184 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -74,7 +74,6 @@ along the row (resp. col) dimension. ); let extraClassDeclaration = extraDistributedDeclaration # [{ - SmallVector getDPASInstShapeA() const; SmallVector getDPASInstShapeB() const; SmallVector getDPASInstShapeC() const; @@ -91,7 +90,17 @@ along the row (resp. col) dimension. return true; } - SmallVector getContigPerThread(); + SmallVector getContigPerThread() const; + + struct DPASCapability { + uint32_t systolicDepth; + uint32_t repeatCount; + uint32_t executionSize; + uint32_t opsChanBitWidths; + }; + + static DPASCapability getDPASCapability(mlir::ModuleOp mod); + static unsigned getOpsPerChannel(DPASCapability dpasCap, Type elemType); }]; let hasCustomAssemblyFormat = 1; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index af81751dda..6b4eddf0c4 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -380,7 +380,7 @@ SmallVector DpasEncodingAttr::getElemsPerThreadForOperands( return elemsPerThread; }; -SmallVector DpasEncodingAttr::getContigPerThread() { +SmallVector DpasEncodingAttr::getContigPerThread() const { size_t rank = getWarpsPerCTA().size(); assert(rank == 2 || rank == 3); SmallVector contigPerThread(rank, 1); @@ -404,6 +404,56 @@ SmallVector DpasEncodingAttr::getContigPerThread() { "be smaller than the threads required per row."); } +DpasEncodingAttr::DPASCapability +DpasEncodingAttr::getDPASCapability(ModuleOp mod) { + assert(mod && "expected a valid module"); + + if (!mod->hasAttrOfType( + triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) + return {}; + + unsigned minSGSize = + mod->getAttrOfType( + triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) + .getInt(); + + switch (minSGSize) { + case 8: { + DPASCapability cap; + cap.systolicDepth = 8; + cap.repeatCount = 8; + cap.executionSize = 8; + cap.opsChanBitWidths = 32; + return cap; + } + case 16: { + DPASCapability cap; + cap.systolicDepth = 8; + cap.repeatCount = 8; + cap.executionSize = 16; + cap.opsChanBitWidths = 32; + return cap; + } + default: + return {}; + } +} + +unsigned +DpasEncodingAttr::getOpsPerChannel(DpasEncodingAttr::DPASCapability dpasCap, + Type elemType) { + if (!elemType.isIntOrFloat()) + llvm::report_fatal_error("unsupported type for DpasEncodingAttr"); + + unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth(); + + // We are upcasting FP8 to FP16 + if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN()) + dpasElemBitWidths = 2 * dpasElemBitWidths; + + return dpasCap.opsChanBitWidths / dpasElemBitWidths; +} + LogicalResult DpasEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned repeatCount, unsigned systolicDepth, unsigned executionSize, @@ -492,18 +542,14 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const { llvm::ArrayRef rC = shapeC; auto warpsPerCTA = getWarpsPerCTA(); auto repCluster = getRepCluster(); - printer << "<{" - << "repeatCount = " << getRepeatCount() << ", " + printer << "<{" << "repeatCount = " << getRepeatCount() << ", " << "systolicDepth = " << getSystolicDepth() << ", " << "executionSize = " << getExecutionSize() << ", " << "opsPerChan = " << getOpsPerChannel() << ", " << "threadsPerWarp = " << getSubGroupSize() << ", " << "warpsPerCTA = [" << llvm::ArrayRef(warpsPerCTA) << "], " - << "repCluster = [" << repCluster << "], " - << "A = [" << rA << "], " - << "B = [" << rB << "], " - << "C = [" << rC << "]" - << "}>"; + << "repCluster = [" << repCluster << "], " << "A = [" << rA << "], " + << "B = [" << rB << "], " << "C = [" << rC << "]" << "}>"; } std::optional @@ -576,13 +622,10 @@ Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) { void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const { auto threadsPerWarp = getThreadsPerWarp(); auto sizePerThread = getSizePerThread(); - printer << "<{" - << "sizePerThread = [" << llvm::ArrayRef(sizePerThread) - << "]" + printer << "<{" << "sizePerThread = [" + << llvm::ArrayRef(sizePerThread) << "]" << ", threadsPerWarp = [" << llvm::ArrayRef(threadsPerWarp) - << "]" - << ", order = [" << getOrder() << "]" - << "}>"; + << "]" << ", order = [" << getOrder() << "]" << "}>"; } //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 91227665e2..5c0f0a4c79 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -28,40 +28,10 @@ namespace mlir::triton::gpu::intel { namespace { -struct IntelDPASCapability { - uint32_t systolicDepth; - uint32_t repeatCount; - uint32_t executionSize; - uint32_t opsChanBitWidths; -}; - -IntelDPASCapability getDPASCapability(unsigned minSGSize) { - switch (minSGSize) { - case 8: { - IntelDPASCapability cap; - cap.systolicDepth = 8; - cap.repeatCount = 8; - cap.executionSize = 8; - cap.opsChanBitWidths = 32; - return cap; - } - case 16: { - IntelDPASCapability cap; - cap.systolicDepth = 8; - cap.repeatCount = 8; - cap.executionSize = 16; - cap.opsChanBitWidths = 32; - return cap; - } - default: - return IntelDPASCapability(); - } -} - -SmallVector getWarpsPerTile(tt::DotOp dotOp, - struct IntelDPASCapability dpasCap, - const ArrayRef shape, - unsigned numWarps) { +SmallVector +getWarpsPerTile(tt::DotOp dotOp, + ttg::intel::DpasEncodingAttr::DPASCapability dpasCap, + const ArrayRef shape, unsigned numWarps) { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; @@ -142,20 +112,10 @@ class BlockedToDPAS : public OpRewritePattern { auto oldAType = cast(a.getType()); auto oldBType = cast(b.getType()); - unsigned minSGSize = - mod->getAttrOfType( - ttg::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) - .getInt(); - IntelDPASCapability dpasCap = getDPASCapability(minSGSize); - unsigned dpasElemBitWidths = - oldAType.getElementType().getIntOrFloatBitWidth(); - - // We are upcasting FP8 to FP16 - if (oldAType.getElementType().isFloat8E5M2() || - oldAType.getElementType().isFloat8E4M3FN()) - dpasElemBitWidths = 2 * dpasElemBitWidths; - - unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; + auto dpasCap = ttg::intel::DpasEncodingAttr::getDPASCapability(mod); + Type elemType = oldAType.getElementType(); + unsigned opsPerChan = + ttg::intel::DpasEncodingAttr::getOpsPerChannel(dpasCap, elemType); SmallVector warpsPerTile = getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); size_t rank = retShape.size(); @@ -168,6 +128,14 @@ class BlockedToDPAS : public OpRewritePattern { threadsPerWarp); if (dpasCap.executionSize == 16 /* PVC */) { + unsigned dpasElemBitWidths = + oldAType.getElementType().getIntOrFloatBitWidth(); + + // We are upcasting FP8 to FP16 + if (oldAType.getElementType().isFloat8E5M2() || + oldAType.getElementType().isFloat8E4M3FN()) + dpasElemBitWidths = 2 * dpasElemBitWidths; + // Enlarge the repCluster size to use the large 2D load for A and B // operands. unsigned maxRepClusterM = @@ -241,6 +209,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { isa(oldRetType.getEncoding())) return failure(); + MLIRContext *ctx = scaledDotOp.getContext(); TensorValue a = scaledDotOp.getLhs(); TensorValue b = scaledDotOp.getRhs(); TensorValue scale = scaledDotOp.getLhsScale(); @@ -257,159 +226,121 @@ class DecomposeScaledBlocked : public OpRewritePattern { bType == tt::ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); + // Convert accumulator. ttg::intel::DpasEncodingAttr dpasEnc = getDPASEncoding(rewriter, scaledDotOp); - auto newRetType = RankedTensorType::get( oldRetType.getShape(), oldRetType.getElementType(), dpasEnc); - llvm::errs() << "newRetType: " << newRetType << "\n"; - - // convert accumulator TensorValue oldAcc = scaledDotOp.getC(); TensorValue newAcc = rewriter.create( oldAcc.getLoc(), newRetType, oldAcc); - llvm::errs() << "newAcc: " << newAcc << "\n"; - MLIRContext *ctx = scaledDotOp.getContext(); + // Upcast A operand. + auto dpasEncForA = ttg::intel::DpasEncodingAttr::get( + ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), + dpasEnc.getExecutionSize(), 2 * dpasEnc.getOpsPerChannel(), + dpasEnc.getWarpsPerCTA(), dpasEnc.getRepCluster(), + dpasEnc.getSubGroupSize()); auto newAEncoding = ttg::DotOperandEncodingAttr::get( - ctx, 0, dpasEnc, dpasEnc.getOpsPerChannel()); - llvm::errs() << "newAEncoding: " << newAEncoding << "\n"; + ctx, 0, dpasEncForA, dpasEncForA.getOpsPerChannel()); + a = createArg(rewriter, a, aType, newAEncoding); auto mod = scaledDotOp->getParentOfType(); - unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[1]; + SmallVector threadsPerWarp{instrShapeM, warpSize / instrShapeM}; auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); - auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + auto newScaleEncoding = ttg::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), newAEncoding.getCTAOrder(), CTALayout); - a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); - llvm::errs() << "operand A: " << a << "\n"; + scale = createScale(rewriter, scale, newScaleEncoding); - // Upcast B operand + auto retTypeEncoding = ttg::DotOperandEncodingAttr::get( + ctx, 0, dpasEnc, dpasEnc.getOpsPerChannel()); + a = createUpcastMxfpOp(rewriter, a, scale, aType, retTypeEncoding); + + // Create B operand. assert(bType != tt::ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); auto newBEncoding = ttg::DotOperandEncodingAttr::get( ctx, 1, dpasEnc, dpasEnc.getOpsPerChannel()); - b = createArg(rewriter, b, 1, bType, newBEncoding, - /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt); - llvm::errs() << "operand B: " << b << "\n"; + b = createArg(rewriter, b, bType, newBEncoding); auto newDot = rewriter.create(scaledDotOp.getLoc(), newRetType, a, b, newAcc); - llvm::errs() << "dot: " << *newDot << "\n"; - + rewriter.replaceOpWithNewOp(scaledDotOp, oldRetType, + newDot); return success(); } private: - TensorValue createArg(mlir::PatternRewriter &rewriter, TensorValue v, int idx, - tt::ScaleDotElemType type, - std::optional vEncoding, - std::optional opt_scale, - std::optional scaleEncoding) const { - MLIRContext *ctx = rewriter.getContext(); - // Create a new tensor with a given encoding or remove the encoding - auto maybeWithEncoding = [](RankedTensorType ty, - std::optional enc) { - if (enc.has_value()) - return RankedTensorType::get(ty.getShape(), ty.getElementType(), *enc); - - return RankedTensorType::get(ty.getShape(), ty.getElementType()); - }; - - RankedTensorType newVType = maybeWithEncoding(v.getType(), vEncoding); - TensorValue ret = - rewriter.create(v.getLoc(), newVType, v); - - // convert to bf16 - if (type != tt::ScaleDotElemType::E2M1 && - type != tt::ScaleDotElemType::BF16) { - assert(type == tt::ScaleDotElemType::E5M2 || - type == tt::ScaleDotElemType::E4M3); - auto vTypeBf16 = RankedTensorType::get( - newVType.getShape(), rewriter.getBF16Type(), newVType.getEncoding()); - ret = rewriter.create(v.getLoc(), vTypeBf16, ret); - } - if (opt_scale.has_value()) { - TensorValue scale = *opt_scale; - assert(idx == 0 && "NYI: rhs scale"); - RankedTensorType newScaleDotElemType = - maybeWithEncoding(scale.getType(), scaleEncoding); - scale = rewriter.create(scale.getLoc(), - newScaleDotElemType, scale); - ret = rewriter.create(v.getLoc(), ret, scale, type); - } - return ret; - } - ttg::intel::DpasEncodingAttr getDPASEncoding(PatternRewriter &rewriter, tt::DotScaledOp scaledDotOp) const { - llvm::errs() << "at line: " << __LINE__ << "\n"; - MLIRContext *ctx = rewriter.getContext(); - TensorValue a = scaledDotOp.getLhs(); - TensorValue b = scaledDotOp.getRhs(); - TensorValue scale = scaledDotOp.getLhsScale(); - tt::ScaleDotElemType aType = scaledDotOp.getLhsType(); - tt::ScaleDotElemType bType = scaledDotOp.getRhsType(); - - Location loc = scaledDotOp.getLoc(); - RankedTensorType aTType = - createArg(rewriter, a, 0, aType, /*vEncoding=*/std::nullopt, scale, - /*scaleEncoding=*/std::nullopt) - .getType(); - auto aTypeNoEnc = - RankedTensorType::get(aTType.getShape(), aTType.getElementType()); - - llvm::errs() << "aTType: " << aTType << "\n"; - llvm::errs() << "aTypeNoEnc: " << aTypeNoEnc << "\n"; - - a = rewriter.create(loc, aTypeNoEnc, a); - llvm::errs() << "a: " << a << "\n"; - - RankedTensorType bTType = - createArg(rewriter, b, 1, bType, /*vEncoding=*/std::nullopt, - /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt) - .getType(); - auto bTypeNoEnc = - RankedTensorType::get(bTType.getShape(), bTType.getElementType()); - llvm::errs() << "bTType: " << bTType << "\n"; - llvm::errs() << "bTypeNoEnc: " << bTypeNoEnc << "\n"; - - b = rewriter.create(loc, bTypeNoEnc, b); - llvm::errs() << "b: " << b << "\n"; - - RankedTensorType oldRetType = scaledDotOp.getType(); - llvm::errs() << "oldRetType: " << oldRetType << "\n"; - auto dotOp = - rewriter.create(loc, oldRetType, a, b, scaledDotOp.getC()); - llvm::errs() << "dotOp: " << dotOp << "\n"; - - ArrayRef retShape = oldRetType.getShape(); auto mod = scaledDotOp->getParentOfType(); + auto dpasCap = ttg::intel::DpasEncodingAttr::getDPASCapability(mod); + Type elemType = scaledDotOp.getRhs().getType().getElementType(); + unsigned opsPerChan = + ttg::intel::DpasEncodingAttr::getOpsPerChannel(dpasCap, elemType); - unsigned minSGSize = - mod->getAttrOfType( - ttg::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) - .getInt(); - IntelDPASCapability dpasCap = getDPASCapability(minSGSize); - - unsigned dpasElemBitWidths = - aTType.getElementType().getIntOrFloatBitWidth(); - unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; + unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + SmallVector warpsPerTile = {numWarps, 1}; - SmallVector warpsPerTile = getWarpsPerTile( - dotOp, dpasCap, retShape, ttg::TritonGPUDialect::getNumWarps(mod)); + ArrayRef retShape = scaledDotOp.getType().getShape(); size_t rank = retShape.size(); SmallVector repCluster(rank, 1); unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); - auto dpasEnc = ttg::intel::DpasEncodingAttr::get( - oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, + + return ttg::intel::DpasEncodingAttr::get( + rewriter.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, threadsPerWarp); + } - llvm::errs() << "dpasEnc: " << dpasEnc << "\n"; + TensorValue createArg(PatternRewriter &rewriter, TensorValue v, + tt::ScaleDotElemType type, Attribute vEncoding) const { + RankedTensorType vType = v.getType(); + auto newVType = RankedTensorType::get(vType.getShape(), + vType.getElementType(), vEncoding); + TensorValue ret = + rewriter.create(v.getLoc(), newVType, v); + if (type != tt::ScaleDotElemType::E2M1 && + type != tt::ScaleDotElemType::BF16) { + // convert to bf16 + assert(type == tt::ScaleDotElemType::E5M2 || + type == tt::ScaleDotElemType::E4M3); + auto vTypeBf16 = RankedTensorType::get( + newVType.getShape(), rewriter.getBF16Type(), newVType.getEncoding()); + ret = rewriter.create(v.getLoc(), vTypeBf16, ret); + } + return ret; + } + + TensorValue createScale(PatternRewriter &rewriter, TensorValue scale, + Attribute scaleEncoding) const { + RankedTensorType scaleType = scale.getType(); + auto newScaleDotElemType = RankedTensorType::get( + scaleType.getShape(), scaleType.getElementType(), scaleEncoding); + return rewriter.create(scale.getLoc(), + newScaleDotElemType, scale); + } - return dpasEnc; + TensorValue createUpcastMxfpOp(PatternRewriter &rewriter, TensorValue a, + TensorValue scale, tt::ScaleDotElemType type, + Attribute retTypeEncoding) const { + auto aType = cast(a.getType()); + auto retType = RankedTensorType::get( + aType.getShape(), aType.getElementType(), retTypeEncoding); + if (type == tt::ScaleDotElemType::E2M1) { + RankedTensorType retTy; + SmallVector newShape(aType.getShape()); + newShape.back() *= 2; + retType = RankedTensorType::get( + newShape, FloatType::getBF16(rewriter.getContext()), retTypeEncoding); + } + // TODO: Check whether constructing without explicit retType works. + return rewriter.create(a.getLoc(), retType, a, scale, + type); } }; From c230c6f0a7ee716fb2b10e3c1273c645340c5906 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 26 Nov 2024 19:13:56 +0000 Subject: [PATCH 5/9] WIP: tt.scaled_dot Signed-off-by: Tiotto, Ettore --- lib/Dialect/TritonGPU/IR/Ops.cpp | 9 +-- .../IR/TritonIntelGPUAttrDefs.td | 25 ++++-- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 36 ++------- .../lib/TritonIntelGPUToLLVM/CMakeLists.txt | 1 - .../PatternTritonGPUOpToLLVM.h | 5 -- .../TritonIntelGPUToLLVM/PipelineManager.h | 2 - .../TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp | 44 +++++++--- .../AccelerateMatmul.cpp | 80 ++++++++++++------- 8 files changed, 112 insertions(+), 90 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index e4e2d412d1..be4b7c4873 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -115,6 +115,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx)); } else { auto oldEncoding = cast(encoding); + const int opIdx = oldEncoding.getOpIdx(); const bool hasBatch = xShape.size() == 3; const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; @@ -127,12 +128,10 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( Attribute newVEncoding; if (auto dpasEncoding = dyn_cast(oldEncoding.getParent())) { - auto mod = operands[0].getDefiningOp()->getParentOfType(); - auto dpasCap = intel::DpasEncodingAttr::getDPASCapability(mod); auto newDpasEncoding = intel::DpasEncodingAttr::get( - ctx, dpasCap.repeatCount, dpasCap.systolicDepth, - dpasCap.executionSize, - intel::DpasEncodingAttr::getOpsPerChannel(dpasCap, elemType), + ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(), + dpasEncoding.getExecutionSize(), + intel::DpasEncodingAttr::getOpsPerChannel(elemType), dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), dpasEncoding.getSubGroupSize()); newVEncoding = DotOperandEncodingAttr::get( diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 637c1d8ae6..a52780f029 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -93,14 +93,27 @@ along the row (resp. col) dimension. SmallVector getContigPerThread() const; struct DPASCapability { - uint32_t systolicDepth; - uint32_t repeatCount; - uint32_t executionSize; - uint32_t opsChanBitWidths; + DPASCapability(unsigned minSGSize) : executionSize(minSGSize) {} + DPASCapability() = default; + + bool isPVC() const { + return executionSize == 16; + } + bool isFalconShore() const { + return executionSize == 16; + } + bool isATSM() const { + return executionSize == 8; + } + + static constexpr unsigned systolicDepth = 8u; + static constexpr unsigned repeatCount = 8u; + static constexpr unsigned opsChanBitWidths = 32u; + unsigned executionSize = 0u; }; - + static DPASCapability getDPASCapability(mlir::ModuleOp mod); - static unsigned getOpsPerChannel(DPASCapability dpasCap, Type elemType); + static unsigned getOpsPerChannel(Type elemType); }]; let hasCustomAssemblyFormat = 1; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 84adb7ca05..5b8d7887af 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -384,51 +384,27 @@ SmallVector DpasEncodingAttr::getContigPerThread() const { DpasEncodingAttr::DPASCapability DpasEncodingAttr::getDPASCapability(ModuleOp mod) { assert(mod && "expected a valid module"); - if (!mod->hasAttrOfType( triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) - return {}; + return DPASCapability(); unsigned minSGSize = mod->getAttrOfType( triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) .getInt(); - - switch (minSGSize) { - case 8: { - DPASCapability cap; - cap.systolicDepth = 8; - cap.repeatCount = 8; - cap.executionSize = 8; - cap.opsChanBitWidths = 32; - return cap; - } - case 16: { - DPASCapability cap; - cap.systolicDepth = 8; - cap.repeatCount = 8; - cap.executionSize = 16; - cap.opsChanBitWidths = 32; - return cap; - } - default: - return {}; - } + assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize"); + return DPASCapability(minSGSize); } -unsigned -DpasEncodingAttr::getOpsPerChannel(DpasEncodingAttr::DPASCapability dpasCap, - Type elemType) { +unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) { if (!elemType.isIntOrFloat()) llvm::report_fatal_error("unsupported type for DpasEncodingAttr"); unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth(); - - // We are upcasting FP8 to FP16 if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN()) - dpasElemBitWidths = 2 * dpasElemBitWidths; + dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16. - return dpasCap.opsChanBitWidths / dpasElemBitWidths; + return DPASCapability::opsChanBitWidths / dpasElemBitWidths; } LogicalResult DpasEncodingAttr::verify( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index 211c2b185a..4e86cbd2f2 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -24,7 +24,6 @@ add_triton_library(TritonIntelGPUToLLVM TritonGPUToLLVM.cpp TritonOpsToLLVM.cpp TypeConverter.cpp - UpcastMXFPToLLVM.cpp Utility.cpp ViewOpToLLVM.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index 4ee1a012bc..40116a17ca 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -41,11 +41,6 @@ void populateElementwiseOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit); -void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - const TargetInfo &targetInfo, - PatternBenefit benefit); - void populateBF16CastsLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h index b569857fe1..f9ddb046c7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h @@ -248,8 +248,6 @@ class TritonGPUToLLVMPipelineManager { targetInfo, benefit); intel::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, benefit); - intel::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns, - targetInfo, benefit); } intel::populateSPMDOpToLLVMPattern(typeConverter, patterns, targetInfo, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp index 59d90e2930..c0f4e12e8a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -30,26 +30,45 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto operands = adaptor.getOperands(); - SmallVector xVals = unpackLLElements(loc, operands[0], rewriter); - SmallVector scaleVals = unpackLLElements(loc, operands[1], rewriter); + return failure(); + + // TODO: Implement this +#if 0 ScaleDotElemType fpType = op.getFpType(); + Location loc = op.getLoc(); + SmallVector xVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector scaleVals = + unpackLLElements(loc, adaptor.getScale(), rewriter); - Value tid = tid_val(); auto mod = op->getParentOfType(); - Value warpSize = - i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + llvm::errs() << "mod: " << mod << "\n"; + llvm::errs() << "adaptor.getScale(): " << adaptor.getScale() << "\n"; + + LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); + LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); + + bool isPacked = fpType == ScaleDotElemType::E2M1; + if (xVals.size() != scaleVals.size() * (isPacked ? 16 : 32)) + return rewriter.notifyMatchFailure(op, "unsupported problem size"); + + unsigned numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = i32_val(numThreads); + Value tid = tid_val(); Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (fpType == ScaleDotElemType::E2M1) + if (fpType == ScaleDotElemType::E2M1) { + llvm::errs() << "at line: " << __LINE__ << "\n"; xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); + } + + llvm::errs() << "xVals size: " << xVals.size() << "\n"; // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + // 16, c + 17 - auto c = mul(udiv(laneId, i32_val(4)), i32_val(2)); + LLVM::MulOp c = mul(udiv(laneId, i32_val(4)), i32_val(2)); std::array ci = {c, add(c, i32_val(1)), add(c, i32_val(16)), add(c, i32_val(17))}; @@ -63,8 +82,10 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { }; for (int j = 0; j < 32; ++j) { - xVals[32 * i + j] = - LLVM::mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], si[j / 8]); + int index = 32 * i + j; + llvm::errs() << "index: " << index << "\n"; + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); } } @@ -72,6 +93,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); rewriter.replaceOp(op, result); return success(); +#endif } }; } // anonymous namespace diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index a74cc3afce..ff88d287dc 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -115,7 +115,7 @@ class BlockedToDPAS : public OpRewritePattern { auto dpasCap = ttg::intel::DpasEncodingAttr::getDPASCapability(mod); Type elemType = oldAType.getElementType(); unsigned opsPerChan = - ttg::intel::DpasEncodingAttr::getOpsPerChannel(dpasCap, elemType); + ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType); SmallVector warpsPerTile = getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); size_t rank = retShape.size(); @@ -127,7 +127,7 @@ class BlockedToDPAS : public OpRewritePattern { dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, threadsPerWarp); - if (dpasCap.executionSize == 16 /* PVC */) { + if (dpasCap.isPVC() || dpasCap.isFalconShore()) { unsigned dpasElemBitWidths = oldAType.getElementType().getIntOrFloatBitWidth(); @@ -205,46 +205,64 @@ class DecomposeScaledBlocked : public OpRewritePattern { matchAndRewrite(tt::DotScaledOp scaledDotOp, PatternRewriter &rewriter) const override { RankedTensorType oldRetType = scaledDotOp.getType(); - if (!oldRetType.getEncoding() || - isa(oldRetType.getEncoding())) - return failure(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + scaledDotOp, "expected blocked encoding result tensor"); + + unsigned rank = oldRetType.getRank(); + if (rank == 3) + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: 3d case"); - MLIRContext *ctx = scaledDotOp.getContext(); TensorValue a = scaledDotOp.getLhs(); TensorValue b = scaledDotOp.getRhs(); - TensorValue scale = scaledDotOp.getLhsScale(); - tt::ScaleDotElemType aType = scaledDotOp.getLhsType(); - tt::ScaleDotElemType bType = scaledDotOp.getRhsType(); - - assert((aType == tt::ScaleDotElemType::E4M3 || - aType == tt::ScaleDotElemType::E5M2 || - aType == tt::ScaleDotElemType::E2M1) && - "NYI: lhs supports fp4 or fp8"); - assert(bType == tt::ScaleDotElemType::E4M3 || - bType == tt::ScaleDotElemType::E5M2 || - bType == tt::ScaleDotElemType::BF16 && - "NYI: rhs supports fp8 and bf16"); + TensorValue aScale = scaledDotOp.getLhsScale(); + TensorValue bScale = scaledDotOp.getRhsScale(); + if (bScale) + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: RHS scale"); + if (aScale && bScale) + return rewriter.notifyMatchFailure(scaledDotOp, + "NYI: both LHS and RHS scale"); + assert(aScale && "LHS scale missing"); + + tt::ScaleDotElemType aElemType = scaledDotOp.getLhsType(); + tt::ScaleDotElemType bElemType = scaledDotOp.getRhsType(); + + auto supportsTypes = [](tt::ScaleDotElemType elemType) { + return elemType == tt::ScaleDotElemType::E2M1 || + elemType == tt::ScaleDotElemType::E4M3 || + elemType == tt::ScaleDotElemType::E5M2 || + elemType == tt::ScaleDotElemType::BF16; + }; + if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: mxfp6 operand"); + + MLIRContext *ctx = scaledDotOp.getContext(); + auto mod = scaledDotOp->getParentOfType(); // Convert accumulator. ttg::intel::DpasEncodingAttr dpasEnc = getDPASEncoding(rewriter, scaledDotOp); + llvm::errs() << "dpasEnc: " << dpasEnc << "\n"; + auto newRetType = RankedTensorType::get( oldRetType.getShape(), oldRetType.getElementType(), dpasEnc); TensorValue oldAcc = scaledDotOp.getC(); TensorValue newAcc = rewriter.create( oldAcc.getLoc(), newRetType, oldAcc); + unsigned opsPerChannel = dpasEnc.getOpsPerChannel(); + if (aElemType == tt::ScaleDotElemType::E2M1) + opsPerChannel *= 2; + // Upcast A operand. auto dpasEncForA = ttg::intel::DpasEncodingAttr::get( ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), - dpasEnc.getExecutionSize(), 2 * dpasEnc.getOpsPerChannel(), - dpasEnc.getWarpsPerCTA(), dpasEnc.getRepCluster(), - dpasEnc.getSubGroupSize()); + dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), + dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize()); auto newAEncoding = ttg::DotOperandEncodingAttr::get( ctx, 0, dpasEncForA, dpasEncForA.getOpsPerChannel()); - a = createArg(rewriter, a, aType, newAEncoding); + a = createArg(rewriter, a, aElemType, newAEncoding); - auto mod = scaledDotOp->getParentOfType(); unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod); unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[1]; SmallVector threadsPerWarp{instrShapeM, warpSize / instrShapeM}; @@ -252,26 +270,29 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto newScaleEncoding = ttg::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), newAEncoding.getCTAOrder(), CTALayout); - scale = createScale(rewriter, scale, newScaleEncoding); + aScale = createScale(rewriter, aScale, newScaleEncoding); auto retTypeEncoding = ttg::DotOperandEncodingAttr::get( ctx, 0, dpasEnc, dpasEnc.getOpsPerChannel()); - a = createUpcastMxfpOp(rewriter, a, scale, aType, retTypeEncoding); + Value scaledA = + createUpcastMxfpOp(rewriter, a, aScale, aElemType, retTypeEncoding); // Create B operand. - assert(bType != tt::ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); + assert(bElemType != tt::ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); auto newBEncoding = ttg::DotOperandEncodingAttr::get( ctx, 1, dpasEnc, dpasEnc.getOpsPerChannel()); - b = createArg(rewriter, b, bType, newBEncoding); + b = createArg(rewriter, b, bElemType, newBEncoding); auto newDot = rewriter.create(scaledDotOp.getLoc(), newRetType, - a, b, newAcc); + scaledA, b, newAcc); rewriter.replaceOpWithNewOp(scaledDotOp, oldRetType, newDot); return success(); } private: + // TODO: this works only when the operand to scale is on the LHS, extend to + // support scaling the RHS operand. ttg::intel::DpasEncodingAttr getDPASEncoding(PatternRewriter &rewriter, tt::DotScaledOp scaledDotOp) const { @@ -279,8 +300,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto dpasCap = ttg::intel::DpasEncodingAttr::getDPASCapability(mod); Type elemType = scaledDotOp.getRhs().getType().getElementType(); unsigned opsPerChan = - ttg::intel::DpasEncodingAttr::getOpsPerChannel(dpasCap, elemType); - + ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType); unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod); SmallVector warpsPerTile = {numWarps, 1}; From 5772d3f94bf436b32735fad2606a86e3c23a730e Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 26 Nov 2024 19:16:06 +0000 Subject: [PATCH 6/9] WIP: tt.scaled_dot Signed-off-by: Tiotto, Ettore --- .../TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp | 105 ------------------ 1 file changed, 105 deletions(-) delete mode 100644 third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp deleted file mode 100644 index c0f4e12e8a..0000000000 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp +++ /dev/null @@ -1,105 +0,0 @@ -#include "PatternTritonGPUOpToLLVM.h" - -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::gpu; - -namespace { - -class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { -private: - const TargetInfoBase &targetInfo; - -public: - UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, PatternBenefit benefit) - : ConvertOpToLLVMPattern(typeConverter, benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - return failure(); - - // TODO: Implement this -#if 0 - ScaleDotElemType fpType = op.getFpType(); - Location loc = op.getLoc(); - SmallVector xVals = - unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector scaleVals = - unpackLLElements(loc, adaptor.getScale(), rewriter); - - auto mod = op->getParentOfType(); - llvm::errs() << "mod: " << mod << "\n"; - llvm::errs() << "adaptor.getScale(): " << adaptor.getScale() << "\n"; - - LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); - LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); - - bool isPacked = fpType == ScaleDotElemType::E2M1; - if (xVals.size() != scaleVals.size() * (isPacked ? 16 : 32)) - return rewriter.notifyMatchFailure(op, "unsupported problem size"); - - unsigned numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - Value warpSize = i32_val(numThreads); - Value tid = tid_val(); - Value warpId = udiv(tid, warpSize); - Value laneId = urem(tid, warpSize); - - if (fpType == ScaleDotElemType::E2M1) { - llvm::errs() << "at line: " << __LINE__ << "\n"; - xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); - } - - llvm::errs() << "xVals size: " << xVals.size() << "\n"; - - // Each thread owns elements of 4 mxfp vectors so we need 4 scales - // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + - // 16, c + 17 - LLVM::MulOp c = mul(udiv(laneId, i32_val(4)), i32_val(2)); - std::array ci = {c, add(c, i32_val(1)), add(c, i32_val(16)), - add(c, i32_val(17))}; - - for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { - // column major as per the DotOperandEncoding(opidx=0) layout - auto si = std::array{ - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]), - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]), - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]), - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), - }; - - for (int j = 0; j < 32; ++j) { - int index = 32 * i + j; - llvm::errs() << "index: " << index << "\n"; - xVals[index] = - LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); - } - } - - Value result = - packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); - rewriter.replaceOp(op, result); - return success(); -#endif - } -}; -} // anonymous namespace - -void mlir::triton::intel::populateUpcastMXFPToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - const TargetInfo &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); -} From bbf250d0a3da7285750a59911b8c6efbe4a9d824 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 26 Nov 2024 21:13:12 +0000 Subject: [PATCH 7/9] Fix precommit Signed-off-by: Tiotto, Ettore --- .../include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index a52780f029..ffac434298 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -111,7 +111,7 @@ along the row (resp. col) dimension. static constexpr unsigned opsChanBitWidths = 32u; unsigned executionSize = 0u; }; - + static DPASCapability getDPASCapability(mlir::ModuleOp mod); static unsigned getOpsPerChannel(Type elemType); }]; From bd1094c7eb1b0238544da003cb83d78abec9c6e8 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 29 Nov 2024 14:56:33 +0000 Subject: [PATCH 8/9] Address code review comments Signed-off-by: Tiotto, Ettore --- .../IR/TritonIntelGPUAttrDefs.td | 2 +- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 22 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index ffac434298..8cb10fcfe4 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -93,7 +93,7 @@ along the row (resp. col) dimension. SmallVector getContigPerThread() const; struct DPASCapability { - DPASCapability(unsigned minSGSize) : executionSize(minSGSize) {} + explicit DPASCapability(unsigned minSGSize) : executionSize(minSGSize) {} DPASCapability() = default; bool isPVC() const { diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 5b8d7887af..661d2bb9cb 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -384,21 +384,19 @@ SmallVector DpasEncodingAttr::getContigPerThread() const { DpasEncodingAttr::DPASCapability DpasEncodingAttr::getDPASCapability(ModuleOp mod) { assert(mod && "expected a valid module"); - if (!mod->hasAttrOfType( - triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) - return DPASCapability(); - - unsigned minSGSize = - mod->getAttrOfType( - triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) - .getInt(); - assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize"); - return DPASCapability(minSGSize); + + if (auto minSGSizeAttr = mod->getAttrOfType( + triton::gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())) { + unsigned minSGSize = minSGSizeAttr.getInt(); + assert(minSGSize == 8 || minSGSize == 16 && "unsupported minSGSize"); + return DPASCapability(minSGSize); + } + + return DPASCapability(); } unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) { - if (!elemType.isIntOrFloat()) - llvm::report_fatal_error("unsupported type for DpasEncodingAttr"); + assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr"); unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth(); if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN()) From 75da01c47b6346a7b559567cad46df243da7b3e5 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 29 Nov 2024 17:01:22 +0000 Subject: [PATCH 9/9] Change triton_gpu --> ttg Signed-off-by: Tiotto, Ettore --- .../TritonIntelGPU/accelerate-matmul-pvc.mlir | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir index 81e7c22770..ecfecf09be 100644 --- a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir +++ b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir @@ -204,26 +204,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> - -module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} { - // CHECK: [[BLOCKED:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> - // CHECK: [[BLOCKED1:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> - // CHECK: [[BLOCKED2:#.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> + +module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} { + // CHECK: [[BLOCKED:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> + // CHECK: [[BLOCKED1:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK: [[BLOCKED2:#.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[DPAS:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> // CHECK: [[DPAS1:#.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [1, 1], A = [8, 32], B = [32, 16], C = [8, 16]}> // CHECK: dot_scaled tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { // CHECK: [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32, [[BLOCKED2]]> - // CHECK: [[C:%.*]] = triton_gpu.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]> - // CHECK: [[CVT_ARG0:%.*]] = triton_gpu.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> - // CHECK: [[CVT_ARG1:%.*]] = triton_gpu.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]> - // CHECK: [[A:%.*]] = triton_gpu.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> - // CHECK: [[B:%.*]] = triton_gpu.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> - // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]> - // CHECK: [[RES:%.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]> + // CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]> + // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> + // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED1]]> + // CHECK: [[A:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>>, tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> + // CHECK: [[B:%.*]] = ttg.convert_layout %arg2 : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> + // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]> + // CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]> %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>