diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index e670567bd1844..5aa51cf6a6517 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16309,7 +16309,12 @@ namespace { // apply a combine. struct CombineResult; -enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 }; +enum ExtKind : uint8_t { + ZExt = 1 << 0, + SExt = 1 << 1, + FPExt = 1 << 2, + BF16Ext = 1 << 3 +}; /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: /// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w @@ -16344,8 +16349,10 @@ struct NodeExtensionHelper { /// instance, a splat constant (e.g., 3), would support being both sign and /// zero extended. bool SupportsSExt; - /// Records if this operand is like being floating-Point extended. + /// Records if this operand is like being floating point extended. bool SupportsFPExt; + /// Records if this operand is extended from bf16. + bool SupportsBF16Ext; /// This boolean captures whether we care if this operand would still be /// around after the folding happens. bool EnforceOneUse; @@ -16381,6 +16388,7 @@ struct NodeExtensionHelper { case ExtKind::ZExt: return RISCVISD::VZEXT_VL; case ExtKind::FPExt: + case ExtKind::BF16Ext: return RISCVISD::FP_EXTEND_VL; } llvm_unreachable("Unknown ExtKind enum"); @@ -16402,13 +16410,6 @@ struct NodeExtensionHelper { if (Source.getValueType() == NarrowVT) return Source; - // vfmadd_vl -> vfwmadd_vl can take bf16 operands - if (Source.getValueType().getVectorElementType() == MVT::bf16) { - assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 && - Root->getOpcode() == RISCVISD::VFMADD_VL); - return Source; - } - unsigned ExtOpc = getExtOpc(*SupportsExt); // If we need an extension, we should be changing the type. @@ -16451,7 +16452,8 @@ struct NodeExtensionHelper { // Determine the narrow size. unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - MVT EltVT = SupportsExt == ExtKind::FPExt + MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16 + : SupportsExt == ExtKind::FPExt ? MVT::getFloatingPointVT(NarrowSize) : MVT::getIntegerVT(NarrowSize); @@ -16628,17 +16630,13 @@ struct NodeExtensionHelper { EnforceOneUse = false; } - bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT, - const RISCVSubtarget &Subtarget) { - // Any f16 extension will need zvfh - if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16()) - return false; - // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with - // zvfbfwma - if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() || - Root->getOpcode() != RISCVISD::VFMADD_VL)) - return false; - return true; + bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) { + return (NarrowEltVT == MVT::f32 || + (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16())); + } + + bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) { + return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma(); } /// Helper method to set the various fields of this struct based on the @@ -16648,6 +16646,7 @@ struct NodeExtensionHelper { SupportsZExt = false; SupportsSExt = false; SupportsFPExt = false; + SupportsBF16Ext = false; EnforceOneUse = true; unsigned Opc = OrigOperand.getOpcode(); // For the nodes we handle below, we end up using their inputs directly: see @@ -16679,9 +16678,11 @@ struct NodeExtensionHelper { case RISCVISD::FP_EXTEND_VL: { MVT NarrowEltVT = OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType(); - if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget)) - break; - SupportsFPExt = true; + if (isSupportedFPExtend(NarrowEltVT, Subtarget)) + SupportsFPExt = true; + if (isSupportedBF16Extend(NarrowEltVT, Subtarget)) + SupportsBF16Ext = true; + break; } case ISD::SPLAT_VECTOR: @@ -16698,16 +16699,16 @@ struct NodeExtensionHelper { if (Op.getOpcode() != ISD::FP_EXTEND) break; - if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(), - Subtarget)) - break; - unsigned NarrowSize = VT.getScalarSizeInBits() / 2; unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits(); if (NarrowSize != ScalarBits) break; - SupportsFPExt = true; + if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget)) + SupportsFPExt = true; + if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(), + Subtarget)) + SupportsBF16Ext = true; break; } default: @@ -16940,6 +16941,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS, /*RHSExt=*/{ExtKind::FPExt}); + if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext && + RHS.SupportsBF16Ext) + return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS, + /*RHSExt=*/{ExtKind::BF16Ext}); return std::nullopt; } @@ -17022,6 +17028,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS, Subtarget); } +/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional +canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG, + Subtarget); +} + /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that @@ -17061,6 +17079,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case RISCVISD::VFNMADD_VL: case RISCVISD::VFNMSUB_VL: Strategies.push_back(canFoldToVWWithSameExtension); + if (Root->getOpcode() == RISCVISD::VFMADD_VL) + Strategies.push_back(canFoldToVWWithBF16EXT); break; case ISD::MUL: case RISCVISD::MUL_VL: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll index 1639f21f243d8..aec970adff51e 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll @@ -1,8 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA -; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA -; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN -; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN +; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA +; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA +; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN +; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN define <1 x float> @vfwmaccbf16_vv_v1f32(<1 x float> %a, <1 x bfloat> %b, <1 x bfloat> %c) { ; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v1f32: @@ -295,3 +295,53 @@ define <32 x float> @vfwmaccbf32_vf_v32f32(<32 x float> %a, bfloat %b, <32 x bfl %res = call <32 x float> @llvm.fma.v32f32(<32 x float> %b.ext, <32 x float> %c.ext, <32 x float> %a) ret <32 x float> %res } + +define <4 x float> @vfwmaccbf16_vf_v4f32_scalar_extend(<4 x float> %rd, bfloat %a, <4 x bfloat> %b) local_unnamed_addr #0 { +; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend: +; ZVFBFWMA: # %bb.0: +; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFWMA-NEXT: vfwmaccbf16.vf v8, fa0, v9 +; ZVFBFWMA-NEXT: ret +; +; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: fmv.x.w a0, fa0 +; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v10, v9 +; ZVFBFMIN-NEXT: slli a0, a0, 16 +; ZVFBFMIN-NEXT: fmv.w.x fa5, a0 +; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; ZVFBFMIN-NEXT: vfmacc.vf v8, fa5, v10 +; ZVFBFMIN-NEXT: ret + %b_ext = fpext <4 x bfloat> %b to <4 x float> + %a_extend = fpext bfloat %a to float + %a_insert = insertelement <4 x float> poison, float %a_extend, i64 0 + %a_shuffle = shufflevector <4 x float> %a_insert, <4 x float> poison, <4 x i32> zeroinitializer + %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_shuffle, <4 x float> %b_ext, <4 x float> %rd) + ret <4 x float> %fma +} + +; Negative test with a mix of bfloat and half fpext. +define <4 x float> @mix(<4 x float> %rd, <4 x half> %a, <4 x bfloat> %b) { +; ZVFBFWMA-LABEL: mix: +; ZVFBFWMA: # %bb.0: +; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFWMA-NEXT: vfwcvt.f.f.v v11, v9 +; ZVFBFWMA-NEXT: vfwcvtbf16.f.f.v v9, v10 +; ZVFBFWMA-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; ZVFBFWMA-NEXT: vfmacc.vv v8, v11, v9 +; ZVFBFWMA-NEXT: ret +; +; ZVFBFMIN-LABEL: mix: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFMIN-NEXT: vfwcvt.f.f.v v11, v9 +; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v9, v10 +; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v9 +; ZVFBFMIN-NEXT: ret + %a_ext = fpext <4 x half> %a to <4 x float> + %b_ext = fpext <4 x bfloat> %b to <4 x float> + %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_ext, <4 x float> %b_ext, <4 x float> %rd) + ret <4 x float> %fma +}