Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2463,6 +2463,11 @@ class LLVM_ABI MachineIRBuilder {
return buildInstr(TargetOpcode::G_GET_ROUNDING, {Dst}, {});
}

/// Build and insert G_SET_ROUNDING
MachineInstrBuilder buildSetRounding(const SrcOp &Src) {
return buildInstr(TargetOpcode::G_SET_ROUNDING, {}, {Src});
}

virtual MachineInstrBuilder
buildInstr(unsigned Opc, ArrayRef<DstOp> DstOps, ArrayRef<SrcOp> SrcOps,
std::optional<unsigned> Flags = std::nullopt);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Support/TargetOpcodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ HANDLE_TARGET_OPCODE(G_SET_FPMODE)
HANDLE_TARGET_OPCODE(G_RESET_FPMODE)

HANDLE_TARGET_OPCODE(G_GET_ROUNDING)
HANDLE_TARGET_OPCODE(G_SET_ROUNDING)

/// Generic pointer offset
HANDLE_TARGET_OPCODE(G_PTR_ADD)
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Target/GenericOpcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,12 @@ def G_GET_ROUNDING : GenericInstruction {
let hasSideEffects = true;
}

def G_SET_ROUNDING : GenericInstruction {
let OutOperandList = (outs);
let InOperandList = (ins type0:$src);
let hasSideEffects = true;
}

//------------------------------------------------------------------------------
// Memory ops
//------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2607,6 +2607,9 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
case Intrinsic::get_rounding:
MIRBuilder.buildGetRounding(getOrCreateVReg(CI));
return true;
case Intrinsic::set_rounding:
MIRBuilder.buildSetRounding(getOrCreateVReg(*CI.getOperand(0)));
return true;
case Intrinsic::vscale: {
MIRBuilder.buildVScale(getOrCreateVReg(CI), 1);
return true;
Expand Down
133 changes: 132 additions & 1 deletion llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/Type.h"

using namespace llvm;
Expand Down Expand Up @@ -110,7 +111,8 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
.legalFor(HasSSE2 || UseX87, {s64})
.legalFor(UseX87, {s80});

getActionDefinitionsBuilder(G_GET_ROUNDING).customFor({s32});
getActionDefinitionsBuilder({G_GET_ROUNDING, G_SET_ROUNDING})
.customFor({s32});

// merge/unmerge
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
Expand Down Expand Up @@ -617,6 +619,8 @@ bool X86LegalizerInfo::legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
return legalizeFPTOSI(MI, MRI, Helper);
case TargetOpcode::G_GET_ROUNDING:
return legalizeGETROUNDING(MI, MRI, Helper);
case TargetOpcode::G_SET_ROUNDING:
return legalizeSETROUNDING(MI, MRI, Helper);
}
llvm_unreachable("expected switch to return");
}
Expand Down Expand Up @@ -859,6 +863,133 @@ bool X86LegalizerInfo::legalizeGETROUNDING(MachineInstr &MI,
return true;
}

bool X86LegalizerInfo::legalizeSETROUNDING(MachineInstr &MI,
MachineRegisterInfo &MRI,
LegalizerHelper &Helper) const {
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
MachineFunction &MF = MIRBuilder.getMF();
Register Src = MI.getOperand(0).getReg();
const LLT s8 = LLT::scalar(8);
const LLT s16 = LLT::scalar(16);
const LLT s32 = LLT::scalar(32);

// Allocate stack slot for control word and MXCSR (4 bytes).
int MemSize = 4;
Align Alignment = Align(4);
MachinePointerInfo PtrInfo;
auto StackTemp = Helper.createStackTemporary(TypeSize::getFixed(MemSize),
Alignment, PtrInfo);
Register StackPtr = StackTemp.getReg(0);

auto StoreMMO =
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOStore, 2, Align(2));
MIRBuilder.buildInstr(X86::G_FNSTCW16)
.addUse(StackPtr)
.addMemOperand(StoreMMO);

auto LoadMMO =
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad, 2, Align(2));
auto CWD16 = MIRBuilder.buildLoad(s16, StackPtr, *LoadMMO);

// Clear RM field (bits 11:10)
auto ClearedCWD =
MIRBuilder.buildAnd(s16, CWD16, MIRBuilder.buildConstant(s16, 0xf3ff));

// Check if Src is a constant
auto *SrcDef = MRI.getVRegDef(Src);
Register RMBits;
Register MXCSRRMBits;

if (SrcDef && SrcDef->getOpcode() == TargetOpcode::G_CONSTANT) {
uint64_t RM = getIConstantFromReg(Src, MRI).getZExtValue();
int FieldVal = X86::getRoundingModeX86(RM);

if (FieldVal == X86::rmInvalid) {
LLVMContext &C = MF.getFunction().getContext();
C.diagnose(DiagnosticInfoUnsupported(
MF.getFunction(), "rounding mode is not supported by X86 hardware",
DiagnosticLocation(MI.getDebugLoc()), DS_Error));
return false;
}

FieldVal = FieldVal << 3;
RMBits = MIRBuilder.buildConstant(s16, FieldVal).getReg(0);
MXCSRRMBits = MIRBuilder.buildConstant(s32, FieldVal).getReg(0);
} else {
// Convert Src (rounding mode) to bits for control word
// (0xc9 << (2 * Src + 4)) & 0xc00
auto Src32 = MIRBuilder.buildZExtOrTrunc(s32, Src);
auto ShiftAmt = MIRBuilder.buildAdd(
s32, MIRBuilder.buildShl(s32, Src32, MIRBuilder.buildConstant(s32, 1)),
MIRBuilder.buildConstant(s32, 4));
auto ShiftAmt8 = MIRBuilder.buildTrunc(s8, ShiftAmt);
auto Shifted = MIRBuilder.buildShl(s16, MIRBuilder.buildConstant(s16, 0xc9),
ShiftAmt8);
RMBits =
MIRBuilder.buildAnd(s16, Shifted, MIRBuilder.buildConstant(s16, 0xc00))
.getReg(0);

// For non-constant case, we still need to compute MXCSR bits dynamically
auto RMBits32 = MIRBuilder.buildZExt(s32, RMBits);
MXCSRRMBits =
MIRBuilder.buildShl(s32, RMBits32, MIRBuilder.buildConstant(s32, 3))
.getReg(0);
}
// Update rounding mode bits
auto NewCWD =
MIRBuilder.buildOr(s16, ClearedCWD, RMBits, MachineInstr::Disjoint);

// Store new FP Control Word to stack
auto StoreNewMMO =
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOStore, 2, Align(2));
MIRBuilder.buildStore(NewCWD, StackPtr, *StoreNewMMO);

// Load FP control word from the slot using G_FLDCW16
auto LoadNewMMO =
MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad, 2, Align(2));
MIRBuilder.buildInstr(X86::G_FLDCW16)
.addUse(StackPtr)
.addMemOperand(LoadNewMMO);

if (Subtarget.hasSSE1()) {
// Store MXCSR to stack (use STMXCSR)
auto StoreMXCSRMMO = MF.getMachineMemOperand(
PtrInfo, MachineMemOperand::MOStore, 4, Align(4));
MIRBuilder.buildInstr(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
.addIntrinsicID(Intrinsic::x86_sse_stmxcsr)
.addUse(StackPtr)
.addMemOperand(StoreMXCSRMMO);

// Load MXCSR from stack
auto LoadMXCSRMMO = MF.getMachineMemOperand(
PtrInfo, MachineMemOperand::MOLoad, 4, Align(4));
auto MXCSR = MIRBuilder.buildLoad(s32, StackPtr, *LoadMXCSRMMO);

// Clear RM field (bits 14:13)
auto ClearedMXCSR = MIRBuilder.buildAnd(
s32, MXCSR, MIRBuilder.buildConstant(s32, 0xffff9fff));

// Update rounding mode bits
auto NewMXCSR = MIRBuilder.buildOr(s32, ClearedMXCSR, MXCSRRMBits);

// Store new MXCSR to stack
auto StoreNewMXCSRMMO = MF.getMachineMemOperand(
PtrInfo, MachineMemOperand::MOStore, 4, Align(4));
MIRBuilder.buildStore(NewMXCSR, StackPtr, *StoreNewMXCSRMMO);

// Load MXCSR from stack (use LDMXCSR)
auto LoadNewMXCSRMMO = MF.getMachineMemOperand(
PtrInfo, MachineMemOperand::MOLoad, 4, Align(4));
MIRBuilder.buildInstr(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
.addIntrinsicID(Intrinsic::x86_sse_ldmxcsr)
.addUse(StackPtr)
.addMemOperand(LoadNewMXCSRMMO);
}

MI.eraseFromParent();
return true;
}

bool X86LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
MachineInstr &MI) const {
return true;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/X86/GISel/X86LegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class X86LegalizerInfo : public LegalizerInfo {

bool legalizeGETROUNDING(MachineInstr &MI, MachineRegisterInfo &MRI,
LegalizerHelper &Helper) const;

bool legalizeSETROUNDING(MachineInstr &MI, MachineRegisterInfo &MRI,
LegalizerHelper &Helper) const;
};
} // namespace llvm
#endif
31 changes: 21 additions & 10 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5346,6 +5346,19 @@ bool isConstantSplat(SDValue Op, APInt &SplatVal, bool AllowPartialUndefs) {

return false;
}

int getRoundingModeX86(unsigned RM) {
switch (static_cast<::llvm::RoundingMode>(RM)) {
// clang-format off
case ::llvm::RoundingMode::NearestTiesToEven: return X86::rmToNearest; break;
case ::llvm::RoundingMode::TowardNegative: return X86::rmDownward; break;
case ::llvm::RoundingMode::TowardPositive: return X86::rmUpward; break;
case ::llvm::RoundingMode::TowardZero: return X86::rmTowardZero; break;
default:
return X86::rmInvalid; // Invalid rounding mode
}
}

} // namespace X86
} // namespace llvm

Expand Down Expand Up @@ -28698,16 +28711,14 @@ SDValue X86TargetLowering::LowerSET_ROUNDING(SDValue Op,
SDValue RMBits;
if (auto *CVal = dyn_cast<ConstantSDNode>(NewRM)) {
uint64_t RM = CVal->getZExtValue();
int FieldVal;
switch (static_cast<RoundingMode>(RM)) {
// clang-format off
case RoundingMode::NearestTiesToEven: FieldVal = X86::rmToNearest; break;
case RoundingMode::TowardNegative: FieldVal = X86::rmDownward; break;
case RoundingMode::TowardPositive: FieldVal = X86::rmUpward; break;
case RoundingMode::TowardZero: FieldVal = X86::rmTowardZero; break;
default:
llvm_unreachable("rounding mode is not supported by X86 hardware");
// clang-format on
int FieldVal = X86::getRoundingModeX86(RM);

if (FieldVal == X86::rmInvalid) {
LLVMContext &C = MF.getFunction().getContext();
C.diagnose(DiagnosticInfoUnsupported(
MF.getFunction(), "rounding mode is not supported by X86 hardware",
DiagnosticLocation(DL.getDebugLoc()), DS_Error));
return {};
}
RMBits = DAG.getConstant(FieldVal, DL, MVT::i16);
} else {
Expand Down
19 changes: 12 additions & 7 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1004,13 +1004,14 @@ namespace llvm {
/// Current rounding mode is represented in bits 11:10 of FPSR. These
/// values are same as corresponding constants for rounding mode used
/// in glibc.
enum RoundingMode {
rmToNearest = 0, // FE_TONEAREST
rmDownward = 1 << 10, // FE_DOWNWARD
rmUpward = 2 << 10, // FE_UPWARD
rmTowardZero = 3 << 10, // FE_TOWARDZERO
rmMask = 3 << 10 // Bit mask selecting rounding mode
};
enum RoundingMode {
rmInvalid = -1, // For handle Invalid rounding mode
rmToNearest = 0, // FE_TONEAREST
rmDownward = 1 << 10, // FE_DOWNWARD
rmUpward = 2 << 10, // FE_UPWARD
rmTowardZero = 3 << 10, // FE_TOWARDZERO
rmMask = 3 << 10 // Bit mask selecting rounding mode
};
}

/// Define some predicates that are used for node matching.
Expand Down Expand Up @@ -1058,6 +1059,10 @@ namespace llvm {
/// functions.
bool isExtendedSwiftAsyncFrameSupported(const X86Subtarget &Subtarget,
const MachineFunction &MF);

/// Convert LLVM rounding mode to X86 rounding mode.
int getRoundingModeX86(unsigned RM);

} // end namespace X86

//===--------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/X86/X86InstrGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def G_FNSTCW16 : X86GenericInstruction {
let mayStore = true;
}

def G_FLDCW16 : X86GenericInstruction {
let OutOperandList = (outs);
let InOperandList = (ins ptype0:$src);
let hasSideEffects = true;
let mayLoad = true;
}

def : GINodeEquiv<G_FILD, X86fild>;
def : GINodeEquiv<G_FIST, X86fp_to_mem>;
def : GINodeEquiv<G_FNSTCW16, X86fp_cwd_get16>;
def : GINodeEquiv<G_FLDCW16, X86fp_cwd_set16>;
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,9 @@
# DEBUG-NEXT: G_GET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_SET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_PTR_ADD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@
# DEBUG-NEXT: G_GET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_SET_ROUNDING (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT:.. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT:.. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_PTR_ADD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
Expand Down
Loading
Loading