Skip to content

Commit 0c1087b

Browse files
[X86][GlobalISel] Added support for llvm.set.rounding (#156591)
- This implementation is adapted from **SDAG X86TargetLowering::LowerSET_ROUNDING**.
1 parent a05b232 commit 0c1087b

File tree

15 files changed

+519
-28
lines changed

15 files changed

+519
-28
lines changed

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,6 +2463,11 @@ class LLVM_ABI MachineIRBuilder {
24632463
return buildInstr(TargetOpcode::G_GET_ROUNDING, {Dst}, {});
24642464
}
24652465

2466+
/// Build and insert G_SET_ROUNDING
2467+
MachineInstrBuilder buildSetRounding(const SrcOp &Src) {
2468+
return buildInstr(TargetOpcode::G_SET_ROUNDING, {}, {Src});
2469+
}
2470+
24662471
virtual MachineInstrBuilder
24672472
buildInstr(unsigned Opc, ArrayRef<DstOp> DstOps, ArrayRef<SrcOp> SrcOps,
24682473
std::optional<unsigned> Flags = std::nullopt);

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ HANDLE_TARGET_OPCODE(G_SET_FPMODE)
745745
HANDLE_TARGET_OPCODE(G_RESET_FPMODE)
746746

747747
HANDLE_TARGET_OPCODE(G_GET_ROUNDING)
748+
HANDLE_TARGET_OPCODE(G_SET_ROUNDING)
748749

749750
/// Generic pointer offset
750751
HANDLE_TARGET_OPCODE(G_PTR_ADD)

llvm/include/llvm/Target/GenericOpcodes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,12 @@ def G_GET_ROUNDING : GenericInstruction {
12731273
let hasSideEffects = true;
12741274
}
12751275

1276+
def G_SET_ROUNDING : GenericInstruction {
1277+
let OutOperandList = (outs);
1278+
let InOperandList = (ins type0:$src);
1279+
let hasSideEffects = true;
1280+
}
1281+
12761282
//------------------------------------------------------------------------------
12771283
// Memory ops
12781284
//------------------------------------------------------------------------------

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,6 +2607,9 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
26072607
case Intrinsic::get_rounding:
26082608
MIRBuilder.buildGetRounding(getOrCreateVReg(CI));
26092609
return true;
2610+
case Intrinsic::set_rounding:
2611+
MIRBuilder.buildSetRounding(getOrCreateVReg(*CI.getOperand(0)));
2612+
return true;
26102613
case Intrinsic::vscale: {
26112614
MIRBuilder.buildVScale(getOrCreateVReg(CI), 1);
26122615
return true;

llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/CodeGen/TargetOpcodes.h"
2222
#include "llvm/CodeGen/ValueTypes.h"
2323
#include "llvm/IR/DerivedTypes.h"
24+
#include "llvm/IR/IntrinsicsX86.h"
2425
#include "llvm/IR/Type.h"
2526

2627
using namespace llvm;
@@ -110,7 +111,8 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
110111
.legalFor(HasSSE2 || UseX87, {s64})
111112
.legalFor(UseX87, {s80});
112113

113-
getActionDefinitionsBuilder(G_GET_ROUNDING).customFor({s32});
114+
getActionDefinitionsBuilder({G_GET_ROUNDING, G_SET_ROUNDING})
115+
.customFor({s32});
114116

115117
// merge/unmerge
116118
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
@@ -617,6 +619,8 @@ bool X86LegalizerInfo::legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
617619
return legalizeFPTOSI(MI, MRI, Helper);
618620
case TargetOpcode::G_GET_ROUNDING:
619621
return legalizeGETROUNDING(MI, MRI, Helper);
622+
case TargetOpcode::G_SET_ROUNDING:
623+
return legalizeSETROUNDING(MI, MRI, Helper);
620624
}
621625
llvm_unreachable("expected switch to return");
622626
}
@@ -859,6 +863,133 @@ bool X86LegalizerInfo::legalizeGETROUNDING(MachineInstr &MI,
859863
return true;
860864
}
861865

866+
bool X86LegalizerInfo::legalizeSETROUNDING(MachineInstr &MI,
867+
MachineRegisterInfo &MRI,
868+
LegalizerHelper &Helper) const {
869+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
870+
MachineFunction &MF = MIRBuilder.getMF();
871+
Register Src = MI.getOperand(0).getReg();
872+
const LLT s8 = LLT::scalar(8);
873+
const LLT s16 = LLT::scalar(16);
874+
const LLT s32 = LLT::scalar(32);
875+
876+
// Allocate stack slot for control word and MXCSR (4 bytes).
877+
int MemSize = 4;
878+
Align Alignment = Align(4);
879+
MachinePointerInfo PtrInfo;
880+
auto StackTemp = Helper.createStackTemporary(TypeSize::getFixed(MemSize),
881+
Alignment, PtrInfo);
882+
Register StackPtr = StackTemp.getReg(0);
883+
884+
auto StoreMMO =
885+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOStore, 2, Align(2));
886+
MIRBuilder.buildInstr(X86::G_FNSTCW16)
887+
.addUse(StackPtr)
888+
.addMemOperand(StoreMMO);
889+
890+
auto LoadMMO =
891+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad, 2, Align(2));
892+
auto CWD16 = MIRBuilder.buildLoad(s16, StackPtr, *LoadMMO);
893+
894+
// Clear RM field (bits 11:10)
895+
auto ClearedCWD =
896+
MIRBuilder.buildAnd(s16, CWD16, MIRBuilder.buildConstant(s16, 0xf3ff));
897+
898+
// Check if Src is a constant
899+
auto *SrcDef = MRI.getVRegDef(Src);
900+
Register RMBits;
901+
Register MXCSRRMBits;
902+
903+
if (SrcDef && SrcDef->getOpcode() == TargetOpcode::G_CONSTANT) {
904+
uint64_t RM = getIConstantFromReg(Src, MRI).getZExtValue();
905+
int FieldVal = X86::getRoundingModeX86(RM);
906+
907+
if (FieldVal == X86::rmInvalid) {
908+
LLVMContext &C = MF.getFunction().getContext();
909+
C.diagnose(DiagnosticInfoUnsupported(
910+
MF.getFunction(), "rounding mode is not supported by X86 hardware",
911+
DiagnosticLocation(MI.getDebugLoc()), DS_Error));
912+
return false;
913+
}
914+
915+
FieldVal = FieldVal << 3;
916+
RMBits = MIRBuilder.buildConstant(s16, FieldVal).getReg(0);
917+
MXCSRRMBits = MIRBuilder.buildConstant(s32, FieldVal).getReg(0);
918+
} else {
919+
// Convert Src (rounding mode) to bits for control word
920+
// (0xc9 << (2 * Src + 4)) & 0xc00
921+
auto Src32 = MIRBuilder.buildZExtOrTrunc(s32, Src);
922+
auto ShiftAmt = MIRBuilder.buildAdd(
923+
s32, MIRBuilder.buildShl(s32, Src32, MIRBuilder.buildConstant(s32, 1)),
924+
MIRBuilder.buildConstant(s32, 4));
925+
auto ShiftAmt8 = MIRBuilder.buildTrunc(s8, ShiftAmt);
926+
auto Shifted = MIRBuilder.buildShl(s16, MIRBuilder.buildConstant(s16, 0xc9),
927+
ShiftAmt8);
928+
RMBits =
929+
MIRBuilder.buildAnd(s16, Shifted, MIRBuilder.buildConstant(s16, 0xc00))
930+
.getReg(0);
931+
932+
// For non-constant case, we still need to compute MXCSR bits dynamically
933+
auto RMBits32 = MIRBuilder.buildZExt(s32, RMBits);
934+
MXCSRRMBits =
935+
MIRBuilder.buildShl(s32, RMBits32, MIRBuilder.buildConstant(s32, 3))
936+
.getReg(0);
937+
}
938+
// Update rounding mode bits
939+
auto NewCWD =
940+
MIRBuilder.buildOr(s16, ClearedCWD, RMBits, MachineInstr::Disjoint);
941+
942+
// Store new FP Control Word to stack
943+
auto StoreNewMMO =
944+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOStore, 2, Align(2));
945+
MIRBuilder.buildStore(NewCWD, StackPtr, *StoreNewMMO);
946+
947+
// Load FP control word from the slot using G_FLDCW16
948+
auto LoadNewMMO =
949+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad, 2, Align(2));
950+
MIRBuilder.buildInstr(X86::G_FLDCW16)
951+
.addUse(StackPtr)
952+
.addMemOperand(LoadNewMMO);
953+
954+
if (Subtarget.hasSSE1()) {
955+
// Store MXCSR to stack (use STMXCSR)
956+
auto StoreMXCSRMMO = MF.getMachineMemOperand(
957+
PtrInfo, MachineMemOperand::MOStore, 4, Align(4));
958+
MIRBuilder.buildInstr(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
959+
.addIntrinsicID(Intrinsic::x86_sse_stmxcsr)
960+
.addUse(StackPtr)
961+
.addMemOperand(StoreMXCSRMMO);
962+
963+
// Load MXCSR from stack
964+
auto LoadMXCSRMMO = MF.getMachineMemOperand(
965+
PtrInfo, MachineMemOperand::MOLoad, 4, Align(4));
966+
auto MXCSR = MIRBuilder.buildLoad(s32, StackPtr, *LoadMXCSRMMO);
967+
968+
// Clear RM field (bits 14:13)
969+
auto ClearedMXCSR = MIRBuilder.buildAnd(
970+
s32, MXCSR, MIRBuilder.buildConstant(s32, 0xffff9fff));
971+
972+
// Update rounding mode bits
973+
auto NewMXCSR = MIRBuilder.buildOr(s32, ClearedMXCSR, MXCSRRMBits);
974+
975+
// Store new MXCSR to stack
976+
auto StoreNewMXCSRMMO = MF.getMachineMemOperand(
977+
PtrInfo, MachineMemOperand::MOStore, 4, Align(4));
978+
MIRBuilder.buildStore(NewMXCSR, StackPtr, *StoreNewMXCSRMMO);
979+
980+
// Load MXCSR from stack (use LDMXCSR)
981+
auto LoadNewMXCSRMMO = MF.getMachineMemOperand(
982+
PtrInfo, MachineMemOperand::MOLoad, 4, Align(4));
983+
MIRBuilder.buildInstr(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
984+
.addIntrinsicID(Intrinsic::x86_sse_ldmxcsr)
985+
.addUse(StackPtr)
986+
.addMemOperand(LoadNewMXCSRMMO);
987+
}
988+
989+
MI.eraseFromParent();
990+
return true;
991+
}
992+
862993
bool X86LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
863994
MachineInstr &MI) const {
864995
return true;

llvm/lib/Target/X86/GISel/X86LegalizerInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class X86LegalizerInfo : public LegalizerInfo {
5757

5858
bool legalizeGETROUNDING(MachineInstr &MI, MachineRegisterInfo &MRI,
5959
LegalizerHelper &Helper) const;
60+
61+
bool legalizeSETROUNDING(MachineInstr &MI, MachineRegisterInfo &MRI,
62+
LegalizerHelper &Helper) const;
6063
};
6164
} // namespace llvm
6265
#endif

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5346,6 +5346,19 @@ bool isConstantSplat(SDValue Op, APInt &SplatVal, bool AllowPartialUndefs) {
53465346

53475347
return false;
53485348
}
5349+
5350+
int getRoundingModeX86(unsigned RM) {
5351+
switch (static_cast<::llvm::RoundingMode>(RM)) {
5352+
// clang-format off
5353+
case ::llvm::RoundingMode::NearestTiesToEven: return X86::rmToNearest; break;
5354+
case ::llvm::RoundingMode::TowardNegative: return X86::rmDownward; break;
5355+
case ::llvm::RoundingMode::TowardPositive: return X86::rmUpward; break;
5356+
case ::llvm::RoundingMode::TowardZero: return X86::rmTowardZero; break;
5357+
default:
5358+
return X86::rmInvalid; // Invalid rounding mode
5359+
}
5360+
}
5361+
53495362
} // namespace X86
53505363
} // namespace llvm
53515364

@@ -28698,16 +28711,14 @@ SDValue X86TargetLowering::LowerSET_ROUNDING(SDValue Op,
2869828711
SDValue RMBits;
2869928712
if (auto *CVal = dyn_cast<ConstantSDNode>(NewRM)) {
2870028713
uint64_t RM = CVal->getZExtValue();
28701-
int FieldVal;
28702-
switch (static_cast<RoundingMode>(RM)) {
28703-
// clang-format off
28704-
case RoundingMode::NearestTiesToEven: FieldVal = X86::rmToNearest; break;
28705-
case RoundingMode::TowardNegative: FieldVal = X86::rmDownward; break;
28706-
case RoundingMode::TowardPositive: FieldVal = X86::rmUpward; break;
28707-
case RoundingMode::TowardZero: FieldVal = X86::rmTowardZero; break;
28708-
default:
28709-
llvm_unreachable("rounding mode is not supported by X86 hardware");
28710-
// clang-format on
28714+
int FieldVal = X86::getRoundingModeX86(RM);
28715+
28716+
if (FieldVal == X86::rmInvalid) {
28717+
LLVMContext &C = MF.getFunction().getContext();
28718+
C.diagnose(DiagnosticInfoUnsupported(
28719+
MF.getFunction(), "rounding mode is not supported by X86 hardware",
28720+
DiagnosticLocation(DL.getDebugLoc()), DS_Error));
28721+
return {};
2871128722
}
2871228723
RMBits = DAG.getConstant(FieldVal, DL, MVT::i16);
2871328724
} else {

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,13 +1004,14 @@ namespace llvm {
10041004
/// Current rounding mode is represented in bits 11:10 of FPSR. These
10051005
/// values are same as corresponding constants for rounding mode used
10061006
/// in glibc.
1007-
enum RoundingMode {
1008-
rmToNearest = 0, // FE_TONEAREST
1009-
rmDownward = 1 << 10, // FE_DOWNWARD
1010-
rmUpward = 2 << 10, // FE_UPWARD
1011-
rmTowardZero = 3 << 10, // FE_TOWARDZERO
1012-
rmMask = 3 << 10 // Bit mask selecting rounding mode
1013-
};
1007+
enum RoundingMode {
1008+
rmInvalid = -1, // For handle Invalid rounding mode
1009+
rmToNearest = 0, // FE_TONEAREST
1010+
rmDownward = 1 << 10, // FE_DOWNWARD
1011+
rmUpward = 2 << 10, // FE_UPWARD
1012+
rmTowardZero = 3 << 10, // FE_TOWARDZERO
1013+
rmMask = 3 << 10 // Bit mask selecting rounding mode
1014+
};
10141015
}
10151016

10161017
/// Define some predicates that are used for node matching.
@@ -1058,6 +1059,10 @@ namespace llvm {
10581059
/// functions.
10591060
bool isExtendedSwiftAsyncFrameSupported(const X86Subtarget &Subtarget,
10601061
const MachineFunction &MF);
1062+
1063+
/// Convert LLVM rounding mode to X86 rounding mode.
1064+
int getRoundingModeX86(unsigned RM);
1065+
10611066
} // end namespace X86
10621067

10631068
//===--------------------------------------------------------------------===//

llvm/lib/Target/X86/X86InstrGISel.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def G_FNSTCW16 : X86GenericInstruction {
3434
let mayStore = true;
3535
}
3636

37+
def G_FLDCW16 : X86GenericInstruction {
38+
let OutOperandList = (outs);
39+
let InOperandList = (ins ptype0:$src);
40+
let hasSideEffects = true;
41+
let mayLoad = true;
42+
}
43+
3744
def : GINodeEquiv<G_FILD, X86fild>;
3845
def : GINodeEquiv<G_FIST, X86fp_to_mem>;
3946
def : GINodeEquiv<G_FNSTCW16, X86fp_cwd_get16>;
47+
def : GINodeEquiv<G_FLDCW16, X86fp_cwd_set16>;

llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,9 @@
642642
# DEBUG-NEXT: G_GET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
643643
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
644644
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
645+
# DEBUG-NEXT: G_SET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
646+
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
647+
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
645648
# DEBUG-NEXT: G_PTR_ADD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
646649
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
647650
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK

0 commit comments

Comments
 (0)