Skip to content

Commit 840da2e

Browse files
[SandboxIR] Implement CmpInst, FCmpInst, and ICmpInst (#106301)
As in the description. Not sure the macros for "WRAP_XXX" add value or not, but do save some boiler plate. Maybe there is a better way.
1 parent 143f3fc commit 840da2e

File tree

8 files changed

+500
-6
lines changed

8 files changed

+500
-6
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 170 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,17 @@
5454
// | |
5555
// | +- ZExtInst
5656
// |
57-
// +- CallBase -----------+- CallBrInst
58-
// | |
59-
// +- CmpInst +- CallInst
60-
// | |
61-
// +- ExtractElementInst +- InvokeInst
57+
// +- CallBase --------+- CallBrInst
58+
// | |
59+
// | +- CallInst
60+
// | |
61+
// | +- InvokeInst
62+
// |
63+
// +- CmpInst ---------+- ICmpInst
64+
// | |
65+
// | +- FCmpInst
66+
// |
67+
// +- ExtractElementInst
6268
// |
6369
// +- GetElementPtrInst
6470
// |
@@ -158,6 +164,9 @@ class BinaryOperator;
158164
class PossiblyDisjointInst;
159165
class AtomicRMWInst;
160166
class AtomicCmpXchgInst;
167+
class CmpInst;
168+
class ICmpInst;
169+
class FCmpInst;
161170

162171
/// Iterator for the `Use` edges of a User's operands.
163172
/// \Returns the operand `Use` when dereferenced.
@@ -304,6 +313,7 @@ class Value {
304313
friend class PHINode; // For getting `Val`.
305314
friend class UnreachableInst; // For getting `Val`.
306315
friend class CatchSwitchAddHandler; // For `Val`.
316+
friend class CmpInst; // For getting `Val`.
307317
friend class ConstantArray; // For `Val`.
308318
friend class ConstantStruct; // For `Val`.
309319

@@ -1076,6 +1086,7 @@ class Instruction : public sandboxir::User {
10761086
friend class CastInst; // For getTopmostLLVMInstruction().
10771087
friend class PHINode; // For getTopmostLLVMInstruction().
10781088
friend class UnreachableInst; // For getTopmostLLVMInstruction().
1089+
friend class CmpInst; // For getTopmostLLVMInstruction().
10791090

10801091
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
10811092
/// order.
@@ -1232,6 +1243,7 @@ template <typename LLVMT> class SingleLLVMInstructionImpl : public Instruction {
12321243
friend class UnaryInstruction;
12331244
friend class CallBase;
12341245
friend class FuncletPadInst;
1246+
friend class CmpInst;
12351247

12361248
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
12371249
return getOperandUseDefault(OpIdx, Verify);
@@ -3425,6 +3437,151 @@ class PHINode final : public SingleLLVMInstructionImpl<llvm::PHINode> {
34253437
// uint32_t ToIdx = 0)
34263438
};
34273439

3440+
// Wraps a static function that takes a single Predicate parameter
3441+
// LLVMValType should be the type of the wrapped class
3442+
#define WRAP_STATIC_PREDICATE(FunctionName) \
3443+
static auto FunctionName(Predicate P) { return LLVMValType::FunctionName(P); }
3444+
// Wraps a member function that takes no parameters
3445+
// LLVMValType should be the type of the wrapped class
3446+
#define WRAP_MEMBER(FunctionName) \
3447+
auto FunctionName() const { return cast<LLVMValType>(Val)->FunctionName(); }
3448+
// Wraps both--a common idiom in the CmpInst classes
3449+
#define WRAP_BOTH(FunctionName) \
3450+
WRAP_STATIC_PREDICATE(FunctionName) \
3451+
WRAP_MEMBER(FunctionName)
3452+
3453+
class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
3454+
protected:
3455+
using LLVMValType = llvm::CmpInst;
3456+
/// Use Context::createCmpInst(). Don't call the constructor directly.
3457+
CmpInst(llvm::CmpInst *CI, Context &Ctx, ClassID Id, Opcode Opc)
3458+
: SingleLLVMInstructionImpl(Id, Opc, CI, Ctx) {}
3459+
friend Context; // for CmpInst()
3460+
static Value *createCommon(Value *Cond, Value *True, Value *False,
3461+
const Twine &Name, IRBuilder<> &Builder,
3462+
Context &Ctx);
3463+
3464+
public:
3465+
using Predicate = llvm::CmpInst::Predicate;
3466+
3467+
static CmpInst *create(Predicate Pred, Value *S1, Value *S2,
3468+
Instruction *InsertBefore, Context &Ctx,
3469+
const Twine &Name = "");
3470+
static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
3471+
const Instruction *FlagsSource,
3472+
Instruction *InsertBefore, Context &Ctx,
3473+
const Twine &Name = "");
3474+
void setPredicate(Predicate P);
3475+
void swapOperands();
3476+
3477+
WRAP_MEMBER(getPredicate);
3478+
WRAP_BOTH(isFPPredicate);
3479+
WRAP_BOTH(isIntPredicate);
3480+
WRAP_STATIC_PREDICATE(getPredicateName);
3481+
WRAP_BOTH(getInversePredicate);
3482+
WRAP_BOTH(getOrderedPredicate);
3483+
WRAP_BOTH(getUnorderedPredicate);
3484+
WRAP_BOTH(getSwappedPredicate);
3485+
WRAP_BOTH(isStrictPredicate);
3486+
WRAP_BOTH(isNonStrictPredicate);
3487+
WRAP_BOTH(getStrictPredicate);
3488+
WRAP_BOTH(getNonStrictPredicate);
3489+
WRAP_BOTH(getFlippedStrictnessPredicate);
3490+
WRAP_MEMBER(isCommutative);
3491+
WRAP_BOTH(isEquality);
3492+
WRAP_BOTH(isRelational);
3493+
WRAP_BOTH(isSigned);
3494+
WRAP_BOTH(getSignedPredicate);
3495+
WRAP_BOTH(getUnsignedPredicate);
3496+
WRAP_BOTH(getFlippedSignednessPredicate);
3497+
WRAP_BOTH(isTrueWhenEqual);
3498+
WRAP_BOTH(isFalseWhenEqual);
3499+
WRAP_BOTH(isUnsigned);
3500+
WRAP_STATIC_PREDICATE(isOrdered);
3501+
WRAP_STATIC_PREDICATE(isUnordered);
3502+
3503+
static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
3504+
return llvm::CmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
3505+
}
3506+
static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
3507+
return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
3508+
}
3509+
3510+
/// Method for support type inquiry through isa, cast, and dyn_cast:
3511+
static bool classof(const Value *From) {
3512+
return From->getSubclassID() == ClassID::ICmp ||
3513+
From->getSubclassID() == ClassID::FCmp;
3514+
}
3515+
3516+
/// Create a result type for fcmp/icmp
3517+
static Type *makeCmpResultType(Type *OpndType);
3518+
3519+
#ifndef NDEBUG
3520+
void dumpOS(raw_ostream &OS) const override;
3521+
LLVM_DUMP_METHOD void dump() const;
3522+
#endif
3523+
};
3524+
3525+
class ICmpInst : public CmpInst {
3526+
/// Use Context::createICmpInst(). Don't call the constructor directly.
3527+
ICmpInst(llvm::ICmpInst *CI, Context &Ctx)
3528+
: CmpInst(CI, Ctx, ClassID::ICmp, Opcode::ICmp) {}
3529+
friend class Context; // For constructor.
3530+
using LLVMValType = llvm::ICmpInst;
3531+
3532+
public:
3533+
void swapOperands();
3534+
3535+
WRAP_BOTH(getSignedPredicate);
3536+
WRAP_BOTH(getUnsignedPredicate);
3537+
WRAP_BOTH(isEquality);
3538+
WRAP_MEMBER(isCommutative);
3539+
WRAP_MEMBER(isRelational);
3540+
WRAP_STATIC_PREDICATE(isGT);
3541+
WRAP_STATIC_PREDICATE(isLT);
3542+
WRAP_STATIC_PREDICATE(isGE);
3543+
WRAP_STATIC_PREDICATE(isLE);
3544+
3545+
static auto predicates() { return llvm::ICmpInst::predicates(); }
3546+
static bool compare(const APInt &LHS, const APInt &RHS,
3547+
ICmpInst::Predicate Pred) {
3548+
return llvm::ICmpInst::compare(LHS, RHS, Pred);
3549+
}
3550+
3551+
static bool classof(const Value *From) {
3552+
return From->getSubclassID() == ClassID::ICmp;
3553+
}
3554+
};
3555+
3556+
class FCmpInst : public CmpInst {
3557+
/// Use Context::createFCmpInst(). Don't call the constructor directly.
3558+
FCmpInst(llvm::FCmpInst *CI, Context &Ctx)
3559+
: CmpInst(CI, Ctx, ClassID::FCmp, Opcode::FCmp) {}
3560+
friend class Context; // For constructor.
3561+
using LLVMValType = llvm::FCmpInst;
3562+
3563+
public:
3564+
void swapOperands();
3565+
3566+
WRAP_BOTH(isEquality);
3567+
WRAP_MEMBER(isCommutative);
3568+
WRAP_MEMBER(isRelational);
3569+
3570+
static auto predicates() { return llvm::FCmpInst::predicates(); }
3571+
static bool compare(const APFloat &LHS, const APFloat &RHS,
3572+
FCmpInst::Predicate Pred) {
3573+
return llvm::FCmpInst::compare(LHS, RHS, Pred);
3574+
}
3575+
3576+
static bool classof(const Value *From) {
3577+
return From->getSubclassID() == ClassID::FCmp;
3578+
}
3579+
};
3580+
3581+
#undef WRAP_STATIC_PREDICATE
3582+
#undef WRAP_MEMBER
3583+
#undef WRAP_BOTH
3584+
34283585
/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
34293586
/// an OpaqueInstr.
34303587
class OpaqueInst : public SingleLLVMInstructionImpl<llvm::Instruction> {
@@ -3445,6 +3602,8 @@ class Context {
34453602
LLVMContext &LLVMCtx;
34463603
friend class Type; // For LLVMCtx.
34473604
friend class PointerType; // For LLVMCtx.
3605+
friend class CmpInst; // For LLVMCtx. TODO: cleanup when sandboxir::VectorType
3606+
// is complete
34483607
friend class IntegerType; // For LLVMCtx.
34493608
friend class StructType; // For LLVMCtx.
34503609
Tracker IRTracker;
@@ -3572,6 +3731,12 @@ class Context {
35723731
friend PHINode; // For createPHINode()
35733732
UnreachableInst *createUnreachableInst(llvm::UnreachableInst *UI);
35743733
friend UnreachableInst; // For createUnreachableInst()
3734+
CmpInst *createCmpInst(llvm::CmpInst *I);
3735+
friend CmpInst; // For createCmpInst()
3736+
ICmpInst *createICmpInst(llvm::ICmpInst *I);
3737+
friend ICmpInst; // For createICmpInst()
3738+
FCmpInst *createFCmpInst(llvm::FCmpInst *I);
3739+
friend FCmpInst; // For createFCmpInst()
35753740

35763741
public:
35773742
Context(LLVMContext &LLVMCtx)

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ DEF_INSTR(Cast, OPCODES(\
113113
), CastInst)
114114
DEF_INSTR(PHI, OP(PHI), PHINode)
115115
DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
116+
DEF_INSTR(ICmp, OP(ICmp), FCmpInst)
117+
DEF_INSTR(FCmp, OP(FCmp), ICmpInst)
116118

117119
// clang-format on
118120
#ifdef DEF_VALUE

llvm/include/llvm/SandboxIR/Tracker.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class CatchSwitchInst;
6363
class SwitchInst;
6464
class ConstantInt;
6565
class ShuffleVectorInst;
66-
66+
class CmpInst;
6767
/// The base class for IR Change classes.
6868
class IRChangeBase {
6969
protected:
@@ -130,6 +130,19 @@ class PHIAddIncoming : public IRChangeBase {
130130
#endif
131131
};
132132

133+
class CmpSwapOperands : public IRChangeBase {
134+
CmpInst *Cmp;
135+
136+
public:
137+
CmpSwapOperands(CmpInst *Cmp);
138+
void revert(Tracker &Tracker) final;
139+
void accept() final {}
140+
#ifndef NDEBUG
141+
void dump(raw_ostream &OS) const final { OS << "CmpSwapOperands"; }
142+
LLVM_DUMP_METHOD void dump() const final;
143+
#endif
144+
};
145+
133146
/// Tracks swapping a Use with another Use.
134147
class UseSwap : public IRChangeBase {
135148
Use ThisUse;

llvm/include/llvm/SandboxIR/Type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class Type {
5050
friend class ConstantArray; // For LLVMTy.
5151
friend class ConstantStruct; // For LLVMTy.
5252
friend class ConstantVector; // For LLVMTy.
53+
friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType
54+
// is more complete.
5355

5456
// Friend all instruction classes because `create()` functions use LLVMTy.
5557
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2735,6 +2735,16 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
27352735
It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
27362736
return It->second.get();
27372737
}
2738+
case llvm::Instruction::ICmp: {
2739+
auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
2740+
It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
2741+
return It->second.get();
2742+
}
2743+
case llvm::Instruction::FCmp: {
2744+
auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
2745+
It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
2746+
return It->second.get();
2747+
}
27382748
case llvm::Instruction::Unreachable: {
27392749
auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
27402750
It->second = std::unique_ptr<UnreachableInst>(
@@ -2922,6 +2932,79 @@ PHINode *Context::createPHINode(llvm::PHINode *I) {
29222932
auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
29232933
return cast<PHINode>(registerValue(std::move(NewPtr)));
29242934
}
2935+
ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
2936+
auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
2937+
return cast<ICmpInst>(registerValue(std::move(NewPtr)));
2938+
}
2939+
FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
2940+
auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
2941+
return cast<FCmpInst>(registerValue(std::move(NewPtr)));
2942+
}
2943+
CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2,
2944+
Instruction *InsertBefore, Context &Ctx,
2945+
const Twine &Name) {
2946+
auto &Builder = Ctx.getLLVMIRBuilder();
2947+
Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
2948+
auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
2949+
if (dyn_cast<llvm::ICmpInst>(LLVMI))
2950+
return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
2951+
return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
2952+
}
2953+
CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
2954+
const Instruction *F,
2955+
Instruction *InsertBefore, Context &Ctx,
2956+
const Twine &Name) {
2957+
CmpInst *Inst = create(P, S1, S2, InsertBefore, Ctx, Name);
2958+
cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
2959+
return Inst;
2960+
}
2961+
2962+
Type *CmpInst::makeCmpResultType(Type *OpndType) {
2963+
if (auto *VT = dyn_cast<VectorType>(OpndType)) {
2964+
// TODO: Cleanup when we have more complete support for
2965+
// sandboxir::VectorType
2966+
return OpndType->getContext().getType(llvm::VectorType::get(
2967+
llvm::Type::getInt1Ty(OpndType->getContext().LLVMCtx),
2968+
cast<llvm::VectorType>(VT->LLVMTy)->getElementCount()));
2969+
}
2970+
return Type::getInt1Ty(OpndType->getContext());
2971+
}
2972+
2973+
void CmpInst::setPredicate(Predicate P) {
2974+
Ctx.getTracker()
2975+
.emplaceIfTracking<
2976+
GenericSetter<&CmpInst::getPredicate, &CmpInst::setPredicate>>(this);
2977+
cast<llvm::CmpInst>(Val)->setPredicate(P);
2978+
}
2979+
2980+
void CmpInst::swapOperands() {
2981+
if (ICmpInst *IC = dyn_cast<ICmpInst>(this))
2982+
IC->swapOperands();
2983+
else
2984+
cast<FCmpInst>(this)->swapOperands();
2985+
}
2986+
2987+
void ICmpInst::swapOperands() {
2988+
Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
2989+
cast<llvm::ICmpInst>(Val)->swapOperands();
2990+
}
2991+
2992+
void FCmpInst::swapOperands() {
2993+
Ctx.getTracker().emplaceIfTracking<CmpSwapOperands>(this);
2994+
cast<llvm::FCmpInst>(Val)->swapOperands();
2995+
}
2996+
2997+
#ifndef NDEBUG
2998+
void CmpInst::dumpOS(raw_ostream &OS) const {
2999+
dumpCommonPrefix(OS);
3000+
dumpCommonSuffix(OS);
3001+
}
3002+
3003+
void CmpInst::dump() const {
3004+
dumpOS(dbgs());
3005+
dbgs() << "\n";
3006+
}
3007+
#endif // NDEBUG
29253008

29263009
Value *Context::getValue(llvm::Value *V) const {
29273010
auto It = LLVMValueToValueMap.find(V);

llvm/lib/SandboxIR/Tracker.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,14 @@ void ShuffleVectorSetMask::dump() const {
248248
}
249249
#endif
250250

251+
CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
252+
253+
void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
254+
void CmpSwapOperands::dump() const {
255+
dump(dbgs());
256+
dbgs() << "\n";
257+
}
258+
251259
void Tracker::save() { State = TrackerState::Record; }
252260

253261
void Tracker::revert() {

0 commit comments

Comments
 (0)