From c2f0ce3b14e466eeb22ce15d32ee85d1f573fb77 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 18 Jun 2025 01:50:43 -0700 Subject: [PATCH 1/4] [RISCV][WIP] Treat bf16->f32 as separate ExtKind in combineOp_VLToVWOp_VL. This allows us to better track the narrow type we need and to fix miscompiles if f16->f32 and bf16->f32 extends are mixed. Fixes #144651. Still need to add tests, but it's late and I need sleep. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 73 ++++++++++++--------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index e670567bd1844..f7d447e03af94 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16309,7 +16309,12 @@ namespace { // apply a combine. struct CombineResult; -enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 }; +enum ExtKind : uint8_t { + ZExt = 1 << 0, + SExt = 1 << 1, + FPExt = 1 << 2, + BF16Ext = 1 << 3 +}; /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: /// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w @@ -16344,8 +16349,10 @@ struct NodeExtensionHelper { /// instance, a splat constant (e.g., 3), would support being both sign and /// zero extended. bool SupportsSExt; - /// Records if this operand is like being floating-Point extended. + /// Records if this operand is like being floating point extended. bool SupportsFPExt; + /// Records if this operand is extended from bf16. + bool SupportsBF16Ext; /// This boolean captures whether we care if this operand would still be /// around after the folding happens. bool EnforceOneUse; @@ -16381,6 +16388,7 @@ struct NodeExtensionHelper { case ExtKind::ZExt: return RISCVISD::VZEXT_VL; case ExtKind::FPExt: + case ExtKind::BF16Ext: return RISCVISD::FP_EXTEND_VL; } llvm_unreachable("Unknown ExtKind enum"); @@ -16402,13 +16410,6 @@ struct NodeExtensionHelper { if (Source.getValueType() == NarrowVT) return Source; - // vfmadd_vl -> vfwmadd_vl can take bf16 operands - if (Source.getValueType().getVectorElementType() == MVT::bf16) { - assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 && - Root->getOpcode() == RISCVISD::VFMADD_VL); - return Source; - } - unsigned ExtOpc = getExtOpc(*SupportsExt); // If we need an extension, we should be changing the type. @@ -16451,7 +16452,8 @@ struct NodeExtensionHelper { // Determine the narrow size. unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - MVT EltVT = SupportsExt == ExtKind::FPExt + MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16 + : SupportsExt == ExtKind::FPExt ? MVT::getFloatingPointVT(NarrowSize) : MVT::getIntegerVT(NarrowSize); @@ -16628,17 +16630,17 @@ struct NodeExtensionHelper { EnforceOneUse = false; } - bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT, - const RISCVSubtarget &Subtarget) { + bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) { + if (NarrowEltVT == MVT::f32) + return true; // Any f16 extension will need zvfh - if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16()) - return false; - // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with - // zvfbfwma - if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() || - Root->getOpcode() != RISCVISD::VFMADD_VL)) - return false; - return true; + if (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16()) + return true; + return false; + } + + bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) { + return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma(); } /// Helper method to set the various fields of this struct based on the @@ -16648,6 +16650,7 @@ struct NodeExtensionHelper { SupportsZExt = false; SupportsSExt = false; SupportsFPExt = false; + SupportsBF16Ext = false; EnforceOneUse = true; unsigned Opc = OrigOperand.getOpcode(); // For the nodes we handle below, we end up using their inputs directly: see @@ -16679,9 +16682,11 @@ struct NodeExtensionHelper { case RISCVISD::FP_EXTEND_VL: { MVT NarrowEltVT = OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType(); - if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget)) - break; - SupportsFPExt = true; + if (isSupportedFPExtend(NarrowEltVT, Subtarget)) + SupportsFPExt = true; + if (isSupportedBF16Extend(NarrowEltVT, Subtarget)) + SupportsBF16Ext = true; + break; } case ISD::SPLAT_VECTOR: @@ -16698,16 +16703,16 @@ struct NodeExtensionHelper { if (Op.getOpcode() != ISD::FP_EXTEND) break; - if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(), - Subtarget)) - break; - unsigned NarrowSize = VT.getScalarSizeInBits() / 2; unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits(); if (NarrowSize != ScalarBits) break; - SupportsFPExt = true; + if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget)) + SupportsFPExt = true; + if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(), + Subtarget)) + SupportsBF16Ext = true; break; } default: @@ -16940,6 +16945,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS, /*RHSExt=*/{ExtKind::FPExt}); + if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext && + RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL) + return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS, + /*RHSExt=*/{ExtKind::BF16Ext}); return std::nullopt; } @@ -16953,9 +16963,10 @@ static std::optional canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl( - Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG, - Subtarget); + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, + ExtKind::ZExt | ExtKind::SExt | + ExtKind::FPExt | ExtKind::BF16Ext, + DAG, Subtarget); } /// Check if \p Root follows a pattern Root(LHS, ext(RHS)) From aafd72ef88b0386056a3e7b00819460c93220b95 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 18 Jun 2025 12:15:34 -0700 Subject: [PATCH 2/4] fixup! add tests --- .../RISCV/rvv/fixed-vectors-vfwmaccbf16.ll | 58 +++++++++++++++++-- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll index 1639f21f243d8..aec970adff51e 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll @@ -1,8 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA -; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA -; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN -; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN +; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA +; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA +; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN +; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN define <1 x float> @vfwmaccbf16_vv_v1f32(<1 x float> %a, <1 x bfloat> %b, <1 x bfloat> %c) { ; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v1f32: @@ -295,3 +295,53 @@ define <32 x float> @vfwmaccbf32_vf_v32f32(<32 x float> %a, bfloat %b, <32 x bfl %res = call <32 x float> @llvm.fma.v32f32(<32 x float> %b.ext, <32 x float> %c.ext, <32 x float> %a) ret <32 x float> %res } + +define <4 x float> @vfwmaccbf16_vf_v4f32_scalar_extend(<4 x float> %rd, bfloat %a, <4 x bfloat> %b) local_unnamed_addr #0 { +; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend: +; ZVFBFWMA: # %bb.0: +; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFWMA-NEXT: vfwmaccbf16.vf v8, fa0, v9 +; ZVFBFWMA-NEXT: ret +; +; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: fmv.x.w a0, fa0 +; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v10, v9 +; ZVFBFMIN-NEXT: slli a0, a0, 16 +; ZVFBFMIN-NEXT: fmv.w.x fa5, a0 +; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; ZVFBFMIN-NEXT: vfmacc.vf v8, fa5, v10 +; ZVFBFMIN-NEXT: ret + %b_ext = fpext <4 x bfloat> %b to <4 x float> + %a_extend = fpext bfloat %a to float + %a_insert = insertelement <4 x float> poison, float %a_extend, i64 0 + %a_shuffle = shufflevector <4 x float> %a_insert, <4 x float> poison, <4 x i32> zeroinitializer + %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_shuffle, <4 x float> %b_ext, <4 x float> %rd) + ret <4 x float> %fma +} + +; Negative test with a mix of bfloat and half fpext. +define <4 x float> @mix(<4 x float> %rd, <4 x half> %a, <4 x bfloat> %b) { +; ZVFBFWMA-LABEL: mix: +; ZVFBFWMA: # %bb.0: +; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFWMA-NEXT: vfwcvt.f.f.v v11, v9 +; ZVFBFWMA-NEXT: vfwcvtbf16.f.f.v v9, v10 +; ZVFBFWMA-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; ZVFBFWMA-NEXT: vfmacc.vv v8, v11, v9 +; ZVFBFWMA-NEXT: ret +; +; ZVFBFMIN-LABEL: mix: +; ZVFBFMIN: # %bb.0: +; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVFBFMIN-NEXT: vfwcvt.f.f.v v11, v9 +; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v9, v10 +; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma +; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v9 +; ZVFBFMIN-NEXT: ret + %a_ext = fpext <4 x half> %a to <4 x float> + %b_ext = fpext <4 x bfloat> %b to <4 x float> + %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_ext, <4 x float> %b_ext, <4 x float> %rd) + ret <4 x float> %fma +} From 274c3f191033ea60d33d11d4929d560811c1b3bd Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 18 Jun 2025 13:14:20 -0700 Subject: [PATCH 3/4] fixup! Use a separate strategy for this in getSupportedFoldings. This avoids a root opcode check elsewhere. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 23 ++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index f7d447e03af94..734ec241d6957 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16946,7 +16946,7 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS, /*RHSExt=*/{ExtKind::FPExt}); if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext && - RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL) + RHS.SupportsBF16Ext) return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS, /*RHSExt=*/{ExtKind::BF16Ext}); @@ -16963,10 +16963,9 @@ static std::optional canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, - ExtKind::ZExt | ExtKind::SExt | - ExtKind::FPExt | ExtKind::BF16Ext, - DAG, Subtarget); + return canFoldToVWWithSameExtensionImpl( + Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG, + Subtarget); } /// Check if \p Root follows a pattern Root(LHS, ext(RHS)) @@ -17033,6 +17032,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS, Subtarget); } +/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional +canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG, + Subtarget); +} + /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that @@ -17072,6 +17083,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case RISCVISD::VFNMADD_VL: case RISCVISD::VFNMSUB_VL: Strategies.push_back(canFoldToVWWithSameExtension); + if (Root->getOpcode() == RISCVISD::VFMADD_VL) + Strategies.push_back(canFoldToVWWithBF16EXT); break; case ISD::MUL: case RISCVISD::MUL_VL: From 34a8bc1adf1597b4de72059517e60984aa9f4c88 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Thu, 19 Jun 2025 19:06:16 -0700 Subject: [PATCH 4/4] fixup! Address review comment --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 734ec241d6957..5aa51cf6a6517 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16631,12 +16631,8 @@ struct NodeExtensionHelper { } bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) { - if (NarrowEltVT == MVT::f32) - return true; - // Any f16 extension will need zvfh - if (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16()) - return true; - return false; + return (NarrowEltVT == MVT::f32 || + (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16())); } bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {