Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 49 additions & 29 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16309,7 +16309,12 @@ namespace {
// apply a combine.
struct CombineResult;

enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
enum ExtKind : uint8_t {
ZExt = 1 << 0,
SExt = 1 << 1,
FPExt = 1 << 2,
BF16Ext = 1 << 3
};
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
Expand Down Expand Up @@ -16344,8 +16349,10 @@ struct NodeExtensionHelper {
/// instance, a splat constant (e.g., 3), would support being both sign and
/// zero extended.
bool SupportsSExt;
/// Records if this operand is like being floating-Point extended.
/// Records if this operand is like being floating point extended.
bool SupportsFPExt;
/// Records if this operand is extended from bf16.
bool SupportsBF16Ext;
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
Expand Down Expand Up @@ -16381,6 +16388,7 @@ struct NodeExtensionHelper {
case ExtKind::ZExt:
return RISCVISD::VZEXT_VL;
case ExtKind::FPExt:
case ExtKind::BF16Ext:
return RISCVISD::FP_EXTEND_VL;
}
llvm_unreachable("Unknown ExtKind enum");
Expand All @@ -16402,13 +16410,6 @@ struct NodeExtensionHelper {
if (Source.getValueType() == NarrowVT)
return Source;

// vfmadd_vl -> vfwmadd_vl can take bf16 operands
if (Source.getValueType().getVectorElementType() == MVT::bf16) {
assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
Root->getOpcode() == RISCVISD::VFMADD_VL);
return Source;
}

unsigned ExtOpc = getExtOpc(*SupportsExt);

// If we need an extension, we should be changing the type.
Expand Down Expand Up @@ -16451,7 +16452,8 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;

MVT EltVT = SupportsExt == ExtKind::FPExt
MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
: SupportsExt == ExtKind::FPExt
? MVT::getFloatingPointVT(NarrowSize)
: MVT::getIntegerVT(NarrowSize);

Expand Down Expand Up @@ -16628,17 +16630,13 @@ struct NodeExtensionHelper {
EnforceOneUse = false;
}

bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
const RISCVSubtarget &Subtarget) {
// Any f16 extension will need zvfh
if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
return false;
// The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
// zvfbfwma
if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
Root->getOpcode() != RISCVISD::VFMADD_VL))
return false;
return true;
bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
return (NarrowEltVT == MVT::f32 ||
(NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16()));
}

bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
}

/// Helper method to set the various fields of this struct based on the
Expand All @@ -16648,6 +16646,7 @@ struct NodeExtensionHelper {
SupportsZExt = false;
SupportsSExt = false;
SupportsFPExt = false;
SupportsBF16Ext = false;
EnforceOneUse = true;
unsigned Opc = OrigOperand.getOpcode();
// For the nodes we handle below, we end up using their inputs directly: see
Expand Down Expand Up @@ -16679,9 +16678,11 @@ struct NodeExtensionHelper {
case RISCVISD::FP_EXTEND_VL: {
MVT NarrowEltVT =
OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
break;
SupportsFPExt = true;
if (isSupportedFPExtend(NarrowEltVT, Subtarget))
SupportsFPExt = true;
if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
SupportsBF16Ext = true;

break;
}
case ISD::SPLAT_VECTOR:
Expand All @@ -16698,16 +16699,16 @@ struct NodeExtensionHelper {
if (Op.getOpcode() != ISD::FP_EXTEND)
break;

if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
Subtarget))
break;

unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
if (NarrowSize != ScalarBits)
break;

SupportsFPExt = true;
if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
SupportsFPExt = true;
if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
Subtarget))
SupportsBF16Ext = true;
break;
}
default:
Expand Down Expand Up @@ -16940,6 +16941,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
/*RHSExt=*/{ExtKind::FPExt});
if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
RHS.SupportsBF16Ext)
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
/*RHSExt=*/{ExtKind::BF16Ext});
return std::nullopt;
}

Expand Down Expand Up @@ -17022,6 +17028,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
Subtarget);
}

/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
Expand Down Expand Up @@ -17061,6 +17079,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
Strategies.push_back(canFoldToVWWithBF16EXT);
break;
case ISD::MUL:
case RISCVISD::MUL_VL:
Expand Down
58 changes: 54 additions & 4 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN

define <1 x float> @vfwmaccbf16_vv_v1f32(<1 x float> %a, <1 x bfloat> %b, <1 x bfloat> %c) {
; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v1f32:
Expand Down Expand Up @@ -295,3 +295,53 @@ define <32 x float> @vfwmaccbf32_vf_v32f32(<32 x float> %a, bfloat %b, <32 x bfl
%res = call <32 x float> @llvm.fma.v32f32(<32 x float> %b.ext, <32 x float> %c.ext, <32 x float> %a)
ret <32 x float> %res
}

define <4 x float> @vfwmaccbf16_vf_v4f32_scalar_extend(<4 x float> %rd, bfloat %a, <4 x bfloat> %b) local_unnamed_addr #0 {
; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
; ZVFBFWMA: # %bb.0:
; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; ZVFBFWMA-NEXT: vfwmaccbf16.vf v8, fa0, v9
; ZVFBFWMA-NEXT: ret
;
; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
; ZVFBFMIN: # %bb.0:
; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v10, v9
; ZVFBFMIN-NEXT: slli a0, a0, 16
; ZVFBFMIN-NEXT: fmv.w.x fa5, a0
; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
; ZVFBFMIN-NEXT: vfmacc.vf v8, fa5, v10
; ZVFBFMIN-NEXT: ret
%b_ext = fpext <4 x bfloat> %b to <4 x float>
%a_extend = fpext bfloat %a to float
%a_insert = insertelement <4 x float> poison, float %a_extend, i64 0
%a_shuffle = shufflevector <4 x float> %a_insert, <4 x float> poison, <4 x i32> zeroinitializer
%fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_shuffle, <4 x float> %b_ext, <4 x float> %rd)
ret <4 x float> %fma
}

; Negative test with a mix of bfloat and half fpext.
define <4 x float> @mix(<4 x float> %rd, <4 x half> %a, <4 x bfloat> %b) {
; ZVFBFWMA-LABEL: mix:
; ZVFBFWMA: # %bb.0:
; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; ZVFBFWMA-NEXT: vfwcvt.f.f.v v11, v9
; ZVFBFWMA-NEXT: vfwcvtbf16.f.f.v v9, v10
; ZVFBFWMA-NEXT: vsetvli zero, zero, e32, m1, ta, ma
; ZVFBFWMA-NEXT: vfmacc.vv v8, v11, v9
; ZVFBFWMA-NEXT: ret
;
; ZVFBFMIN-LABEL: mix:
; ZVFBFMIN: # %bb.0:
; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; ZVFBFMIN-NEXT: vfwcvt.f.f.v v11, v9
; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v9, v10
; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v9
; ZVFBFMIN-NEXT: ret
%a_ext = fpext <4 x half> %a to <4 x float>
%b_ext = fpext <4 x bfloat> %b to <4 x float>
%fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_ext, <4 x float> %b_ext, <4 x float> %rd)
ret <4 x float> %fma
}
Loading