Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5168,26 +5168,41 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
return false;
}
case Instruction::Mul: {
auto ShouldSinkSplatForIndexedVariant = [](Value *V) {
auto VT = MVT::getVT(V->getType(), /*HandleUnknown=*/true);
return (VT == MVT::v4i16 || VT == MVT::v8i16 || VT == MVT::v2i32 ||
VT == MVT::v4i32);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add handling for v16i16 and similar larger types too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - I've refactored it to sink any (non-scalable) vector type with i16 or i32 elements, rather than adding all the possible element counts, because that seemed to make more sense - I'm not sure if there's a reason not to do it this way?

};

int NumZExts = 0, NumSExts = 0;
for (auto &Op : I->operands()) {
// Make sure we are not already sinking this operand
if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
continue;

if (match(&Op, m_SExt(m_Value()))) {
NumSExts++;
continue;
} else if (match(&Op, m_ZExt(m_Value()))) {
NumZExts++;
if (match(&Op, m_ZExtOrSExt(m_Value()))) {
auto *Ext = cast<Instruction>(Op);
auto *ExtOp = Ext->getOperand(0);
if (isSplatShuffle(ExtOp) && ShouldSinkSplatForIndexedVariant(ExtOp))
Ops.push_back(&Ext->getOperandUse(0));
Ops.push_back(&Op);

if (isa<SExtInst>(Ext))
NumSExts++;
else
NumZExts++;

continue;
}

ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
if (!Shuffle)
continue;

// If the Shuffle is a splat and the operand is a zext/sext, sinking the
// operand and the s/zext can help create indexed s/umull. This is
// especially useful to prevent i64 mul being scalarized.
if (Shuffle && isSplatShuffle(Shuffle) &&
if (isSplatShuffle(Shuffle) &&
match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
Ops.push_back(&Shuffle->getOperandUse(0));
Ops.push_back(&Op);
Expand All @@ -5198,9 +5213,6 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
continue;
}

if (!Shuffle)
continue;

Value *ShuffleOperand = Shuffle->getOperand(0);
InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
if (!Insert)
Expand Down Expand Up @@ -5232,12 +5244,26 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
NumZExts++;
}

Ops.push_back(&Insert->getOperandUse(1));
Ops.push_back(&Shuffle->getOperandUse(0));
Ops.push_back(&Op);
}

// Is it profitable to sink if we found two of the same type of extends.
return !Ops.empty() && (NumSExts == 2 || NumZExts == 2);
// It is profitable to sink if we found two of the same type of extends.
if (!Ops.empty() && (NumSExts == 2 || NumZExts == 2))
return true;

// Otherwise, see if we should sink splats for indexed variants.
if (!ShouldSinkSplatForIndexedVariant(I))
return false;

Ops.clear();
if (isSplatShuffle(I->getOperand(0)))
Ops.push_back(&I->getOperandUse(0));
if (isSplatShuffle(I->getOperand(1)))
Ops.push_back(&I->getOperandUse(1));

return !Ops.empty();
}
default:
return false;
Expand Down
16 changes: 10 additions & 6 deletions llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ target triple = "aarch64-unknown-linux-gnu"
define dso_local i32 @dupext_crashtest(i32 %e) local_unnamed_addr {
; CHECK-LABEL: dupext_crashtest:
; CHECK: // %bb.0: // %for.body.lr.ph
; CHECK-NEXT: mov w8, w0
; CHECK-NEXT: dup v0.2s, w8
; CHECK-NEXT: .LBB0_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ldr d1, [x8]
; CHECK-NEXT: smull v1.2d, v0.2s, v1.2s
; CHECK-NEXT: xtn v1.2s, v1.2d
; CHECK-NEXT: str d1, [x8]
; CHECK-NEXT: ldr d0, [x8]
; CHECK-NEXT: ushll v0.2d, v0.2s, #0
; CHECK-NEXT: fmov x9, d0
; CHECK-NEXT: mov x8, v0.d[1]
; CHECK-NEXT: mul w9, w0, w9
; CHECK-NEXT: mul w8, w0, w8
; CHECK-NEXT: fmov d0, x9
; CHECK-NEXT: mov v0.d[1], x8
; CHECK-NEXT: xtn v0.2s, v0.2d
; CHECK-NEXT: str d0, [x8]
Comment on lines -17 to +24
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression, but I have a patch that fixes it by teaching

https://github.com/llvm/llvm-project/blob/c25c6c32494c8d1038438b6208d42ba40f25270e/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp#L18599

to handle ANY_EXTENDs, which seem to get generated via KnownBits queries when we visit the truncate nodes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Is it possible to write a separate test for it too, with the anyext already in place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seemed to make sense to put this into a seperate, follow-up PR - see #118308

Is it possible to write a separate test for it too, with the anyext already in place?

I've added the test @dupzext_v2i32_v2i64_trunc in that PR that should generate the anyext via the truncate - I'm not sure how I would do this otherwise, as unless I'm missing something there's no anyext in LLVM IR?

; CHECK-NEXT: b .LBB0_1
for.body.lr.ph:
%conv314 = zext i32 %e to i64
Expand Down
Loading
Loading