Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,268 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
return false;
}

/// Match low part of 128-bit multiplication.
///
/// Use counts are checked to prevent total instruction count increase as per
/// contributors guide:
/// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need links to the ContributorGuide here. It might be useful to explain what it is matching instead.

static bool foldMul128Low(Instruction &I) {
auto *Ty = I.getType();
if (!Ty->isIntegerTy(64))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if this was generalized to more types? Would it trigger a lot more, and would those be profitable for smaller types? It would help make this more generic.

return false;

// (low_accum << 32) | lo(lo(y) * lo(x))
Value *LowAccum = nullptr, *YLowXLow = nullptr;
if (!match(&I, m_c_DisjointOr(
m_OneUse(m_Shl(m_Value(LowAccum), m_SpecificInt(32))),
m_OneUse(
m_And(m_Value(YLowXLow), m_SpecificInt(0xffffffff))))))
return false;

// lo(cross_sum) + hi(lo(y) * lo(x))
Value *CrossSum = nullptr;
if (!match(
LowAccum,
m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this And might be optional, as it only clears top bits that we will ignore in the shift.

m_OneUse(m_LShr(m_Specific(YLowXLow), m_SpecificInt(32))))) ||
LowAccum->hasNUsesOrMore(3))
return false;

// (hi(y) * lo(x)) + (lo(y) * hi(x))
Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr;
if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_c_Mul can be m_Mul as both the operands are m_Value.

m_Value(YLowXHigh))) ||
CrossSum->hasNUsesOrMore(4))
return false;

// lo(y) * lo(x)
Value *YLow = nullptr;
if (!match(YLowXLow, m_c_Mul(m_Value(YLow), m_Specific(XLow))) ||
YLowXLow->hasNUsesOrMore(3))
return false;

// lo(y) * hi(x)
Value *XHigh = nullptr;
if (!match(YLowXHigh, m_c_Mul(m_Specific(YLow), m_Value(XHigh))) ||
!YLowXHigh->hasNUses(2))
return false;

Value *X = nullptr;
// lo(x) = x & 0xffffffff
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
!XLow->hasNUses(2))
return false;
// hi(x) = x >> 32
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
!XHigh->hasNUses(2))
return false;

// Same for Y.
Value *Y = nullptr;
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
!YLow->hasNUses(2))
return false;
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
!YHigh->hasNUses(2))
return false;

IRBuilder<> Builder(&I);
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
Value *Mul128 = Builder.CreateMul(XExt, YExt);
Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty());
Comment on lines +1526 to +1529
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this just be a mul?

I.replaceAllUsesWith(Res);

return true;
}

/// Match high part of 128-bit multiplication.
///
/// Use counts are checked to prevent total instruction count increase as per
/// contributors guide:
/// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling
static bool foldMul128High(Instruction &I) {
auto *Ty = I.getType();
if (!Ty->isIntegerTy(64))
return false;

// intermediate_plus_carry + hi(low_accum)
Value *IntermediatePlusCarry = nullptr, *LowAccum = nullptr;
if (!match(&I,
m_c_Add(m_OneUse(m_Value(IntermediatePlusCarry)),
m_OneUse(m_LShr(m_Value(LowAccum), m_SpecificInt(32))))))
return false;

// match:
// (((hi(y) * hi(x)) + carry) + hi(cross_sum))
// or:
// ((hi(cross_sum) + (hi(y) * hi(x))) + carry)
CmpPredicate Pred;
Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr,
*Carry = nullptr;
if (!match(IntermediatePlusCarry,
m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_c_Mul -> m_Mul

m_Value(Carry)),
m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))))) &&
!match(IntermediatePlusCarry,
m_c_Add(m_OneUse(m_c_Add(
m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))),
m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
m_Value(Carry))))
return false;

// (select (icmp ult cross_sum, (lo(y) * hi(x))), (1 << 32), 0)
Value *YLowXHigh = nullptr;
if (!match(Carry,
m_OneUse(m_Select(m_OneUse(m_ICmp(Pred, m_Specific(CrossSum),
m_Value(YLowXHigh))),
m_SpecificInt(4294967296), m_SpecificInt(0)))) ||
Pred != ICmpInst::ICMP_ULT)
return false;

// (hi(y) * lo(x)) + (lo(y) * hi(x))
Value *XLow = nullptr;
if (!match(CrossSum,
m_c_Add(m_OneUse(m_c_Mul(m_Specific(YHigh), m_Value(XLow))),
m_Specific(YLowXHigh))) ||
CrossSum->hasNUsesOrMore(4))
return false;

// lo(y) * hi(x)
Value *YLow = nullptr;
if (!match(YLowXHigh, m_c_Mul(m_Value(YLow), m_Specific(XHigh))) ||
!YLowXHigh->hasNUses(2))
return false;

// lo(cross_sum) + hi(lo(y) * lo(x))
Value *YLowXLow = nullptr;
if (!match(LowAccum,
m_c_Add(m_OneUse(m_c_And(m_Specific(CrossSum),
m_SpecificInt(0xffffffff))),
m_OneUse(m_LShr(m_Value(YLowXLow), m_SpecificInt(32))))) ||
LowAccum->hasNUsesOrMore(3))
return false;

// lo(y) * lo(x)
//
// When only doing the high part there's a single use and 2 uses when doing
// full multiply. Given the low/high patterns are separate, it's non-trivial
// to vary the number of uses to check this, but applying the optimization
// when there's an unrelated use when only doing the high part still results
// in less instructions and is likely profitable, so an upper bound of 2 uses
// should be fine.
if (!match(YLowXLow, m_c_Mul(m_Specific(YLow), m_Specific(XLow))) ||
YLowXLow->hasNUsesOrMore(3))
return false;

Value *X = nullptr;
// lo(x) = x & 0xffffffff
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
!XLow->hasNUses(2))
return false;
// hi(x) = x >> 32
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
!XHigh->hasNUses(2))
return false;

// Same for Y.
Value *Y = nullptr;
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
!YLow->hasNUses(2))
return false;
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
!YHigh->hasNUses(2))
return false;

IRBuilder<> Builder(&I);
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
Value *Mul128 = Builder.CreateMul(XExt, YExt);
Value *High = Builder.CreateLShr(Mul128, 64);
Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
I.replaceAllUsesWith(Res);

return true;
}

/// Match another variant of high part of 128-bit multiplication.
///
/// %t0 = mul nuw i64 %y_lo, %x_lo
/// %t1 = mul nuw i64 %y_lo, %x_hi
/// %t2 = mul nuw i64 %y_hi, %x_lo
/// %t3 = mul nuw i64 %y_hi, %x_hi
/// %t0_hi = lshr i64 %t0, 32
/// %u0 = add nuw i64 %t0_hi, %t1
/// %u0_lo = and i64 %u0, 4294967295
/// %u0_hi = lshr i64 %u0, 32
/// %u1 = add nuw i64 %u0_lo, %t2
/// %u1_hi = lshr i64 %u1, 32
/// %u2 = add nuw i64 %u0_hi, %t3
/// %hw64 = add nuw i64 %u2, %u1_hi
///
/// Use counts are checked to prevent total instruction count increase as per
/// contributors guide:
/// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling
static bool foldMul128HighVariant(Instruction &I) {
auto *Ty = I.getType();
if (!Ty->isIntegerTy(64))
return false;

// hw64 = (hi(u0) + (hi(y) * hi(x)) + (lo(u0) + (hi(y) * lo(x)) >> 32))
Value *U0 = nullptr, *XHigh = nullptr, *YHigh = nullptr, *XLow = nullptr;
if (!match(
&I,
m_c_Add(m_OneUse(m_c_Add(
m_OneUse(m_LShr(m_Value(U0), m_SpecificInt(32))),
m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
m_OneUse(m_LShr(
m_OneUse(m_c_Add(
m_OneUse(m_c_And(m_Deferred(U0),
m_SpecificInt(0xffffffff))),
m_OneUse(m_c_Mul(m_Deferred(YHigh), m_Value(XLow))))),
m_SpecificInt(32))))))
return false;

// u0 = (hi(lo(y) * lo(x)) + (lo(y) * hi(x)))
Value *YLow = nullptr;
if (!match(U0,
m_c_Add(m_OneUse(m_LShr(
m_OneUse(m_c_Mul(m_Value(YLow), m_Specific(XLow))),
m_SpecificInt(32))),
m_OneUse(m_c_Mul(m_Deferred(YLow), m_Specific(XHigh))))) ||
!U0->hasNUses(2))
return false;

Value *X = nullptr;
// lo(x) = x & 0xffffffff
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
!XLow->hasNUses(2))
return false;
// hi(x) = x >> 32
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
!XHigh->hasNUses(2))
return false;

// Same for Y.
Value *Y = nullptr;
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
!YLow->hasNUses(2))
return false;
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
!YHigh->hasNUses(2))
return false;

IRBuilder<> Builder(&I);
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
Value *Mul128 = Builder.CreateMul(XExt, YExt);
Value *High = Builder.CreateLShr(Mul128, 64);
Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
I.replaceAllUsesWith(Res);

return true;
}

/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
Expand Down Expand Up @@ -1486,6 +1748,9 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
MadeChange |= foldPatternedLoads(I, DL);
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
MadeChange |= foldMul128Low(I);
MadeChange |= foldMul128High(I);
MadeChange |= foldMul128HighVariant(I);
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
Expand Down
Loading