Skip to content

Commit 7f2bbba

Browse files
authored
[AArch64][ARM] Optimize more tbl/tbx calls into shufflevector (#169748)
Resolves #169701. This PR extends the existing InstCombine operation which folds `tbl1` intrinsics to `shufflevector` if the mask operand is constant. Before this change, it only handled 64-bit `tbl1` intrinsics with no out-of-bounds indices. I've extended it to support both 64-bit and 128-bit vectors, and it now handles the full range of `tbl1`-`tbl4` and `tbx1`-`tbx4`, as long as at most two of the input operands are actually indexed into. For the purposes of `tbl`, we need a dummy vector of zeroes if there are any out-of-bounds indices, and for the purposes of `tbx`, we use the "fallback" operand. Both of those take up an operand for the purposes of `shufflevector`. This works a lot like #169110, with some added complexity because we need to handle multiple operands. I raised a couple questions in that PR that still need to be answered: - Is it correct to check `IsA<UndefValue>` for each mask index, and set the output mask index to -1 if so? This is later folded to a poison value, and I'm not sure about the subtle differences between poison and undef and when you can substitute one for the other. As I mentioned in #169110, the existing x86 pass (`simplifyX86vpermilvar`) already behaves this way when it comes to undef. - How can I write an Alive2 proof for this? It's very hard to find good documentation or tutorials about Alive2. As with #169110, most of the regression test cases were generated using Claude. Everything else was written by me.
1 parent c66eb25 commit 7f2bbba

File tree

5 files changed

+589
-123
lines changed

5 files changed

+589
-123
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 113 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -737,42 +737,119 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
737737
return nullptr;
738738
}
739739

740-
/// Convert a table lookup to shufflevector if the mask is constant.
741-
/// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in
742-
/// which case we could lower the shufflevector with rev64 instructions
743-
/// as it's actually a byte reverse.
744-
static Value *simplifyNeonTbl1(const IntrinsicInst &II,
745-
InstCombiner::BuilderTy &Builder) {
740+
/// Convert `tbl`/`tbx` intrinsics to shufflevector if the mask is constant, and
741+
/// at most two source operands are actually referenced.
742+
static Instruction *simplifyNeonTbl(IntrinsicInst &II, InstCombiner &IC,
743+
bool IsExtension) {
746744
// Bail out if the mask is not a constant.
747-
auto *C = dyn_cast<Constant>(II.getArgOperand(1));
745+
auto *C = dyn_cast<Constant>(II.getArgOperand(II.arg_size() - 1));
748746
if (!C)
749747
return nullptr;
750748

751-
auto *VecTy = cast<FixedVectorType>(II.getType());
752-
unsigned NumElts = VecTy->getNumElements();
749+
auto *RetTy = cast<FixedVectorType>(II.getType());
750+
unsigned NumIndexes = RetTy->getNumElements();
753751

754-
// Only perform this transformation for <8 x i8> vector types.
755-
if (!VecTy->getElementType()->isIntegerTy(8) || NumElts != 8)
752+
// Only perform this transformation for <8 x i8> and <16 x i8> vector types.
753+
if (!RetTy->getElementType()->isIntegerTy(8) ||
754+
(NumIndexes != 8 && NumIndexes != 16))
756755
return nullptr;
757756

758-
int Indexes[8];
757+
// For tbx instructions, the first argument is the "fallback" vector, which
758+
// has the same length as the mask and return type.
759+
unsigned int StartIndex = (unsigned)IsExtension;
760+
auto *SourceTy =
761+
cast<FixedVectorType>(II.getArgOperand(StartIndex)->getType());
762+
// Note that the element count of each source vector does *not* need to be the
763+
// same as the element count of the return type and mask! All source vectors
764+
// must have the same element count as each other, though.
765+
unsigned NumElementsPerSource = SourceTy->getNumElements();
766+
767+
// There are no tbl/tbx intrinsics for which the destination size exceeds the
768+
// source size. However, our definitions of the intrinsics, at least in
769+
// IntrinsicsAArch64.td, allow for arbitrary destination vector sizes, so it
770+
// *could* technically happen.
771+
if (NumIndexes > NumElementsPerSource)
772+
return nullptr;
773+
774+
// The tbl/tbx intrinsics take several source operands followed by a mask
775+
// operand.
776+
unsigned int NumSourceOperands = II.arg_size() - 1 - (unsigned)IsExtension;
759777

760-
for (unsigned I = 0; I < NumElts; ++I) {
778+
// Map input operands to shuffle indices. This also helpfully deduplicates the
779+
// input arguments, in case the same value is passed as an argument multiple
780+
// times.
781+
SmallDenseMap<Value *, unsigned, 2> ValueToShuffleSlot;
782+
Value *ShuffleOperands[2] = {PoisonValue::get(SourceTy),
783+
PoisonValue::get(SourceTy)};
784+
785+
int Indexes[16];
786+
for (unsigned I = 0; I < NumIndexes; ++I) {
761787
Constant *COp = C->getAggregateElement(I);
762788

763-
if (!COp || !isa<ConstantInt>(COp))
789+
if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp)))
764790
return nullptr;
765791

766-
Indexes[I] = cast<ConstantInt>(COp)->getLimitedValue();
792+
if (isa<UndefValue>(COp)) {
793+
Indexes[I] = -1;
794+
continue;
795+
}
796+
797+
uint64_t Index = cast<ConstantInt>(COp)->getZExtValue();
798+
// The index of the input argument that this index references (0 = first
799+
// source argument, etc).
800+
unsigned SourceOperandIndex = Index / NumElementsPerSource;
801+
// The index of the element at that source operand.
802+
unsigned SourceOperandElementIndex = Index % NumElementsPerSource;
803+
804+
Value *SourceOperand;
805+
if (SourceOperandIndex >= NumSourceOperands) {
806+
// This index is out of bounds. Map it to index into either the fallback
807+
// vector (tbx) or vector of zeroes (tbl).
808+
SourceOperandIndex = NumSourceOperands;
809+
if (IsExtension) {
810+
// For out-of-bounds indices in tbx, choose the `I`th element of the
811+
// fallback.
812+
SourceOperand = II.getArgOperand(0);
813+
SourceOperandElementIndex = I;
814+
} else {
815+
// Otherwise, choose some element from the dummy vector of zeroes (we'll
816+
// always choose the first).
817+
SourceOperand = Constant::getNullValue(SourceTy);
818+
SourceOperandElementIndex = 0;
819+
}
820+
} else {
821+
SourceOperand = II.getArgOperand(SourceOperandIndex + StartIndex);
822+
}
823+
824+
// The source operand may be the fallback vector, which may not have the
825+
// same number of elements as the source vector. In that case, we *could*
826+
// choose to extend its length with another shufflevector, but it's simpler
827+
// to just bail instead.
828+
if (cast<FixedVectorType>(SourceOperand->getType())->getNumElements() !=
829+
NumElementsPerSource)
830+
return nullptr;
767831

768-
// Make sure the mask indices are in range.
769-
if ((unsigned)Indexes[I] >= NumElts)
832+
// We now know the source operand referenced by this index. Make it a
833+
// shufflevector operand, if it isn't already.
834+
unsigned NumSlots = ValueToShuffleSlot.size();
835+
// This shuffle references more than two sources, and hence cannot be
836+
// represented as a shufflevector.
837+
if (NumSlots == 2 && !ValueToShuffleSlot.contains(SourceOperand))
770838
return nullptr;
839+
840+
auto [It, Inserted] =
841+
ValueToShuffleSlot.try_emplace(SourceOperand, NumSlots);
842+
if (Inserted)
843+
ShuffleOperands[It->getSecond()] = SourceOperand;
844+
845+
unsigned RemappedIndex =
846+
(It->getSecond() * NumElementsPerSource) + SourceOperandElementIndex;
847+
Indexes[I] = RemappedIndex;
771848
}
772849

773-
auto *V1 = II.getArgOperand(0);
774-
auto *V2 = Constant::getNullValue(V1->getType());
775-
return Builder.CreateShuffleVector(V1, V2, ArrayRef(Indexes));
850+
Value *Shuf = IC.Builder.CreateShuffleVector(
851+
ShuffleOperands[0], ShuffleOperands[1], ArrayRef(Indexes, NumIndexes));
852+
return IC.replaceInstUsesWith(II, Shuf);
776853
}
777854

778855
// Returns true iff the 2 intrinsics have the same operands, limiting the
@@ -3167,10 +3244,23 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
31673244
return CallInst::Create(NewFn, CallArgs);
31683245
}
31693246
case Intrinsic::arm_neon_vtbl1:
3247+
case Intrinsic::arm_neon_vtbl2:
3248+
case Intrinsic::arm_neon_vtbl3:
3249+
case Intrinsic::arm_neon_vtbl4:
31703250
case Intrinsic::aarch64_neon_tbl1:
3171-
if (Value *V = simplifyNeonTbl1(*II, Builder))
3172-
return replaceInstUsesWith(*II, V);
3173-
break;
3251+
case Intrinsic::aarch64_neon_tbl2:
3252+
case Intrinsic::aarch64_neon_tbl3:
3253+
case Intrinsic::aarch64_neon_tbl4:
3254+
return simplifyNeonTbl(*II, *this, /*IsExtension=*/false);
3255+
case Intrinsic::arm_neon_vtbx1:
3256+
case Intrinsic::arm_neon_vtbx2:
3257+
case Intrinsic::arm_neon_vtbx3:
3258+
case Intrinsic::arm_neon_vtbx4:
3259+
case Intrinsic::aarch64_neon_tbx1:
3260+
case Intrinsic::aarch64_neon_tbx2:
3261+
case Intrinsic::aarch64_neon_tbx3:
3262+
case Intrinsic::aarch64_neon_tbx4:
3263+
return simplifyNeonTbl(*II, *this, /*IsExtension=*/true);
31743264

31753265
case Intrinsic::arm_neon_vmulls:
31763266
case Intrinsic::arm_neon_vmullu:

0 commit comments

Comments
 (0)