Skip to content

Commit b2c1390

Browse files
Added Support for the Constrained fmuladd
1 parent fc69f25 commit b2c1390

File tree

6 files changed

+106
-1
lines changed

6 files changed

+106
-1
lines changed

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ HANDLE_TARGET_OPCODE(G_FMA)
643643

644644
/// Generic FP multiply and add. Behaves as separate fmul and fadd.
645645
HANDLE_TARGET_OPCODE(G_FMAD)
646+
HANDLE_TARGET_OPCODE(G_STRICT_FMULADD)
646647

647648
/// Generic FP division.
648649
HANDLE_TARGET_OPCODE(G_FDIV)

llvm/include/llvm/Target/GenericOpcodes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,7 @@ def G_STRICT_FREM : ConstrainedInstruction<G_FREM>;
17161716
def G_STRICT_FMA : ConstrainedInstruction<G_FMA>;
17171717
def G_STRICT_FSQRT : ConstrainedInstruction<G_FSQRT>;
17181718
def G_STRICT_FLDEXP : ConstrainedInstruction<G_FLDEXP>;
1719+
def G_STRICT_FMULADD : ConstrainedInstruction<G_FMAD>;
17191720

17201721
//------------------------------------------------------------------------------
17211722
// Memory intrinsics

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,8 @@ static unsigned getConstrainedOpcode(Intrinsic::ID ID) {
20612061
return TargetOpcode::G_STRICT_FSQRT;
20622062
case Intrinsic::experimental_constrained_ldexp:
20632063
return TargetOpcode::G_STRICT_FLDEXP;
2064+
case Intrinsic::experimental_constrained_fmuladd:
2065+
return TargetOpcode::G_STRICT_FMULADD;
20642066
default:
20652067
return 0;
20662068
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
227227
bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
228228
bool IsSigned) const;
229229

230+
bool selectStrictFMulAdd(Register ResVReg, const SPIRVType *ResType,
231+
MachineInstr &I) const;
232+
230233
bool selectTrunc(Register ResVReg, const SPIRVType *ResType,
231234
MachineInstr &I) const;
232235

@@ -689,6 +692,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
689692
case TargetOpcode::G_FMA:
690693
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
691694

695+
case TargetOpcode::G_STRICT_FMULADD:
696+
return selectStrictFMulAdd(ResVReg, ResType, I);
697+
692698
case TargetOpcode::G_STRICT_FLDEXP:
693699
return selectExtInst(ResVReg, ResType, I, CL::ldexp);
694700

@@ -1038,6 +1044,37 @@ bool SPIRVInstructionSelector::selectOpWithSrcs(Register ResVReg,
10381044
return MIB.constrainAllUses(TII, TRI, RBI);
10391045
}
10401046

1047+
bool SPIRVInstructionSelector::selectStrictFMulAdd(Register ResVReg,
1048+
const SPIRVType *ResType,
1049+
MachineInstr &I) const {
1050+
MachineBasicBlock &BB = *I.getParent();
1051+
Register MulLHS = I.getOperand(1).getReg();
1052+
Register MulRHS = I.getOperand(2).getReg();
1053+
Register AddRHS = I.getOperand(3).getReg();
1054+
SPIRVType *MulLHSType = GR.getSPIRVTypeForVReg(MulLHS);
1055+
unsigned MulOpcode, AddOpcode;
1056+
if (MulLHSType->getOpcode() == SPIRV::OpTypeFloat) {
1057+
MulOpcode = SPIRV::OpFMulS;
1058+
AddOpcode = SPIRV::OpFAddS;
1059+
} else {
1060+
MulOpcode = SPIRV::OpFMulV;
1061+
AddOpcode = SPIRV::OpFAddV;
1062+
}
1063+
Register MulTemp = MRI->createVirtualRegister(MRI->getRegClass(MulLHS));
1064+
BuildMI(BB, I, I.getDebugLoc(), TII.get(MulOpcode))
1065+
.addDef(MulTemp)
1066+
.addUse(GR.getSPIRVTypeID(ResType))
1067+
.addUse(MulLHS)
1068+
.addUse(MulRHS)
1069+
.constrainAllUses(TII, TRI, RBI);
1070+
return BuildMI(BB, I, I.getDebugLoc(), TII.get(AddOpcode))
1071+
.addDef(ResVReg)
1072+
.addUse(GR.getSPIRVTypeID(ResType))
1073+
.addUse(MulTemp)
1074+
.addUse(AddRHS)
1075+
.constrainAllUses(TII, TRI, RBI);
1076+
}
1077+
10411078
bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
10421079
const SPIRVType *ResType,
10431080
MachineInstr &I,

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
193193
.legalFor(allIntScalarsAndVectors)
194194
.legalIf(extendedScalarsAndVectors);
195195

196-
getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
196+
getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA, G_STRICT_FMULADD})
197197
.legalFor(allFloatScalarsAndVectors);
198198

199199
getActionDefinitionsBuilder(G_STRICT_FLDEXP)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
; RUN: llc -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTE
5+
; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTZ
6+
; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTP
7+
; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTN
8+
; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTE
9+
10+
; CHECK: OpFMul %[[#]] %[[#]] %[[#]]
11+
; CHECK: OpFAdd %[[#]] %[[#]] %[[#]]
12+
define spir_kernel void @test_f32(float %a) {
13+
entry:
14+
%r = tail call float @llvm.experimental.constrained.fmuladd.f32(
15+
float %a, float %a, float %a,
16+
metadata !"round.tonearest", metadata !"fpexcept.strict")
17+
ret void
18+
}
19+
20+
; CHECK: OpFMul %[[#]] %[[#]] %[[#]]
21+
; CHECK: OpFAdd %[[#]] %[[#]] %[[#]]
22+
define spir_kernel void @test_f64(double %a) {
23+
entry:
24+
%r = tail call double @llvm.experimental.constrained.fmuladd.f64(
25+
double %a, double %a, double %a,
26+
metadata !"round.towardzero", metadata !"fpexcept.strict")
27+
ret void
28+
}
29+
30+
; CHECK: OpFMul %[[#]] %[[#]] %[[#]]
31+
; CHECK: OpFAdd %[[#]] %[[#]] %[[#]]
32+
define spir_kernel void @test_v2f32(<2 x float> %a) {
33+
entry:
34+
%r = tail call <2 x float> @llvm.experimental.constrained.fmuladd.v2f32(
35+
<2 x float> %a, <2 x float> %a, <2 x float> %a,
36+
metadata !"round.upward", metadata !"fpexcept.strict")
37+
ret void
38+
}
39+
40+
; CHECK: OpFMul %[[#]] %[[#]] %[[#]]
41+
; CHECK: OpFAdd %[[#]] %[[#]] %[[#]]
42+
define spir_kernel void @test_v4f32(<4 x float> %a) {
43+
entry:
44+
%r = tail call <4 x float> @llvm.experimental.constrained.fmuladd.v4f32(
45+
<4 x float> %a, <4 x float> %a, <4 x float> %a,
46+
metadata !"round.downward", metadata !"fpexcept.strict")
47+
ret void
48+
}
49+
50+
; CHECK: OpFMul %[[#]] %[[#]] %[[#]]
51+
; CHECK: OpFAdd %[[#]] %[[#]] %[[#]]
52+
define spir_kernel void @test_v2f64(<2 x double> %a) {
53+
entry:
54+
%r = tail call <2 x double> @llvm.experimental.constrained.fmuladd.v2f64(
55+
<2 x double> %a, <2 x double> %a, <2 x double> %a,
56+
metadata !"round.tonearest", metadata !"fpexcept.strict")
57+
ret void
58+
}
59+
60+
declare float @llvm.experimental.constrained.fmuladd.f32(float, float, float, metadata, metadata)
61+
declare double @llvm.experimental.constrained.fmuladd.f64(double, double, double, metadata, metadata)
62+
declare <2 x float> @llvm.experimental.constrained.fmuladd.v2f32(<2 x float>, <2 x float>, <2 x float>, metadata, metadata)
63+
declare <4 x float> @llvm.experimental.constrained.fmuladd.v4f32(<4 x float>, <4 x float>, <4 x float>, metadata, metadata)
64+
declare <2 x double> @llvm.experimental.constrained.fmuladd.v2f64(<2 x double>, <2 x double>, <2 x double>, metadata, metadata)

0 commit comments

Comments
 (0)