diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index afa1aa989e..dd8d4be4c2 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -26,17 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); } inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); } inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } -inline bool isFloat(Type type) { - return type.isF32() || type.isF64() || type.isF16() || type.isF128() || - type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); +inline bool isFloat8(Type type) { + return isa(type); } -inline bool isFloat8(Type type) { - return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || llvm::isa(type) || + isFloat8(type); } inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index f801af5163..aa4214bd77 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -756,14 +756,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || + (llvm::isa(aElemTy) || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && + (llvm::isa(aElemTy)) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -784,8 +784,8 @@ bool supportMMA(Value value, int version) { cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. - bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + bool isFP8 = llvm::isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 6d7632b1b7..f32891aceb 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -632,7 +632,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); + bool isNativeFP8 = llvm::isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 21b8e059ca..5a13a64535 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,9 +45,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || - eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || - eltType.isF32()) { + if (llvm::isa( + eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index a171d89339..f49a2555c7 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); - bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || - aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); + bool isFP8 = llvm::isa(aElTy); bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 35a2e1a34b..2afe8c847a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1019,17 +1019,16 @@ struct FpToFpOpConversion return outVals; } size_t numElements = 4; - if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() || - srcElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E4M3FNUZ() || - srcElementType.isFloat8E5M2FNUZ() || - dstElementType.isFloat8E5M2FNUZ()) { + if (llvm::isa( + srcElementType) || + llvm::isa( + dstElementType)) { numElements = 2; } bool useFP16IntermediateSrc = - srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 && - (dstElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E5M2FNUZ())); + srcElementType.isF32() && + !(isaFamily == AMD::ISAFamily::CDNA3 && + (llvm::isa(dstElementType))); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = isDstFP32 ? f16_ty : dstElementType; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 7ea13142a7..005089aaf7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -416,7 +416,7 @@ class BlockedToMFMA : public OpRewritePattern { // store instructions, except for fp8 matmul kernels due to regression // TODO (lixun): investigate the regression and enable this feature again auto aElemTy = mfmaInstr.getElementTypeA(); - bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ(); + bool isFP8 = llvm::isa(aElemTy); bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 4979ee005b..74306ce241 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -20,19 +20,24 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) { return MfmaTypeId::I8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Fp8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Bf8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Fp8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Bf8TyId; } - if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp16TyId; } llvm_unreachable("Unsupported input argument type."); diff --git a/third_party/intel/lib/Analysis/DPAS.cpp b/third_party/intel/lib/Analysis/DPAS.cpp index 3b5975c3c3..249fe28cb1 100644 --- a/third_party/intel/lib/Analysis/DPAS.cpp +++ b/third_party/intel/lib/Analysis/DPAS.cpp @@ -125,9 +125,9 @@ DPASAnalysis::getDPASType(OpTy op) { if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32) return DPASEngineType::FP32_FP32_TF32_TF32; // For FP8XFP8->FP32, upcast to FP16 - if (aElemTy.isFloat8E5M2()) + if (isa(aElemTy)) return DPASEngineType::FP32_FP32_FP16_FP16; - if (aElemTy.isFloat8E4M3FN()) + if (isa(aElemTy)) return DPASEngineType::FP32_FP32_FP16_FP16; } else if (dElemTy.isF16()) { if (aElemTy.isF16()) @@ -147,28 +147,24 @@ DPASAnalysis::getDPASType(OpTy op) { if (isa(dElemTy)) { if (dElemTy.isF32()) { - if (aElemTy.isBF16() && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + if (aElemTy.isBF16() && isa(bElemTy)) return DPASEngineType::FP32_FP32_BF16_FP8; // 2 E2M1 are packed into 1 int8 if (aElemTy.isBF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_BF16_FP4; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && - bElemTy.isBF16()) + if (isa(aElemTy) && bElemTy.isBF16()) return DPASEngineType::FP32_FP32_FP8_BF16; - if (aElemTy.isF16() && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + if (aElemTy.isF16() && isa(bElemTy)) return DPASEngineType::FP32_FP32_FP16_FP8; // 2 E2M1 are packed into 1 int8 if (aElemTy.isF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_FP16_FP4; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && - bElemTy.isF16()) + if (isa(aElemTy) && bElemTy.isF16()) return DPASEngineType::FP32_FP32_FP8_FP16; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + if (isa(aElemTy) && + isa(bElemTy)) return DPASEngineType::FP32_FP32_FP8_FP8; - if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) && + if (isa(aElemTy) && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_FP8_FP4; if (aElemTy.isInteger(8) && bElemTy.isBF16()) @@ -176,7 +172,7 @@ DPASAnalysis::getDPASType(OpTy op) { if (aElemTy.isInteger(8) && bElemTy.isF16()) return DPASEngineType::FP32_FP32_FP4_FP16; if (aElemTy.isInteger(8) && - (bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2())) + isa(bElemTy)) return DPASEngineType::FP32_FP32_FP4_FP8; } } diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 8a9c53969d..dcbb5e592d 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -405,7 +405,7 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) { assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr"); unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth(); - if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN()) + if (llvm::isa(elemType)) dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16. return DPASCapability::opsChanBitWidths / dpasElemBitWidths; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index 50697c86dc..2e52eb6cd8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -960,7 +960,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { + if (isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for conversions to fp8"); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 170fc7f355..6182bc853d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -132,8 +132,7 @@ class BlockedToDPAS : public OpRewritePattern { oldAType.getElementType().getIntOrFloatBitWidth(); // We are upcasting FP8 to FP16 - if (oldAType.getElementType().isFloat8E5M2() || - oldAType.getElementType().isFloat8E4M3FN()) + if (isa(oldAType.getElementType())) dpasElemBitWidths = 2 * dpasElemBitWidths; // Enlarge the repCluster size to use the large 2D load for A and B @@ -488,7 +487,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { Type promoteType; if (dpasLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); + bool isNativeFP8 = isa(AElType); // fp8 is not natively supported by the the DPAS instruction, promote it // to fp16. if (!isNativeFP8) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c5ec00097d..47f7ff3010 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -299,17 +299,17 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E5M2()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FN()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E5M2()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E4M3FN()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index c347cad988..f2c025ee61 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -60,7 +60,8 @@ enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 }; inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB, Type scaleAType, Type scaleBType) { if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) { - if (scaleAType.isFloat8E4M3FN() && scaleBType.isFloat8E4M3FN()) { + if (llvm::isa(scaleAType) && + llvm::isa(scaleBType)) { return mxfpKind::mxf4nvf4; } return mxfpKind::mxf4; @@ -100,9 +101,9 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, return 1; if (type.isF32()) return 2; - if (type.isFloat8E4M3FN()) + if (llvm::isa(type)) return 0; - if (type.isFloat8E5M2()) + if (llvm::isa(type)) return 1; llvm_unreachable("Unsupported type."); }; @@ -224,7 +225,7 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, opcode += "f16"; else if (srcElementTy.isF32()) opcode += "tf32"; - else if (srcElementTy.isFloat8E4M3FN() || srcElementTy.isFloat8E5M2()) + else if (llvm::isa(srcElementTy)) opcode += "f8f6f4"; else assert(0 && "Unsupported type."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 7450bc3f4e..f79bbf6ddf 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -59,9 +59,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::tf32; } else if (aTy.isInteger(8)) { return triton::nvgpu::WGMMAEltType::s8; - } else if (aTy.isFloat8E5M2()) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FN()) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index d489d0a1b1..d9a008b47d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -466,8 +466,8 @@ struct FpToFpOpConversion llvm::errs() << "\n"; llvm::report_fatal_error("Unsupported rounding mode for conversion."); } - if (computeCapability < 89 && - (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { + if (computeCapability < 89 && (llvm::isa(srcTy) || + llvm::isa(dstTy))) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -489,7 +489,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { + if (llvm::isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -526,8 +526,8 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || - dstElementType.isFloat8E5M2())) || + (!(computeCapability >= 90 && + (llvm::isa(dstElementType))) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;