Skip to content

Commit 0e4bf69

Browse files
committed
fixup! Address review comments
1 parent f7242df commit 0e4bf69

File tree

3 files changed

+214
-70
lines changed

3 files changed

+214
-70
lines changed

llvm/lib/CodeGen/InterleavedAccessPass.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -630,23 +630,41 @@ getVectorDeinterleaveFactor(IntrinsicInst *II,
630630
return true;
631631
}
632632

633-
// Return nullptr if the value corresponds to a all-true mask. Otherwise,
634-
// return the value that is corresponded to a deinterleaved mask.
635-
static Value *getMask(Value *WideMask, unsigned Factor) {
633+
// Return the corresponded deinterleaved mask, or nullptr if there is no valid
634+
// mask.
635+
static Value *getMask(Value *WideMask, unsigned Factor,
636+
VectorType *LeafValueTy) {
637+
Value *MaskVal = nullptr;
638+
636639
using namespace llvm::PatternMatch;
637640
if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) {
638641
SmallVector<Value *, 8> Operands;
639642
SmallVector<Instruction *, 8> DeadInsts;
640643
if (getVectorInterleaveFactor(IMI, Operands, DeadInsts)) {
641644
assert(!Operands.empty());
642645
if (Operands.size() == Factor && llvm::all_equal(Operands))
643-
return Operands[0];
646+
MaskVal = Operands[0];
644647
}
645648
}
646-
if (match(WideMask, m_AllOnes()))
647-
return WideMask;
648649

649-
return nullptr;
650+
if (match(WideMask, m_AllOnes())) {
651+
// Scale the vector length.
652+
ElementCount OrigEC =
653+
cast<VectorType>(WideMask->getType())->getElementCount();
654+
MaskVal =
655+
ConstantVector::getSplat(OrigEC.divideCoefficientBy(Factor),
656+
cast<Constant>(WideMask)->getSplatValue());
657+
}
658+
659+
if (MaskVal) {
660+
// Check if the vector length of mask matches that of the leaf values.
661+
auto *MaskTy = cast<VectorType>(MaskVal->getType());
662+
if (!MaskTy->getElementType()->isIntegerTy(/*Bitwidth=*/1) ||
663+
MaskTy->getElementCount() != LeafValueTy->getElementCount())
664+
return nullptr;
665+
}
666+
667+
return MaskVal;
650668
}
651669

652670
bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
@@ -668,7 +686,8 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
668686
return false;
669687
// Check mask operand. Handle both all-true and interleaved mask.
670688
Value *WideMask = VPLoad->getOperand(1);
671-
Value *Mask = getMask(WideMask, Factor);
689+
Value *Mask = getMask(WideMask, Factor,
690+
cast<VectorType>(DeinterleaveValues[0]->getType()));
672691
if (!Mask)
673692
return false;
674693

@@ -720,7 +739,8 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
720739
return false;
721740

722741
Value *WideMask = VPStore->getOperand(2);
723-
Value *Mask = getMask(WideMask, Factor);
742+
Value *Mask = getMask(WideMask, Factor,
743+
cast<VectorType>(InterleaveValues[0]->getType()));
724744
if (!Mask)
725745
return false;
726746

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/SmallSet.h"
2323
#include "llvm/ADT/Statistic.h"
2424
#include "llvm/Analysis/MemoryLocation.h"
25+
#include "llvm/Analysis/ValueTracking.h"
2526
#include "llvm/Analysis/VectorUtils.h"
2627
#include "llvm/CodeGen/MachineFrameInfo.h"
2728
#include "llvm/CodeGen/MachineFunction.h"
@@ -22529,6 +22530,22 @@ bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore(
2252922530
return true;
2253022531
}
2253122532

22533+
static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) {
22534+
assert(N);
22535+
if (N == 1)
22536+
return true;
22537+
22538+
if (isPowerOf2_32(N)) {
22539+
KnownBits KB = llvm::computeKnownBits(V, DL);
22540+
return KB.countMinTrailingZeros() >= Log2_32(N);
22541+
} else {
22542+
using namespace PatternMatch;
22543+
// Right now we're only recognizing the simplest pattern.
22544+
uint64_t C;
22545+
return match(V, m_c_Mul(m_Value(), m_ConstantInt(C))) && C && C % N == 0;
22546+
}
22547+
}
22548+
2253222549
/// Lower an interleaved vp.load into a vlsegN intrinsic.
2253322550
///
2253422551
/// E.g. Lower an interleaved vp.load (Factor = 2):
@@ -22586,6 +22603,9 @@ bool RISCVTargetLowering::lowerDeinterleavedIntrinsicToVPLoad(
2258622603

2258722604
IRBuilder<> Builder(Load);
2258822605
Value *WideEVL = Load->getArgOperand(2);
22606+
if (!isMultipleOfN(WideEVL, Load->getDataLayout(), Factor))
22607+
return false;
22608+
2258922609
auto *XLenTy = Type::getIntNTy(Load->getContext(), Subtarget.getXLen());
2259022610
Value *EVL = Builder.CreateZExtOrTrunc(
2259122611
Builder.CreateUDiv(WideEVL, ConstantInt::get(WideEVL->getType(), Factor)),
@@ -22740,6 +22760,9 @@ bool RISCVTargetLowering::lowerInterleavedIntrinsicToVPStore(
2274022760

2274122761
IRBuilder<> Builder(Store);
2274222762
Value *WideEVL = Store->getArgOperand(3);
22763+
if (!isMultipleOfN(WideEVL, Store->getDataLayout(), Factor))
22764+
return false;
22765+
2274322766
auto *XLenTy = Type::getIntNTy(Store->getContext(), Subtarget.getXLen());
2274422767
Value *EVL = Builder.CreateZExtOrTrunc(
2274522768
Builder.CreateUDiv(WideEVL, ConstantInt::get(WideEVL->getType(), Factor)),

0 commit comments

Comments
 (0)