Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class OverflowingBinaryOperator : public Operator {
return NoWrapKind;
}

/// Return true if the instruction is commutative
bool isCommutative() const { return Instruction::isCommutative(getOpcode()); }

static bool classof(const Instruction *I) {
return I->getOpcode() == Instruction::Add ||
I->getOpcode() == Instruction::Sub ||
Expand Down
73 changes: 73 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,76 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
ConstantInt::getTrue(ZeroUndef->getType()));
}

/// Return whether "X LOp (Y ROp Z)" is always equal to
/// "(X LOp Y) ROp (X LOp Z)".
static bool leftDistributesOverRight(Instruction::BinaryOps LOp, bool HasNUW,
bool HasNSW, Intrinsic::ID ROp) {
switch (ROp) {
case Intrinsic::umax:
case Intrinsic::umin:
return hasNUW && LOp == Instruction::Add;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can combine these cases since there are no functions that distribute over umax but not umin (or vice versa).

  case Intrinsic::umax:
  case Intrinsic::umin:
    return hasNUW && LOp == Instruction::Add;

Proof sketch: Let f be an arbitrary binary function and x, y, z be arbitrary bit vectors. Suppose that (f u (umax v w)) = (umax (f u v) (f u w)) for all u, v, w. Observe that (umin u v) = (xor u v (umax u v)) for all u, v. Then (umin (f x y) (f x z)) = (xor (f x y) (f x z) (umax (f x y) (f x z))) = (xor (f x y) (f x z) (f x (umax y z))). The case (f x y) = (f x z) is trivial, hence suppose they are not equal. Then (f x (umax y z)) is equal to either (f x y) or (f x z), leaving the other as the result of the xor, which equals (f x (umin y z)), as required.

(Similar for smin/smax.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely understand. Fixed.

case Intrinsic::smax:
case Intrinsic::smin:
return hasNSW && LOp == Instruction::Add;
default:
return false;
}
}

// Attempts to factorise a common term
// in an instruction that has the form "(A op' B) op (C op' D)
// where op is an intrinsic and op' is a binop
static Value *
foldIntrinsicUsingDistributiveLaws(IntrinsicInst *II,
InstCombiner::BuilderTy &Builder) {
Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
Intrinsic::ID TopLevelOpcode = II->getIntrinsicID();

OverflowingBinaryOperator *Op0 = dyn_cast<OverflowingBinaryOperator>(LHS);
OverflowingBinaryOperator *Op1 = dyn_cast<OverflowingBinaryOperator>(RHS);

if (!Op0 || !Op1)
return nullptr;

if (Op0->getOpcode() != Op1->getOpcode())
return nullptr;

if (!Op0->hasOneUse() || !Op1->hasOneUse())
return nullptr;

Instruction::BinaryOps InnerOpcode =
static_cast<Instruction::BinaryOps>(Op0->getOpcode());
bool HasNUW = Op0->hasNoUnsignedWrap() && Op1->hasNoUnsignedWrap();
bool HasNSW = Op0->hasNoSignedWrap() && Op1->hasNoSignedWrap();

if (!leftDistributesOverRight(InnerOpcode, HasNUW, HasNSW, TopLevelOpcode))
return nullptr;

assert(II->isCommutative() && Op0->isCommutative() &&
"Only inner and outer commutative op codes are supported.");

Value *A = Op0->getOperand(0);
Value *B = Op0->getOperand(1);
Value *C = Op1->getOperand(0);
Value *D = Op1->getOperand(1);

// Attempts to swap variables such that A always equals C
if (A != C && A != D)
std::swap(A, B);
if (A == C || A == D) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: I'm just curious but does this work for constant As/Cs or splat vectors? For example,

define i8 @f(i8 %x, i8 %y) {
    %add1 = add nuw i8 %x, 42
    %add2 = add nuw i8 %y, 42
    %umin = call i8 @llvm.umin.i8(i8 %add1, i8 %add2)
    ret i8 %umin
}

and

define <4 x i8> @src(<4 x i8> %x, <4 x i8> %y) {
    %add1 = add nuw <4 x i8> %x, <i8 42, i8 42, i8 42, i8 42>
    %add2 = add nuw <4 x i8> %y, <i8 42, i8 42, i8 42, i8 42>
    %umin = call <4 x i8> @llvm.umin.v4i8(<4 x i8> %add1, <4 x i8> %add2)
    ret <4 x i8> %umin
}

It might be a good idea to add such a test to the precommitted tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would work for constants/splats because the optimisation itself doesn't distinguish between different types of operands. Sure. Will add a test.

if (A != C)
std::swap(C, D);
Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
BinaryOperator *NewBinop =
cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, A));
NewBinop->setHasNoSignedWrap(HasNSW);
NewBinop->setHasNoUnsignedWrap(HasNUW);
return NewBinop;
}

return nullptr;
}

/// CallInst simplification. This mostly only handles folding of intrinsic
/// instructions. For normal calls, it allows visitCallBase to do the heavy
/// lifting.
Expand Down Expand Up @@ -1928,6 +1998,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}

if (Value *V = foldIntrinsicUsingDistributiveLaws(II, Builder))
return replaceInstUsesWith(*II, V);

break;
}
case Intrinsic::scmp: {
Expand Down
Loading
Loading