Skip to content

Commit 10ab51c

Browse files
implement arith Constrained Floating-Point Intrinsics via Tablegen; add float Saturation decoration
1 parent 34e8354 commit 10ab51c

File tree

9 files changed

+78
-52
lines changed

9 files changed

+78
-52
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
6565
// If we define an output, and have at least one other argument.
6666
if (MCDesc.getNumDefs() == 1 && MCDesc.getNumOperands() >= 2) {
6767
// Check if we define an ID, and take a type as operand 1.
68-
auto &DefOpInfo = MCDesc.operands()[0];
69-
auto &FirstArgOpInfo = MCDesc.operands()[1];
70-
return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
71-
DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
72-
FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
68+
return MCDesc.operands()[0].RegClass >= 0 &&
69+
MCDesc.operands()[1].RegClass >= 0 &&
70+
MCDesc.operands()[0].RegClass != SPIRV::TYPERegClassID &&
71+
MCDesc.operands()[1].RegClass == SPIRV::TYPERegClassID;
7372
}
7473
return false;
7574
}

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,23 +1293,35 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
12931293
}
12941294
}
12951295

1296+
static void createDecorationIntrinsic(Instruction *I, MDNode *Node,
1297+
IRBuilder<> &B) {
1298+
LLVMContext &Ctx = I->getContext();
1299+
setInsertPointAfterDef(B, I);
1300+
B.CreateIntrinsic(Intrinsic::spv_assign_decoration, {I->getType()},
1301+
{I, MetadataAsValue::get(Ctx, MDNode::get(Ctx, {Node}))});
1302+
}
1303+
12961304
static void createRoundingModeDecoration(Instruction *I,
12971305
unsigned RoundingModeDeco,
12981306
IRBuilder<> &B) {
12991307
LLVMContext &Ctx = I->getContext();
13001308
Type *Int32Ty = Type::getInt32Ty(Ctx);
1301-
setInsertPointAfterDef(B, I);
1302-
B.CreateIntrinsic(
1303-
Intrinsic::spv_assign_decoration, {I->getType()},
1304-
{I,
1305-
MetadataAsValue::get(
1306-
Ctx,
1307-
MDNode::get(
1308-
Ctx, {MDNode::get(
1309-
Ctx, {ConstantAsMetadata::get(ConstantInt::get(
1310-
Int32Ty, SPIRV::Decoration::FPRoundingMode)),
1311-
ConstantAsMetadata::get(ConstantInt::get(
1312-
Int32Ty, RoundingModeDeco))})}))});
1309+
MDNode *RoundingModeNode = MDNode::get(
1310+
Ctx,
1311+
{ConstantAsMetadata::get(
1312+
ConstantInt::get(Int32Ty, SPIRV::Decoration::FPRoundingMode)),
1313+
ConstantAsMetadata::get(ConstantInt::get(Int32Ty, RoundingModeDeco))});
1314+
createDecorationIntrinsic(I, RoundingModeNode, B);
1315+
}
1316+
1317+
static void createSaturatedConversionDecoration(Instruction *I,
1318+
IRBuilder<> &B) {
1319+
LLVMContext &Ctx = I->getContext();
1320+
Type *Int32Ty = Type::getInt32Ty(Ctx);
1321+
MDNode *SaturatedConversionNode =
1322+
MDNode::get(Ctx, {ConstantAsMetadata::get(ConstantInt::get(
1323+
Int32Ty, SPIRV::Decoration::SaturatedConversion))});
1324+
createDecorationIntrinsic(I, SaturatedConversionNode, B);
13131325
}
13141326

13151327
Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
@@ -1912,10 +1924,13 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
19121924
SmallVector<StringRef, 8> Parts;
19131925
S.split(Parts, "_", -1, false);
19141926
if (Parts.size() > 1) {
1915-
// Convert the tip about rounding mode into a decoration record.
1927+
// Convert the info about rounding mode into a decoration record.
19161928
unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]);
19171929
if (RoundingModeDeco != std::numeric_limits<unsigned>::max())
19181930
createRoundingModeDecoration(CI, RoundingModeDeco, B);
1931+
// Check if the SaturatedConversion info is present.
1932+
if (Parts[1] == "sat")
1933+
createSaturatedConversionDecoration(CI, B);
19191934
}
19201935
}
19211936
}

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,23 +491,29 @@ def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>;
491491
def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>;
492492
defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>;
493493
defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>;
494+
defm OpStrictFAdd: BinOpTypedGen<"OpFAdd", 129, strict_fadd, 1, 1>;
494495

495496
defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>;
496497
defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>;
498+
defm OpStrictFSub: BinOpTypedGen<"OpFSub", 131, strict_fsub, 1, 1>;
497499

498500
defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>;
499501
defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>;
502+
defm OpStrictFMul: BinOpTypedGen<"OpFMul", 133, strict_fmul, 1, 1>;
500503

501504
defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>;
502505
defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>;
503506
defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>;
507+
defm OpStrictFDiv: BinOpTypedGen<"OpFDiv", 136, strict_fdiv, 1, 1>;
504508

505509
defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>;
506510
defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
507511

508512
def OpSMod: BinOp<"OpSMod", 139>;
509513

510514
defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>;
515+
defm OpStrictFRem: BinOpTypedGen<"OpFRem", 140, strict_frem, 1, 1>;
516+
511517
def OpFMod: BinOp<"OpFMod", 141>;
512518

513519
def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -557,19 +557,12 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
557557
case TargetOpcode::G_UCMP:
558558
return selectSUCmp(ResVReg, ResType, I, false);
559559

560+
case TargetOpcode::G_STRICT_FMA:
560561
case TargetOpcode::G_FMA:
561562
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
562563

563-
case TargetOpcode::G_STRICT_FSQRT:
564-
case TargetOpcode::G_STRICT_FADD:
565-
case TargetOpcode::G_STRICT_FSUB:
566-
case TargetOpcode::G_STRICT_FMUL:
567-
case TargetOpcode::G_STRICT_FDIV:
568-
case TargetOpcode::G_STRICT_FREM:
569564
case TargetOpcode::G_STRICT_FLDEXP:
570-
return false;
571-
case TargetOpcode::G_STRICT_FMA:
572-
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
565+
return selectExtInst(ResVReg, ResType, I, CL::ldexp);
573566

574567
case TargetOpcode::G_FPOW:
575568
return selectExtInst(ResVReg, ResType, I, CL::pow, GL::Pow);
@@ -629,6 +622,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
629622
case TargetOpcode::G_FTANH:
630623
return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh);
631624

625+
case TargetOpcode::G_STRICT_FSQRT:
632626
case TargetOpcode::G_FSQRT:
633627
return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt);
634628

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,21 @@ using namespace llvm::LegalityPredicates;
2727
static const std::set<unsigned> TypeFoldingSupportingOpcs = {
2828
TargetOpcode::G_ADD,
2929
TargetOpcode::G_FADD,
30+
TargetOpcode::G_STRICT_FADD,
3031
TargetOpcode::G_SUB,
3132
TargetOpcode::G_FSUB,
33+
TargetOpcode::G_STRICT_FSUB,
3234
TargetOpcode::G_MUL,
3335
TargetOpcode::G_FMUL,
36+
TargetOpcode::G_STRICT_FMUL,
3437
TargetOpcode::G_SDIV,
3538
TargetOpcode::G_UDIV,
3639
TargetOpcode::G_FDIV,
40+
TargetOpcode::G_STRICT_FDIV,
3741
TargetOpcode::G_SREM,
3842
TargetOpcode::G_UREM,
3943
TargetOpcode::G_FREM,
44+
TargetOpcode::G_STRICT_FREM,
4045
TargetOpcode::G_FNEG,
4146
TargetOpcode::G_CONSTANT,
4247
TargetOpcode::G_FCONSTANT,
@@ -219,10 +224,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
219224
.legalFor(allIntScalarsAndVectors)
220225
.legalIf(extendedScalarsAndVectors);
221226

222-
getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
223-
224-
getActionDefinitionsBuilder({G_STRICT_FSQRT, G_STRICT_FADD, G_STRICT_FSUB, G_STRICT_FMUL,
225-
G_STRICT_FDIV, G_STRICT_FREM, G_STRICT_FMA})
227+
getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
226228
.legalFor(allFloatScalarsAndVectors);
227229

228230
getActionDefinitionsBuilder(G_STRICT_FLDEXP)

llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
5555
MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR);
5656
} // namespace llvm
5757

58-
static bool isMetaInstrGET(unsigned Opcode) {
58+
static bool isMetaInstr(unsigned Opcode) {
5959
return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
6060
Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
61-
Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID;
61+
Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID ||
62+
Opcode == SPIRV::ASSIGN_TYPE;
6263
}
6364

6465
static bool mayBeInserted(unsigned Opcode) {
@@ -128,7 +129,7 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
128129
if (isTypeFoldingSupported(Opcode)) {
129130
// Check if the instruction newly generated or already processed
130131
MachineInstr *NextMI = I.getNextNode();
131-
if (NextMI && isMetaInstrGET(NextMI->getOpcode()))
132+
if (NextMI && isMetaInstr(NextMI->getOpcode()))
132133
continue;
133134
// Restore usual instructions pattern for the newly inserted
134135
// instruction

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,10 @@ void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
458458
assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
459459
MachineInstr &AssignTypeInst =
460460
*(MRI.use_instr_begin(MI.getOperand(0).getReg()));
461+
SPIRVType *SpvTypeRes = GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg());
461462
auto NewReg =
462-
createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
463+
createNewIdReg(SpvTypeRes, MI.getOperand(0).getReg(), MRI, *GR).first;
464+
GR->assignSPIRVTypeToVReg(SpvTypeRes, NewReg, MIB.getMF());
463465
AssignTypeInst.getOperand(1).setReg(NewReg);
464466
MI.getOperand(0).setReg(NewReg);
465467
MIB.setInsertPt(*MI.getParent(), MI.getIterator());

llvm/test/CodeGen/SPIRV/instructions/integer-casts.ll

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
; CHECK-DAG: OpName [[ZEXT8_16:%.*]] "u8tou16"
1515
; CHECK-DAG: OpName [[ZEXT16_32:%.*]] "u16tou32"
1616

17+
; CHECK-DAG: OpName %[[#R17:]] "r17"
18+
; CHECK-DAG: OpName %[[#R18:]] "r18"
19+
; CHECK-DAG: OpName %[[#R19:]] "r19"
20+
; CHECK-DAG: OpName %[[#R20:]] "r20"
21+
; CHECK-DAG: OpName %[[#R21:]] "r21"
22+
1723
; CHECK-DAG: OpName [[TRUNC32_16v4:%.*]] "i32toi16v4"
1824
; CHECK-DAG: OpName [[TRUNC32_8v4:%.*]] "i32toi8v4"
1925
; CHECK-DAG: OpName [[TRUNC16_8v4:%.*]] "i16toi8v4"
@@ -24,10 +30,11 @@
2430
; CHECK-DAG: OpName [[ZEXT8_16v4:%.*]] "u8tou16v4"
2531
; CHECK-DAG: OpName [[ZEXT16_32v4:%.*]] "u16tou32v4"
2632

27-
; CHECK-DAG: OpDecorate %[[#R17:]] FPRoundingMode RTZ
28-
; CHECK-DAG: OpDecorate %[[#R18:]] FPRoundingMode RTE
29-
; CHECK-DAG: OpDecorate %[[#R19:]] FPRoundingMode RTP
30-
; CHECK-DAG: OpDecorate %[[#R20:]] FPRoundingMode RTN
33+
; CHECK-DAG: OpDecorate %[[#R17]] FPRoundingMode RTZ
34+
; CHECK-DAG: OpDecorate %[[#R18]] FPRoundingMode RTE
35+
; CHECK-DAG: OpDecorate %[[#R19]] FPRoundingMode RTP
36+
; CHECK-DAG: OpDecorate %[[#R20]] FPRoundingMode RTN
37+
; CHECK-DAG: OpDecorate %[[#R21]] SaturatedConversion
3138

3239
; CHECK-DAG: [[F32:%.*]] = OpTypeFloat 32
3340
; CHECK-DAG: [[F16:%.*]] = OpTypeFloat 16
@@ -260,10 +267,11 @@ define <4 x i32> @u16tou32v4(<4 x i16> %a) {
260267
; CHECK: %[[#]] = OpSConvert [[U32v4]] %[[#]]
261268
; CHECK: %[[#]] = OpConvertUToF [[F32]] %[[#]]
262269
; CHECK: %[[#]] = OpConvertUToF [[F32]] %[[#]]
263-
; CHECK: %[[#R17:]] = OpFConvert [[F32v2]] %[[#]]
264-
; CHECK: %[[#R18:]] = OpFConvert [[F32v2]] %[[#]]
265-
; CHECK: %[[#R19:]] = OpFConvert [[F32v2]] %[[#]]
266-
; CHECK: %[[#R20:]] = OpFConvert [[F32v2]] %[[#]]
270+
; CHECK: %[[#R17]] = OpFConvert [[F32v2]] %[[#]]
271+
; CHECK: %[[#R18]] = OpFConvert [[F32v2]] %[[#]]
272+
; CHECK: %[[#R19]] = OpFConvert [[F32v2]] %[[#]]
273+
; CHECK: %[[#R20]] = OpFConvert [[F32v2]] %[[#]]
274+
; CHECK: %[[#R21]] = OpConvertFToU [[U8]] %[[#]]
267275
; CHECK: OpFunctionEnd
268276
define dso_local spir_kernel void @test_wrappers(ptr addrspace(4) %arg, i64 %arg_ptr, <4 x i8> %arg_v2) {
269277
%r1 = call spir_func i32 @__spirv_ConvertFToU(float 0.000000e+00)
@@ -286,6 +294,7 @@ define dso_local spir_kernel void @test_wrappers(ptr addrspace(4) %arg, i64 %arg
286294
%r18 = call spir_func <2 x float> @_Z28__spirv_FConvert_Rfloat2_rteDv2_DF16_(<2 x half> noundef <half 0xH409A, half 0xH439A>)
287295
%r19 = call spir_func <2 x float> @_Z28__spirv_FConvert_Rfloat2_rtpDv2_DF16_(<2 x half> noundef <half 0xH409A, half 0xH439A>)
288296
%r20 = call spir_func <2 x float> @_Z28__spirv_FConvert_Rfloat2_rtnDv2_DF16_(<2 x half> noundef <half 0xH409A, half 0xH439A>)
297+
%r21 = call spir_func i8 @_Z30__spirv_ConvertFToU_Ruchar_satf(float noundef 42.0)
289298
ret void
290299
}
291300

@@ -309,3 +318,4 @@ declare dso_local spir_func <2 x float> @_Z28__spirv_FConvert_Rfloat2_rtzDv2_DF1
309318
declare dso_local spir_func <2 x float> @_Z28__spirv_FConvert_Rfloat2_rteDv2_DF16_(<2 x half> noundef)
310319
declare dso_local spir_func <2 x float> @_Z28__spirv_FConvert_Rfloat2_rtpDv2_DF16_(<2 x half> noundef)
311320
declare dso_local spir_func <2 x float> @_Z28__spirv_FConvert_Rfloat2_rtnDv2_DF16_(<2 x half> noundef)
321+
declare dso_local spir_func i8 @_Z30__spirv_ConvertFToU_Ruchar_satf(float)

llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,19 @@
2121
; CHECK: OpFDiv %[[#]] %[[#di]]
2222
; CHECK: OpFSub %[[#]] %[[#su]]
2323
; CHECK: OpFMul %[[#]] %[[#mu]]
24-
; CHECK: OpFMul
25-
; CHECK: OpFAdd
2624
; CHECK: OpExtInst %[[#]] %[[#]] %[[#]] fma
2725
; CHECK: OpFRem
2826

2927
; Function Attrs: norecurse nounwind strictfp
3028
define dso_local spir_kernel void @test(float %a, i32 %in, i32 %ui) local_unnamed_addr #0 !kernel_arg_addr_space !5 !kernel_arg_access_qual !6 !kernel_arg_type !7 !kernel_arg_base_type !7 !kernel_arg_type_qual !8 !kernel_arg_buffer_location !9 {
3129
entry:
3230
%add = tail call float @llvm.experimental.constrained.fadd.f32(float %a, float %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #2
33-
%div = tail call float @llvm.experimental.constrained.fdiv.f32(float %add, float %add, metadata !"round.towardzero", metadata !"fpexcept.strict") #2, !fpmath !10
34-
%sub = tail call float @llvm.experimental.constrained.fsub.f32(float %div, float %div, metadata !"round.upward", metadata !"fpexcept.strict") #2
35-
%mul = tail call float @llvm.experimental.constrained.fmul.f32(float %sub, float %sub, metadata !"round.downward", metadata !"fpexcept.strict") #2
36-
; TODO: @llvm.experimental.constrained.fmuladd is not supported at the moment
37-
; %0 = tail call float @llvm.experimental.constrained.fmuladd.f32(float %mul, float %mul, float %mul, metadata !"round.tonearestaway", metadata !"fpexcept.strict") #2
38-
%r1 = tail call float @llvm.experimental.constrained.fma.f32(float %a, float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict") #2
39-
%r2 = tail call float @llvm.experimental.constrained.frem.f32(float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict") #2
31+
%add2 = fadd float %a, %a
32+
; %div = tail call float @llvm.experimental.constrained.fdiv.f32(float %a, float %a, metadata !"round.towardzero", metadata !"fpexcept.strict") #2, !fpmath !10
33+
; %sub = tail call float @llvm.experimental.constrained.fsub.f32(float %a, float %a, metadata !"round.upward", metadata !"fpexcept.strict") #2
34+
; %mul = tail call float @llvm.experimental.constrained.fmul.f32(float %a, float %a, metadata !"round.downward", metadata !"fpexcept.strict") #2
35+
; %r1 = tail call float @llvm.experimental.constrained.fma.f32(float %a, float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict") #2
36+
; %r2 = tail call float @llvm.experimental.constrained.frem.f32(float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict") #2
4037
ret void
4138
}
4239

0 commit comments

Comments
 (0)