From 75379aa5688ebb6496453889d49e07575e30a165 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Sun, 26 Jan 2025 11:52:28 +0100 Subject: [PATCH 1/3] Replace isF...() LLVM API calls with the corresponding isa<...>() The isF...() methods have been removed in the main LLVM branch: https://github.com/llvm/llvm-project/pull/123326 --- include/triton/Conversion/MLIRTypes.h | 12 +++++------ lib/Analysis/Utility.cpp | 9 +++++---- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 3 ++- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 6 +++--- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 4 ++-- .../ElementwiseOpToLLVM.cpp | 15 +++++++------- .../AccelerateAMDMatmul.cpp | 3 ++- .../lib/TritonAMDGPUTransforms/MfmaGroup.cpp | 14 ++++++++----- third_party/intel/lib/Analysis/DPAS.cpp | 20 +++++++++---------- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 3 ++- .../ElementwiseOpToLLVM.cpp | 3 ++- .../AccelerateMatmul.cpp | 7 ++++--- .../DotOpToLLVM/MMAv2.cpp | 16 +++++++-------- .../DotOpToLLVM/WGMMA.cpp | 4 ++-- .../ElementwiseOpToLLVM.cpp | 9 +++++---- 15 files changed, 70 insertions(+), 58 deletions(-) diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index afa1aa989e..1fa8543a14 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -28,15 +28,15 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } inline bool isFloat(Type type) { return type.isF32() || type.isF64() || type.isF16() || type.isF128() || - type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + type.isBF16() || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type); } inline bool isFloat8(Type type) { - return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + return isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type); } inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index f801af5163..ac3274a213 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -756,14 +756,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || + (isa(aElemTy) || isa(aElemTy) || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && + (isa(aElemTy) || isa(aElemTy)) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -784,8 +784,9 @@ bool supportMMA(Value value, int version) { cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. - bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + bool isFP8 = isa(elemTy) || isa(elemTy) || + isa(elemTy) || + isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 6d7632b1b7..47707cc672 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -632,7 +632,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); + bool isNativeFP8 = + isa(AElType) || isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 21b8e059ca..4914333d3e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,9 +45,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || - eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || - eltType.isF32()) { + if (isa(eltType) || isa(eltType) || + isa(eltType) || eltType.isF16() || + eltType.isBF16() || eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index a171d89339..4c321c27d4 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); - bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || - aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); + bool isFP8 = isa(aElTy) || isa(aElTy) || + isa(aElTy) || isa(aElTy); bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 35a2e1a34b..ee2441fcc5 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1019,17 +1019,18 @@ struct FpToFpOpConversion return outVals; } size_t numElements = 4; - if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() || - srcElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E4M3FNUZ() || - srcElementType.isFloat8E5M2FNUZ() || - dstElementType.isFloat8E5M2FNUZ()) { + if (isa(srcElementType) || + isa(dstElementType) || + isa(srcElementType) || + isa(dstElementType) || + isa(srcElementType) || + isa(dstElementType)) { numElements = 2; } bool useFP16IntermediateSrc = srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 && - (dstElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E5M2FNUZ())); + (isa(dstElementType) || + isa(dstElementType))); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = isDstFP32 ? f16_ty : dstElementType; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 7ea13142a7..f7057266d7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -416,7 +416,8 @@ class BlockedToMFMA : public OpRewritePattern { // store instructions, except for fp8 matmul kernels due to regression // TODO (lixun): investigate the regression and enable this feature again auto aElemTy = mfmaInstr.getElementTypeA(); - bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ(); + bool isFP8 = + isa(aElemTy) || isa(aElemTy); bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 4979ee005b..59c102c1ed 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -20,19 +20,23 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) { return MfmaTypeId::I8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Fp8Fp8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Fp8Bf8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Bf8Fp8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Bf8Bf8TyId; } - if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) { + if (isa(dataTypeA) && isa(dataTypeB)) { return MfmaTypeId::Fp16TyId; } llvm_unreachable("Unsupported input argument type."); diff --git a/third_party/intel/lib/Analysis/DPAS.cpp b/third_party/intel/lib/Analysis/DPAS.cpp index 3b5975c3c3..89710ce4f2 100644 --- a/third_party/intel/lib/Analysis/DPAS.cpp +++ b/third_party/intel/lib/Analysis/DPAS.cpp @@ -125,9 +125,9 @@ DPASAnalysis::getDPASType(OpTy op) { if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32) return DPASEngineType::FP32_FP32_TF32_TF32; // For FP8XFP8->FP32, upcast to FP16 - if (aElemTy.isFloat8E5M2()) + if (isa(aElemTy)) return DPASEngineType::FP32_FP32_FP16_FP16; - if (aElemTy.isFloat8E4M3FN()) + if (isa(aElemTy)) return DPASEngineType::FP32_FP32_FP16_FP16; } else if (dElemTy.isF16()) { if (aElemTy.isF16()) @@ -148,27 +148,27 @@ DPASAnalysis::getDPASType(OpTy op) { if (isa(dElemTy)) { if (dElemTy.isF32()) { if (aElemTy.isBF16() && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + (isa(bElemTy) || isa(bElemTy))) return DPASEngineType::FP32_FP32_BF16_FP8; // 2 E2M1 are packed into 1 int8 if (aElemTy.isBF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_BF16_FP4; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && + if ((isa(aElemTy) || isa(aElemTy)) && bElemTy.isBF16()) return DPASEngineType::FP32_FP32_FP8_BF16; if (aElemTy.isF16() && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + (isa(bElemTy) || isa(bElemTy))) return DPASEngineType::FP32_FP32_FP16_FP8; // 2 E2M1 are packed into 1 int8 if (aElemTy.isF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_FP16_FP4; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && + if ((isa(aElemTy) || isa(aElemTy)) && bElemTy.isF16()) return DPASEngineType::FP32_FP32_FP8_FP16; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + if ((isa(aElemTy) || isa(aElemTy)) && + (isa(bElemTy) || isa(bElemTy))) return DPASEngineType::FP32_FP32_FP8_FP8; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && + if ((isa(aElemTy) || isa(aElemTy)) && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_FP8_FP4; if (aElemTy.isInteger(8) && bElemTy.isBF16()) @@ -176,7 +176,7 @@ DPASAnalysis::getDPASType(OpTy op) { if (aElemTy.isInteger(8) && bElemTy.isF16()) return DPASEngineType::FP32_FP32_FP4_FP16; if (aElemTy.isInteger(8) && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + (isa(bElemTy) || isa(bElemTy))) return DPASEngineType::FP32_FP32_FP4_FP8; } } diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 8a9c53969d..f920c1180d 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -405,7 +405,8 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) { assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr"); unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth(); - if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN()) + if (llvm::isa(elemType) || + llvm::isa(elemType)) dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16. return DPASCapability::opsChanBitWidths / dpasElemBitWidths; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index 50697c86dc..e78fa2f719 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -960,7 +960,8 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { + if (isa(dstElementType) || + isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for conversions to fp8"); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 170fc7f355..a499147609 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -132,8 +132,8 @@ class BlockedToDPAS : public OpRewritePattern { oldAType.getElementType().getIntOrFloatBitWidth(); // We are upcasting FP8 to FP16 - if (oldAType.getElementType().isFloat8E5M2() || - oldAType.getElementType().isFloat8E4M3FN()) + if (isa(oldAType.getElementType()) || + isa(oldAType.getElementType())) dpasElemBitWidths = 2 * dpasElemBitWidths; // Enlarge the repCluster size to use the large 2D load for A and B @@ -488,7 +488,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { Type promoteType; if (dpasLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); + bool isNativeFP8 = + isa(AElType) || isa(AElType); // fp8 is not natively supported by the the DPAS instruction, promote it // to fp16. if (!isNativeFP8) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c5ec00097d..06901280d0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -299,17 +299,17 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E5M2()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FN()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E5M2()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E4M3FN()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 7450bc3f4e..3f2f78d9e7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -59,9 +59,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::tf32; } else if (aTy.isInteger(8)) { return triton::nvgpu::WGMMAEltType::s8; - } else if (aTy.isFloat8E5M2()) { + } else if (isa(aTy)) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FN()) { + } else if (isa(aTy)) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index d489d0a1b1..8e37a4ad10 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -467,7 +467,7 @@ struct FpToFpOpConversion llvm::report_fatal_error("Unsupported rounding mode for conversion."); } if (computeCapability < 89 && - (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { + (isa(srcTy) || isa(dstTy))) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -489,7 +489,8 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { + if (isa(dstElementType) || + isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -526,8 +527,8 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || - dstElementType.isFloat8E5M2())) || + (!(computeCapability >= 90 && (isa(dstElementType) || + isa(dstElementType))) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; From f9e8aa568febcddca1c569bfd44545cda55826bd Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Tue, 28 Jan 2025 18:41:34 +0100 Subject: [PATCH 2/3] Use isa<> with multiple temlates, syncronized with triton-lang/triton#5684 --- include/triton/Conversion/MLIRTypes.h | 16 +++++++-------- lib/Analysis/Utility.cpp | 9 ++++----- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 3 +-- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 6 +++--- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 4 ++-- .../ElementwiseOpToLLVM.cpp | 16 +++++++-------- .../AccelerateAMDMatmul.cpp | 3 +-- .../lib/TritonAMDGPUTransforms/MfmaGroup.cpp | 19 +++++++++--------- third_party/intel/lib/Analysis/DPAS.cpp | 20 ++++++++----------- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 3 +-- .../ElementwiseOpToLLVM.cpp | 3 +-- .../AccelerateMatmul.cpp | 6 ++---- .../DotOpToLLVM/MMAv2.cpp | 16 +++++++-------- .../DotOpToLLVM/WGMMA.cpp | 4 ++-- .../ElementwiseOpToLLVM.cpp | 11 +++++----- 15 files changed, 62 insertions(+), 77 deletions(-) diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index 1fa8543a14..dd8d4be4c2 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -26,17 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); } inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); } inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } -inline bool isFloat(Type type) { - return type.isF32() || type.isF64() || type.isF16() || type.isF128() || - type.isBF16() || isa(type) || - isa(type) || isa(type) || - isa(type) || isa(type); +inline bool isFloat8(Type type) { + return isa(type); } -inline bool isFloat8(Type type) { - return isa(type) || isa(type) || - isa(type) || isa(type) || - isa(type); +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || llvm::isa(type) || + isFloat8(type); } inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ac3274a213..aa4214bd77 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -756,14 +756,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (isa(aElemTy) || isa(aElemTy) || + (llvm::isa(aElemTy) || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (isa(aElemTy) || isa(aElemTy)) && + (llvm::isa(aElemTy)) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -784,9 +784,8 @@ bool supportMMA(Value value, int version) { cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. - bool isFP8 = isa(elemTy) || isa(elemTy) || - isa(elemTy) || - isa(elemTy); + bool isFP8 = llvm::isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 47707cc672..f32891aceb 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -632,8 +632,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = - isa(AElType) || isa(AElType); + bool isNativeFP8 = llvm::isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 4914333d3e..5a13a64535 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,9 +45,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (isa(eltType) || isa(eltType) || - isa(eltType) || eltType.isF16() || - eltType.isBF16() || eltType.isF32()) { + if (llvm::isa( + eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 4c321c27d4..f49a2555c7 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); - bool isFP8 = isa(aElTy) || isa(aElTy) || - isa(aElTy) || isa(aElTy); + bool isFP8 = llvm::isa(aElTy); bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index ee2441fcc5..2afe8c847a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1019,18 +1019,16 @@ struct FpToFpOpConversion return outVals; } size_t numElements = 4; - if (isa(srcElementType) || - isa(dstElementType) || - isa(srcElementType) || - isa(dstElementType) || - isa(srcElementType) || - isa(dstElementType)) { + if (llvm::isa( + srcElementType) || + llvm::isa( + dstElementType)) { numElements = 2; } bool useFP16IntermediateSrc = - srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 && - (isa(dstElementType) || - isa(dstElementType))); + srcElementType.isF32() && + !(isaFamily == AMD::ISAFamily::CDNA3 && + (llvm::isa(dstElementType))); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = isDstFP32 ? f16_ty : dstElementType; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index f7057266d7..005089aaf7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -416,8 +416,7 @@ class BlockedToMFMA : public OpRewritePattern { // store instructions, except for fp8 matmul kernels due to regression // TODO (lixun): investigate the regression and enable this feature again auto aElemTy = mfmaInstr.getElementTypeA(); - bool isFP8 = - isa(aElemTy) || isa(aElemTy); + bool isFP8 = llvm::isa(aElemTy); bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 59c102c1ed..74306ce241 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -20,23 +20,24 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) { return MfmaTypeId::I8TyId; } - if (isa(dataTypeA) && - isa(dataTypeB)) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Fp8TyId; } - if (isa(dataTypeA) && - isa(dataTypeB)) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Bf8TyId; } - if (isa(dataTypeA) && - isa(dataTypeB)) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Fp8TyId; } - if (isa(dataTypeA) && - isa(dataTypeB)) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Bf8TyId; } - if (isa(dataTypeA) && isa(dataTypeB)) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp16TyId; } llvm_unreachable("Unsupported input argument type."); diff --git a/third_party/intel/lib/Analysis/DPAS.cpp b/third_party/intel/lib/Analysis/DPAS.cpp index 89710ce4f2..249fe28cb1 100644 --- a/third_party/intel/lib/Analysis/DPAS.cpp +++ b/third_party/intel/lib/Analysis/DPAS.cpp @@ -147,28 +147,24 @@ DPASAnalysis::getDPASType(OpTy op) { if (isa(dElemTy)) { if (dElemTy.isF32()) { - if (aElemTy.isBF16() && - (isa(bElemTy) || isa(bElemTy))) + if (aElemTy.isBF16() && isa(bElemTy)) return DPASEngineType::FP32_FP32_BF16_FP8; // 2 E2M1 are packed into 1 int8 if (aElemTy.isBF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_BF16_FP4; - if ((isa(aElemTy) || isa(aElemTy)) && - bElemTy.isBF16()) + if (isa(aElemTy) && bElemTy.isBF16()) return DPASEngineType::FP32_FP32_FP8_BF16; - if (aElemTy.isF16() && - (isa(bElemTy) || isa(bElemTy))) + if (aElemTy.isF16() && isa(bElemTy)) return DPASEngineType::FP32_FP32_FP16_FP8; // 2 E2M1 are packed into 1 int8 if (aElemTy.isF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_FP16_FP4; - if ((isa(aElemTy) || isa(aElemTy)) && - bElemTy.isF16()) + if (isa(aElemTy) && bElemTy.isF16()) return DPASEngineType::FP32_FP32_FP8_FP16; - if ((isa(aElemTy) || isa(aElemTy)) && - (isa(bElemTy) || isa(bElemTy))) + if (isa(aElemTy) && + isa(bElemTy)) return DPASEngineType::FP32_FP32_FP8_FP8; - if ((isa(aElemTy) || isa(aElemTy)) && + if (isa(aElemTy) && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_FP8_FP4; if (aElemTy.isInteger(8) && bElemTy.isBF16()) @@ -176,7 +172,7 @@ DPASAnalysis::getDPASType(OpTy op) { if (aElemTy.isInteger(8) && bElemTy.isF16()) return DPASEngineType::FP32_FP32_FP4_FP16; if (aElemTy.isInteger(8) && - (isa(bElemTy) || isa(bElemTy))) + isa(bElemTy)) return DPASEngineType::FP32_FP32_FP4_FP8; } } diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index f920c1180d..dcbb5e592d 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -405,8 +405,7 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) { assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr"); unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth(); - if (llvm::isa(elemType) || - llvm::isa(elemType)) + if (llvm::isa(elemType)) dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16. return DPASCapability::opsChanBitWidths / dpasElemBitWidths; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index e78fa2f719..2e52eb6cd8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -960,8 +960,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (isa(dstElementType) || - isa(dstElementType)) { + if (isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for conversions to fp8"); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index a499147609..6182bc853d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -132,8 +132,7 @@ class BlockedToDPAS : public OpRewritePattern { oldAType.getElementType().getIntOrFloatBitWidth(); // We are upcasting FP8 to FP16 - if (isa(oldAType.getElementType()) || - isa(oldAType.getElementType())) + if (isa(oldAType.getElementType())) dpasElemBitWidths = 2 * dpasElemBitWidths; // Enlarge the repCluster size to use the large 2D load for A and B @@ -488,8 +487,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { Type promoteType; if (dpasLayout) { - bool isNativeFP8 = - isa(AElType) || isa(AElType); + bool isNativeFP8 = isa(AElType); // fp8 is not natively supported by the the DPAS instruction, promote it // to fp16. if (!isNativeFP8) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 06901280d0..47f7ff3010 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -299,17 +299,17 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; - if (isa(aTy.getElementType()) && - isa(bTy.getElementType())) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; - if (isa(aTy.getElementType()) && - isa(bTy.getElementType())) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; - if (isa(aTy.getElementType()) && - isa(bTy.getElementType())) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; - if (isa(aTy.getElementType()) && - isa(bTy.getElementType())) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 3f2f78d9e7..f79bbf6ddf 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -59,9 +59,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::tf32; } else if (aTy.isInteger(8)) { return triton::nvgpu::WGMMAEltType::s8; - } else if (isa(aTy)) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (isa(aTy)) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 8e37a4ad10..d9a008b47d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -466,8 +466,8 @@ struct FpToFpOpConversion llvm::errs() << "\n"; llvm::report_fatal_error("Unsupported rounding mode for conversion."); } - if (computeCapability < 89 && - (isa(srcTy) || isa(dstTy))) { + if (computeCapability < 89 && (llvm::isa(srcTy) || + llvm::isa(dstTy))) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -489,8 +489,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (isa(dstElementType) || - isa(dstElementType)) { + if (llvm::isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -527,8 +526,8 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (isa(dstElementType) || - isa(dstElementType))) || + (!(computeCapability >= 90 && + (llvm::isa(dstElementType))) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; From c4895970131f5bf3cff3d9bbb057b6b50d72ba2f Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 29 Jan 2025 12:45:24 -0600 Subject: [PATCH 3/3] Syncronized with triton-lang/triton#5684 --- .../lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index c347cad988..f2c025ee61 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -60,7 +60,8 @@ enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 }; inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB, Type scaleAType, Type scaleBType) { if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) { - if (scaleAType.isFloat8E4M3FN() && scaleBType.isFloat8E4M3FN()) { + if (llvm::isa(scaleAType) && + llvm::isa(scaleBType)) { return mxfpKind::mxf4nvf4; } return mxfpKind::mxf4; @@ -100,9 +101,9 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, return 1; if (type.isF32()) return 2; - if (type.isFloat8E4M3FN()) + if (llvm::isa(type)) return 0; - if (type.isFloat8E5M2()) + if (llvm::isa(type)) return 1; llvm_unreachable("Unsupported type."); }; @@ -224,7 +225,7 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, opcode += "f16"; else if (srcElementTy.isF32()) opcode += "tf32"; - else if (srcElementTy.isFloat8E4M3FN() || srcElementTy.isFloat8E5M2()) + else if (llvm::isa(srcElementTy)) opcode += "f8f6f4"; else assert(0 && "Unsupported type.");