Skip to content

Commit 2e941ea

Browse files
JaydeepChauhan14mahesh-attarde
authored andcommitted
[X86][GlobalISel] Added support for llvm.set.rounding (llvm#156591)
- This implementation is adapted from **SDAG X86TargetLowering::LowerSET_ROUNDING**.
1 parent dfa32f6 commit 2e941ea

File tree

15 files changed

+519
-28
lines changed

15 files changed

+519
-28
lines changed

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,6 +2463,11 @@ class LLVM_ABI MachineIRBuilder {
24632463
return buildInstr(TargetOpcode::G_GET_ROUNDING, {Dst}, {});
24642464
}
24652465

2466+
/// Build and insert G_SET_ROUNDING
2467+
MachineInstrBuilder buildSetRounding(const SrcOp &Src) {
2468+
return buildInstr(TargetOpcode::G_SET_ROUNDING, {}, {Src});
2469+
}
2470+
24662471
virtual MachineInstrBuilder
24672472
buildInstr(unsigned Opc, ArrayRef<DstOp> DstOps, ArrayRef<SrcOp> SrcOps,
24682473
std::optional<unsigned> Flags = std::nullopt);

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ HANDLE_TARGET_OPCODE(G_SET_FPMODE)
745745
HANDLE_TARGET_OPCODE(G_RESET_FPMODE)
746746

747747
HANDLE_TARGET_OPCODE(G_GET_ROUNDING)
748+
HANDLE_TARGET_OPCODE(G_SET_ROUNDING)
748749

749750
/// Generic pointer offset
750751
HANDLE_TARGET_OPCODE(G_PTR_ADD)

llvm/include/llvm/Target/GenericOpcodes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,12 @@ def G_GET_ROUNDING : GenericInstruction {
12731273
let hasSideEffects = true;
12741274
}
12751275

1276+
def G_SET_ROUNDING : GenericInstruction {
1277+
let OutOperandList = (outs);
1278+
let InOperandList = (ins type0:$src);
1279+
let hasSideEffects = true;
1280+
}
1281+
12761282
//------------------------------------------------------------------------------
12771283
// Memory ops
12781284
//------------------------------------------------------------------------------

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,6 +2607,9 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
26072607
case Intrinsic::get_rounding:
26082608
MIRBuilder.buildGetRounding(getOrCreateVReg(CI));
26092609
return true;
2610+
case Intrinsic::set_rounding:
2611+
MIRBuilder.buildSetRounding(getOrCreateVReg(*CI.getOperand(0)));
2612+
return true;
26102613
case Intrinsic::vscale: {
26112614
MIRBuilder.buildVScale(getOrCreateVReg(CI), 1);
26122615
return true;

llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/CodeGen/TargetOpcodes.h"
2222
#include "llvm/CodeGen/ValueTypes.h"
2323
#include "llvm/IR/DerivedTypes.h"
24+
#include "llvm/IR/IntrinsicsX86.h"
2425
#include "llvm/IR/Type.h"
2526

2627
using namespace llvm;
@@ -110,7 +111,8 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
110111
.legalFor(HasSSE2 || UseX87, {s64})
111112
.legalFor(UseX87, {s80});
112113

113-
getActionDefinitionsBuilder(G_GET_ROUNDING).customFor({s32});
114+
getActionDefinitionsBuilder({G_GET_ROUNDING, G_SET_ROUNDING})
115+
.customFor({s32});
114116

115117
// merge/unmerge
116118
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
@@ -622,6 +624,8 @@ bool X86LegalizerInfo::legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
622624
return legalizeFPTOSI(MI, MRI, Helper);
623625
case TargetOpcode::G_GET_ROUNDING:
624626
return legalizeGETROUNDING(MI, MRI, Helper);
627+
case TargetOpcode::G_SET_ROUNDING:
628+
return legalizeSETROUNDING(MI, MRI, Helper);
625629
}
626630
llvm_unreachable("expected switch to return");
627631
}
@@ -864,6 +868,133 @@ bool X86LegalizerInfo::legalizeGETROUNDING(MachineInstr &MI,
864868
return true;
865869
}
866870

871+
bool X86LegalizerInfo::legalizeSETROUNDING(MachineInstr &MI,
872+
MachineRegisterInfo &MRI,
873+
LegalizerHelper &Helper) const {
874+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
875+
MachineFunction &MF = MIRBuilder.getMF();
876+
Register Src = MI.getOperand(0).getReg();
877+
const LLT s8 = LLT::scalar(8);
878+
const LLT s16 = LLT::scalar(16);
879+
const LLT s32 = LLT::scalar(32);
880+
881+
// Allocate stack slot for control word and MXCSR (4 bytes).
882+
int MemSize = 4;
883+
Align Alignment = Align(4);
884+
MachinePointerInfo PtrInfo;
885+
auto StackTemp = Helper.createStackTemporary(TypeSize::getFixed(MemSize),
886+
Alignment, PtrInfo);
887+
Register StackPtr = StackTemp.getReg(0);
888+
889+
auto StoreMMO =
890+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOStore, 2, Align(2));
891+
MIRBuilder.buildInstr(X86::G_FNSTCW16)
892+
.addUse(StackPtr)
893+
.addMemOperand(StoreMMO);
894+
895+
auto LoadMMO =
896+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad, 2, Align(2));
897+
auto CWD16 = MIRBuilder.buildLoad(s16, StackPtr, *LoadMMO);
898+
899+
// Clear RM field (bits 11:10)
900+
auto ClearedCWD =
901+
MIRBuilder.buildAnd(s16, CWD16, MIRBuilder.buildConstant(s16, 0xf3ff));
902+
903+
// Check if Src is a constant
904+
auto *SrcDef = MRI.getVRegDef(Src);
905+
Register RMBits;
906+
Register MXCSRRMBits;
907+
908+
if (SrcDef && SrcDef->getOpcode() == TargetOpcode::G_CONSTANT) {
909+
uint64_t RM = getIConstantFromReg(Src, MRI).getZExtValue();
910+
int FieldVal = X86::getRoundingModeX86(RM);
911+
912+
if (FieldVal == X86::rmInvalid) {
913+
LLVMContext &C = MF.getFunction().getContext();
914+
C.diagnose(DiagnosticInfoUnsupported(
915+
MF.getFunction(), "rounding mode is not supported by X86 hardware",
916+
DiagnosticLocation(MI.getDebugLoc()), DS_Error));
917+
return false;
918+
}
919+
920+
FieldVal = FieldVal << 3;
921+
RMBits = MIRBuilder.buildConstant(s16, FieldVal).getReg(0);
922+
MXCSRRMBits = MIRBuilder.buildConstant(s32, FieldVal).getReg(0);
923+
} else {
924+
// Convert Src (rounding mode) to bits for control word
925+
// (0xc9 << (2 * Src + 4)) & 0xc00
926+
auto Src32 = MIRBuilder.buildZExtOrTrunc(s32, Src);
927+
auto ShiftAmt = MIRBuilder.buildAdd(
928+
s32, MIRBuilder.buildShl(s32, Src32, MIRBuilder.buildConstant(s32, 1)),
929+
MIRBuilder.buildConstant(s32, 4));
930+
auto ShiftAmt8 = MIRBuilder.buildTrunc(s8, ShiftAmt);
931+
auto Shifted = MIRBuilder.buildShl(s16, MIRBuilder.buildConstant(s16, 0xc9),
932+
ShiftAmt8);
933+
RMBits =
934+
MIRBuilder.buildAnd(s16, Shifted, MIRBuilder.buildConstant(s16, 0xc00))
935+
.getReg(0);
936+
937+
// For non-constant case, we still need to compute MXCSR bits dynamically
938+
auto RMBits32 = MIRBuilder.buildZExt(s32, RMBits);
939+
MXCSRRMBits =
940+
MIRBuilder.buildShl(s32, RMBits32, MIRBuilder.buildConstant(s32, 3))
941+
.getReg(0);
942+
}
943+
// Update rounding mode bits
944+
auto NewCWD =
945+
MIRBuilder.buildOr(s16, ClearedCWD, RMBits, MachineInstr::Disjoint);
946+
947+
// Store new FP Control Word to stack
948+
auto StoreNewMMO =
949+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOStore, 2, Align(2));
950+
MIRBuilder.buildStore(NewCWD, StackPtr, *StoreNewMMO);
951+
952+
// Load FP control word from the slot using G_FLDCW16
953+
auto LoadNewMMO =
954+
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad, 2, Align(2));
955+
MIRBuilder.buildInstr(X86::G_FLDCW16)
956+
.addUse(StackPtr)
957+
.addMemOperand(LoadNewMMO);
958+
959+
if (Subtarget.hasSSE1()) {
960+
// Store MXCSR to stack (use STMXCSR)
961+
auto StoreMXCSRMMO = MF.getMachineMemOperand(
962+
PtrInfo, MachineMemOperand::MOStore, 4, Align(4));
963+
MIRBuilder.buildInstr(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
964+
.addIntrinsicID(Intrinsic::x86_sse_stmxcsr)
965+
.addUse(StackPtr)
966+
.addMemOperand(StoreMXCSRMMO);
967+
968+
// Load MXCSR from stack
969+
auto LoadMXCSRMMO = MF.getMachineMemOperand(
970+
PtrInfo, MachineMemOperand::MOLoad, 4, Align(4));
971+
auto MXCSR = MIRBuilder.buildLoad(s32, StackPtr, *LoadMXCSRMMO);
972+
973+
// Clear RM field (bits 14:13)
974+
auto ClearedMXCSR = MIRBuilder.buildAnd(
975+
s32, MXCSR, MIRBuilder.buildConstant(s32, 0xffff9fff));
976+
977+
// Update rounding mode bits
978+
auto NewMXCSR = MIRBuilder.buildOr(s32, ClearedMXCSR, MXCSRRMBits);
979+
980+
// Store new MXCSR to stack
981+
auto StoreNewMXCSRMMO = MF.getMachineMemOperand(
982+
PtrInfo, MachineMemOperand::MOStore, 4, Align(4));
983+
MIRBuilder.buildStore(NewMXCSR, StackPtr, *StoreNewMXCSRMMO);
984+
985+
// Load MXCSR from stack (use LDMXCSR)
986+
auto LoadNewMXCSRMMO = MF.getMachineMemOperand(
987+
PtrInfo, MachineMemOperand::MOLoad, 4, Align(4));
988+
MIRBuilder.buildInstr(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
989+
.addIntrinsicID(Intrinsic::x86_sse_ldmxcsr)
990+
.addUse(StackPtr)
991+
.addMemOperand(LoadNewMXCSRMMO);
992+
}
993+
994+
MI.eraseFromParent();
995+
return true;
996+
}
997+
867998
bool X86LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
868999
MachineInstr &MI) const {
8691000
return true;

llvm/lib/Target/X86/GISel/X86LegalizerInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class X86LegalizerInfo : public LegalizerInfo {
5757

5858
bool legalizeGETROUNDING(MachineInstr &MI, MachineRegisterInfo &MRI,
5959
LegalizerHelper &Helper) const;
60+
61+
bool legalizeSETROUNDING(MachineInstr &MI, MachineRegisterInfo &MRI,
62+
LegalizerHelper &Helper) const;
6063
};
6164
} // namespace llvm
6265
#endif

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5346,6 +5346,19 @@ bool isConstantSplat(SDValue Op, APInt &SplatVal, bool AllowPartialUndefs) {
53465346

53475347
return false;
53485348
}
5349+
5350+
int getRoundingModeX86(unsigned RM) {
5351+
switch (static_cast<::llvm::RoundingMode>(RM)) {
5352+
// clang-format off
5353+
case ::llvm::RoundingMode::NearestTiesToEven: return X86::rmToNearest; break;
5354+
case ::llvm::RoundingMode::TowardNegative: return X86::rmDownward; break;
5355+
case ::llvm::RoundingMode::TowardPositive: return X86::rmUpward; break;
5356+
case ::llvm::RoundingMode::TowardZero: return X86::rmTowardZero; break;
5357+
default:
5358+
return X86::rmInvalid; // Invalid rounding mode
5359+
}
5360+
}
5361+
53495362
} // namespace X86
53505363
} // namespace llvm
53515364

@@ -28698,16 +28711,14 @@ SDValue X86TargetLowering::LowerSET_ROUNDING(SDValue Op,
2869828711
SDValue RMBits;
2869928712
if (auto *CVal = dyn_cast<ConstantSDNode>(NewRM)) {
2870028713
uint64_t RM = CVal->getZExtValue();
28701-
int FieldVal;
28702-
switch (static_cast<RoundingMode>(RM)) {
28703-
// clang-format off
28704-
case RoundingMode::NearestTiesToEven: FieldVal = X86::rmToNearest; break;
28705-
case RoundingMode::TowardNegative: FieldVal = X86::rmDownward; break;
28706-
case RoundingMode::TowardPositive: FieldVal = X86::rmUpward; break;
28707-
case RoundingMode::TowardZero: FieldVal = X86::rmTowardZero; break;
28708-
default:
28709-
llvm_unreachable("rounding mode is not supported by X86 hardware");
28710-
// clang-format on
28714+
int FieldVal = X86::getRoundingModeX86(RM);
28715+
28716+
if (FieldVal == X86::rmInvalid) {
28717+
LLVMContext &C = MF.getFunction().getContext();
28718+
C.diagnose(DiagnosticInfoUnsupported(
28719+
MF.getFunction(), "rounding mode is not supported by X86 hardware",
28720+
DiagnosticLocation(DL.getDebugLoc()), DS_Error));
28721+
return {};
2871128722
}
2871228723
RMBits = DAG.getConstant(FieldVal, DL, MVT::i16);
2871328724
} else {

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,13 +1004,14 @@ namespace llvm {
10041004
/// Current rounding mode is represented in bits 11:10 of FPSR. These
10051005
/// values are same as corresponding constants for rounding mode used
10061006
/// in glibc.
1007-
enum RoundingMode {
1008-
rmToNearest = 0, // FE_TONEAREST
1009-
rmDownward = 1 << 10, // FE_DOWNWARD
1010-
rmUpward = 2 << 10, // FE_UPWARD
1011-
rmTowardZero = 3 << 10, // FE_TOWARDZERO
1012-
rmMask = 3 << 10 // Bit mask selecting rounding mode
1013-
};
1007+
enum RoundingMode {
1008+
rmInvalid = -1, // For handle Invalid rounding mode
1009+
rmToNearest = 0, // FE_TONEAREST
1010+
rmDownward = 1 << 10, // FE_DOWNWARD
1011+
rmUpward = 2 << 10, // FE_UPWARD
1012+
rmTowardZero = 3 << 10, // FE_TOWARDZERO
1013+
rmMask = 3 << 10 // Bit mask selecting rounding mode
1014+
};
10141015
}
10151016

10161017
/// Define some predicates that are used for node matching.
@@ -1058,6 +1059,10 @@ namespace llvm {
10581059
/// functions.
10591060
bool isExtendedSwiftAsyncFrameSupported(const X86Subtarget &Subtarget,
10601061
const MachineFunction &MF);
1062+
1063+
/// Convert LLVM rounding mode to X86 rounding mode.
1064+
int getRoundingModeX86(unsigned RM);
1065+
10611066
} // end namespace X86
10621067

10631068
//===--------------------------------------------------------------------===//

llvm/lib/Target/X86/X86InstrGISel.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def G_FNSTCW16 : X86GenericInstruction {
3434
let mayStore = true;
3535
}
3636

37+
def G_FLDCW16 : X86GenericInstruction {
38+
let OutOperandList = (outs);
39+
let InOperandList = (ins ptype0:$src);
40+
let hasSideEffects = true;
41+
let mayLoad = true;
42+
}
43+
3744
def : GINodeEquiv<G_FILD, X86fild>;
3845
def : GINodeEquiv<G_FIST, X86fp_to_mem>;
3946
def : GINodeEquiv<G_FNSTCW16, X86fp_cwd_get16>;
47+
def : GINodeEquiv<G_FLDCW16, X86fp_cwd_set16>;

llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,9 @@
642642
# DEBUG-NEXT: G_GET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
643643
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
644644
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
645+
# DEBUG-NEXT: G_SET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
646+
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
647+
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
645648
# DEBUG-NEXT: G_PTR_ADD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
646649
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
647650
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK

0 commit comments

Comments
 (0)