Skip to content

Commit ca05058

Browse files
authored
[IA][RISCV] Recognize deinterleaved loads that could lower to strided segmented loads (#151612)
Turn the following deinterleaved load patterns ``` %l = masked.load(%ptr, /*mask=*/110110110110, /*passthru=*/poison) %f0 = shufflevector %l, [0, 3, 6, 9] %f1 = shufflevector %l, [1, 4, 7, 10] %f2 = shufflevector %l, [2, 5, 8, 11] ``` into ``` %s = riscv.vlsseg2(/*passthru=*/poison, %ptr, /*mask=*/1111) %f0 = extractvalue %s, 0 %f1 = extractvalue %s, 1 %f2 = poison ``` The mask `110110110110` is regarded as 'gap mask' since it effectively skips the entire third field / component. Similarly, turning the following snippet ``` %l = masked.load(%ptr, /*mask=*/110000110000, /*passthru=*/poison) %f0 = shufflevector %l, [0, 3, 6, 9] %f1 = shufflevector %l, [1, 4, 7, 10] ``` into ``` %s = riscv.vlsseg2(/*passthru=*/poison, %ptr, /*mask=*/1010) %f0 = extractvalue %s, 0 %f1 = extractvalue %s, 1 ``` Right now this patch only tries to detect gap mask from a constant mask supplied to a masked.load/vp.load.
1 parent 6f939da commit ca05058

File tree

11 files changed

+357
-82
lines changed

11 files changed

+357
-82
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3202,10 +3202,12 @@ class LLVM_ABI TargetLoweringBase {
32023202
/// \p Shuffles is the shufflevector list to DE-interleave the loaded vector.
32033203
/// \p Indices is the corresponding indices for each shufflevector.
32043204
/// \p Factor is the interleave factor.
3205+
/// \p GapMask is a mask with zeros for components / fields that may not be
3206+
/// accessed.
32053207
virtual bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
32063208
ArrayRef<ShuffleVectorInst *> Shuffles,
3207-
ArrayRef<unsigned> Indices,
3208-
unsigned Factor) const {
3209+
ArrayRef<unsigned> Indices, unsigned Factor,
3210+
const APInt &GapMask) const {
32093211
return false;
32103212
}
32113213

llvm/lib/CodeGen/InterleavedAccessPass.cpp

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,16 @@ static Value *getMaskOperand(IntrinsicInst *II) {
268268
}
269269
}
270270

271-
// Return the corresponded deinterleaved mask, or nullptr if there is no valid
272-
// mask.
273-
static Value *getMask(Value *WideMask, unsigned Factor,
274-
ElementCount LeafValueEC);
275-
276-
static Value *getMask(Value *WideMask, unsigned Factor,
277-
VectorType *LeafValueTy) {
271+
// Return a pair of
272+
// (1) The corresponded deinterleaved mask, or nullptr if there is no valid
273+
// mask.
274+
// (2) Some mask effectively skips a certain field, and this element is a mask
275+
// in which inactive lanes represent fields that are skipped (i.e. "gaps").
276+
static std::pair<Value *, APInt> getMask(Value *WideMask, unsigned Factor,
277+
ElementCount LeafValueEC);
278+
279+
static std::pair<Value *, APInt> getMask(Value *WideMask, unsigned Factor,
280+
VectorType *LeafValueTy) {
278281
return getMask(WideMask, Factor, LeafValueTy->getElementCount());
279282
}
280283

@@ -379,22 +382,25 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
379382
replaceBinOpShuffles(BinOpShuffles.getArrayRef(), Shuffles, Load);
380383

381384
Value *Mask = nullptr;
385+
auto GapMask = APInt::getAllOnes(Factor);
382386
if (LI) {
383387
LLVM_DEBUG(dbgs() << "IA: Found an interleaved load: " << *Load << "\n");
384388
} else {
385389
// Check mask operand. Handle both all-true/false and interleaved mask.
386-
Mask = getMask(getMaskOperand(II), Factor, VecTy);
390+
std::tie(Mask, GapMask) = getMask(getMaskOperand(II), Factor, VecTy);
387391
if (!Mask)
388392
return false;
389393

390394
LLVM_DEBUG(dbgs() << "IA: Found an interleaved vp.load or masked.load: "
391395
<< *Load << "\n");
396+
LLVM_DEBUG(dbgs() << "IA: With nominal factor " << Factor
397+
<< " and actual factor " << GapMask.popcount() << "\n");
392398
}
393399

394400
// Try to create target specific intrinsics to replace the load and
395401
// shuffles.
396402
if (!TLI->lowerInterleavedLoad(cast<Instruction>(Load), Mask, Shuffles,
397-
Indices, Factor))
403+
Indices, Factor, GapMask))
398404
// If Extracts is not empty, tryReplaceExtracts made changes earlier.
399405
return !Extracts.empty() || BinOpShuffleChanged;
400406

@@ -536,10 +542,15 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
536542
} else {
537543
// Check mask operand. Handle both all-true/false and interleaved mask.
538544
unsigned LaneMaskLen = NumStoredElements / Factor;
539-
Mask = getMask(getMaskOperand(II), Factor,
540-
ElementCount::getFixed(LaneMaskLen));
545+
APInt GapMask(Factor, 0);
546+
std::tie(Mask, GapMask) = getMask(getMaskOperand(II), Factor,
547+
ElementCount::getFixed(LaneMaskLen));
541548
if (!Mask)
542549
return false;
550+
// We haven't supported gap mask for stores. Yet it is possible that we
551+
// already changed the IR, hence returning true here.
552+
if (GapMask.popcount() != Factor)
553+
return true;
543554

544555
LLVM_DEBUG(dbgs() << "IA: Found an interleaved vp.store or masked.store: "
545556
<< *Store << "\n");
@@ -556,34 +567,64 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
556567
return true;
557568
}
558569

559-
static Value *getMask(Value *WideMask, unsigned Factor,
560-
ElementCount LeafValueEC) {
570+
// A wide mask <1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0> could be used to skip the
571+
// last field in a factor-of-three interleaved store or deinterleaved load (in
572+
// which case LeafMaskLen is 4). Such (wide) mask is also known as gap mask.
573+
// This helper function tries to detect this pattern and return the actual
574+
// factor we're accessing, which is 2 in this example.
575+
static void getGapMask(const Constant &MaskConst, unsigned Factor,
576+
unsigned LeafMaskLen, APInt &GapMask) {
577+
assert(GapMask.getBitWidth() == Factor);
578+
for (unsigned F = 0U; F < Factor; ++F) {
579+
bool AllZero = true;
580+
for (unsigned Idx = 0U; Idx < LeafMaskLen; ++Idx) {
581+
Constant *C = MaskConst.getAggregateElement(F + Idx * Factor);
582+
if (!C->isZeroValue()) {
583+
AllZero = false;
584+
break;
585+
}
586+
}
587+
// All mask bits on this field are zero, skipping it.
588+
if (AllZero)
589+
GapMask.clearBit(F);
590+
}
591+
}
592+
593+
static std::pair<Value *, APInt> getMask(Value *WideMask, unsigned Factor,
594+
ElementCount LeafValueEC) {
595+
auto GapMask = APInt::getAllOnes(Factor);
596+
561597
if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) {
562598
if (unsigned F = getInterleaveIntrinsicFactor(IMI->getIntrinsicID());
563599
F && F == Factor && llvm::all_equal(IMI->args())) {
564-
return IMI->getArgOperand(0);
600+
return {IMI->getArgOperand(0), GapMask};
565601
}
566602
}
567603

568604
if (auto *ConstMask = dyn_cast<Constant>(WideMask)) {
569605
if (auto *Splat = ConstMask->getSplatValue())
570606
// All-ones or all-zeros mask.
571-
return ConstantVector::getSplat(LeafValueEC, Splat);
607+
return {ConstantVector::getSplat(LeafValueEC, Splat), GapMask};
572608

573609
if (LeafValueEC.isFixed()) {
574610
unsigned LeafMaskLen = LeafValueEC.getFixedValue();
611+
// First, check if we use a gap mask to skip some of the factors / fields.
612+
getGapMask(*ConstMask, Factor, LeafMaskLen, GapMask);
613+
575614
SmallVector<Constant *, 8> LeafMask(LeafMaskLen, nullptr);
576615
// If this is a fixed-length constant mask, each lane / leaf has to
577616
// use the same mask. This is done by checking if every group with Factor
578617
// number of elements in the interleaved mask has homogeneous values.
579618
for (unsigned Idx = 0U; Idx < LeafMaskLen * Factor; ++Idx) {
619+
if (!GapMask[Idx % Factor])
620+
continue;
580621
Constant *C = ConstMask->getAggregateElement(Idx);
581622
if (LeafMask[Idx / Factor] && LeafMask[Idx / Factor] != C)
582-
return nullptr;
623+
return {nullptr, GapMask};
583624
LeafMask[Idx / Factor] = C;
584625
}
585626

586-
return ConstantVector::get(LeafMask);
627+
return {ConstantVector::get(LeafMask), GapMask};
587628
}
588629
}
589630

@@ -603,12 +644,13 @@ static Value *getMask(Value *WideMask, unsigned Factor,
603644
auto *LeafMaskTy =
604645
VectorType::get(Type::getInt1Ty(SVI->getContext()), LeafValueEC);
605646
IRBuilder<> Builder(SVI);
606-
return Builder.CreateExtractVector(LeafMaskTy, SVI->getOperand(0),
607-
uint64_t(0));
647+
return {Builder.CreateExtractVector(LeafMaskTy, SVI->getOperand(0),
648+
uint64_t(0)),
649+
GapMask};
608650
}
609651
}
610652

611-
return nullptr;
653+
return {nullptr, GapMask};
612654
}
613655

614656
bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
@@ -639,9 +681,16 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
639681
return false;
640682

641683
// Check mask operand. Handle both all-true/false and interleaved mask.
642-
Mask = getMask(getMaskOperand(II), Factor, getDeinterleavedVectorType(DI));
684+
APInt GapMask(Factor, 0);
685+
std::tie(Mask, GapMask) =
686+
getMask(getMaskOperand(II), Factor, getDeinterleavedVectorType(DI));
643687
if (!Mask)
644688
return false;
689+
// We haven't supported gap mask if it's deinterleaving using intrinsics.
690+
// Yet it is possible that we already changed the IR, hence returning true
691+
// here.
692+
if (GapMask.popcount() != Factor)
693+
return true;
645694

646695
LLVM_DEBUG(dbgs() << "IA: Found a vp.load or masked.load with deinterleave"
647696
<< " intrinsic " << *DI << " and factor = "
@@ -680,10 +729,16 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
680729
II->getIntrinsicID() != Intrinsic::vp_store)
681730
return false;
682731
// Check mask operand. Handle both all-true/false and interleaved mask.
683-
Mask = getMask(getMaskOperand(II), Factor,
684-
cast<VectorType>(InterleaveValues[0]->getType()));
732+
APInt GapMask(Factor, 0);
733+
std::tie(Mask, GapMask) =
734+
getMask(getMaskOperand(II), Factor,
735+
cast<VectorType>(InterleaveValues[0]->getType()));
685736
if (!Mask)
686737
return false;
738+
// We haven't supported gap mask if it's interleaving using intrinsics. Yet
739+
// it is possible that we already changed the IR, hence returning true here.
740+
if (GapMask.popcount() != Factor)
741+
return true;
687742

688743
LLVM_DEBUG(dbgs() << "IA: Found a vp.store or masked.store with interleave"
689744
<< " intrinsic " << *IntII << " and factor = "

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17390,7 +17390,7 @@ static Function *getStructuredStoreFunction(Module *M, unsigned Factor,
1739017390
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
1739117391
bool AArch64TargetLowering::lowerInterleavedLoad(
1739217392
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
17393-
ArrayRef<unsigned> Indices, unsigned Factor) const {
17393+
ArrayRef<unsigned> Indices, unsigned Factor, const APInt &GapMask) const {
1739417394
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
1739517395
"Invalid interleave factor");
1739617396
assert(!Shuffles.empty() && "Empty shufflevector input");
@@ -17400,7 +17400,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
1740017400
auto *LI = dyn_cast<LoadInst>(Load);
1740117401
if (!LI)
1740217402
return false;
17403-
assert(!Mask && "Unexpected mask on a load");
17403+
assert(!Mask && GapMask.popcount() == Factor && "Unexpected mask on a load");
1740417404

1740517405
const DataLayout &DL = LI->getDataLayout();
1740617406

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ class AArch64TargetLowering : public TargetLowering {
222222

223223
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
224224
ArrayRef<ShuffleVectorInst *> Shuffles,
225-
ArrayRef<unsigned> Indices,
226-
unsigned Factor) const override;
225+
ArrayRef<unsigned> Indices, unsigned Factor,
226+
const APInt &GapMask) const override;
227227
bool lowerInterleavedStore(Instruction *Store, Value *Mask,
228228
ShuffleVectorInst *SVI,
229229
unsigned Factor) const override;

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21745,7 +21745,7 @@ unsigned ARMTargetLowering::getMaxSupportedInterleaveFactor() const {
2174521745
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %vld2, i32 1
2174621746
bool ARMTargetLowering::lowerInterleavedLoad(
2174721747
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
21748-
ArrayRef<unsigned> Indices, unsigned Factor) const {
21748+
ArrayRef<unsigned> Indices, unsigned Factor, const APInt &GapMask) const {
2174921749
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
2175021750
"Invalid interleave factor");
2175121751
assert(!Shuffles.empty() && "Empty shufflevector input");
@@ -21755,7 +21755,7 @@ bool ARMTargetLowering::lowerInterleavedLoad(
2175521755
auto *LI = dyn_cast<LoadInst>(Load);
2175621756
if (!LI)
2175721757
return false;
21758-
assert(!Mask && "Unexpected mask on a load");
21758+
assert(!Mask && GapMask.popcount() == Factor && "Unexpected mask on a load");
2175921759

2176021760
auto *VecTy = cast<FixedVectorType>(Shuffles[0]->getType());
2176121761
Type *EltTy = VecTy->getElementType();

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,8 @@ class VectorType;
685685

686686
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
687687
ArrayRef<ShuffleVectorInst *> Shuffles,
688-
ArrayRef<unsigned> Indices,
689-
unsigned Factor) const override;
688+
ArrayRef<unsigned> Indices, unsigned Factor,
689+
const APInt &GapMask) const override;
690690
bool lowerInterleavedStore(Instruction *Store, Value *Mask,
691691
ShuffleVectorInst *SVI,
692692
unsigned Factor) const override;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,8 @@ class RISCVTargetLowering : public TargetLowering {
431431

432432
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
433433
ArrayRef<ShuffleVectorInst *> Shuffles,
434-
ArrayRef<unsigned> Indices,
435-
unsigned Factor) const override;
434+
ArrayRef<unsigned> Indices, unsigned Factor,
435+
const APInt &GapMask) const override;
436436

437437
bool lowerInterleavedStore(Instruction *Store, Value *Mask,
438438
ShuffleVectorInst *SVI,

llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ static const Intrinsic::ID FixedVlsegIntrIds[] = {
6363
Intrinsic::riscv_seg6_load_mask, Intrinsic::riscv_seg7_load_mask,
6464
Intrinsic::riscv_seg8_load_mask};
6565

66+
static const Intrinsic::ID FixedVlssegIntrIds[] = {
67+
Intrinsic::riscv_sseg2_load_mask, Intrinsic::riscv_sseg3_load_mask,
68+
Intrinsic::riscv_sseg4_load_mask, Intrinsic::riscv_sseg5_load_mask,
69+
Intrinsic::riscv_sseg6_load_mask, Intrinsic::riscv_sseg7_load_mask,
70+
Intrinsic::riscv_sseg8_load_mask};
71+
6672
static const Intrinsic::ID ScalableVlsegIntrIds[] = {
6773
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
6874
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
@@ -197,9 +203,15 @@ static bool getMemOperands(unsigned Factor, VectorType *VTy, Type *XLenTy,
197203
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
198204
bool RISCVTargetLowering::lowerInterleavedLoad(
199205
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
200-
ArrayRef<unsigned> Indices, unsigned Factor) const {
206+
ArrayRef<unsigned> Indices, unsigned Factor, const APInt &GapMask) const {
201207
assert(Indices.size() == Shuffles.size());
208+
assert(GapMask.getBitWidth() == Factor);
202209

210+
// We only support cases where the skipped fields are the trailing ones.
211+
// TODO: Lower to strided load if there is only a single active field.
212+
unsigned MaskFactor = GapMask.popcount();
213+
if (MaskFactor < 2 || !GapMask.isMask())
214+
return false;
203215
IRBuilder<> Builder(Load);
204216

205217
const DataLayout &DL = Load->getDataLayout();
@@ -208,20 +220,37 @@ bool RISCVTargetLowering::lowerInterleavedLoad(
208220

209221
Value *Ptr, *VL;
210222
Align Alignment;
211-
if (!getMemOperands(Factor, VTy, XLenTy, Load, Ptr, Mask, VL, Alignment))
223+
if (!getMemOperands(MaskFactor, VTy, XLenTy, Load, Ptr, Mask, VL, Alignment))
212224
return false;
213225

214226
Type *PtrTy = Ptr->getType();
215227
unsigned AS = PtrTy->getPointerAddressSpace();
216-
if (!isLegalInterleavedAccessType(VTy, Factor, Alignment, AS, DL))
228+
if (!isLegalInterleavedAccessType(VTy, MaskFactor, Alignment, AS, DL))
217229
return false;
218230

219-
CallInst *VlsegN = Builder.CreateIntrinsic(
220-
FixedVlsegIntrIds[Factor - 2], {VTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
231+
CallInst *SegLoad = nullptr;
232+
if (MaskFactor < Factor) {
233+
// Lower to strided segmented load.
234+
unsigned ScalarSizeInBytes = DL.getTypeStoreSize(VTy->getElementType());
235+
Value *Stride = ConstantInt::get(XLenTy, Factor * ScalarSizeInBytes);
236+
SegLoad = Builder.CreateIntrinsic(FixedVlssegIntrIds[MaskFactor - 2],
237+
{VTy, PtrTy, XLenTy, XLenTy},
238+
{Ptr, Stride, Mask, VL});
239+
} else {
240+
// Lower to normal segmented load.
241+
SegLoad = Builder.CreateIntrinsic(FixedVlsegIntrIds[Factor - 2],
242+
{VTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
243+
}
221244

222245
for (unsigned i = 0; i < Shuffles.size(); i++) {
223-
Value *SubVec = Builder.CreateExtractValue(VlsegN, Indices[i]);
224-
Shuffles[i]->replaceAllUsesWith(SubVec);
246+
unsigned FactorIdx = Indices[i];
247+
if (FactorIdx >= MaskFactor) {
248+
// Replace masked-off factors (that are still extracted) with poison.
249+
Shuffles[i]->replaceAllUsesWith(PoisonValue::get(VTy));
250+
} else {
251+
Value *SubVec = Builder.CreateExtractValue(SegLoad, FactorIdx);
252+
Shuffles[i]->replaceAllUsesWith(SubVec);
253+
}
225254
}
226255

227256
return true;

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,8 +1663,8 @@ namespace llvm {
16631663
/// instructions/intrinsics.
16641664
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
16651665
ArrayRef<ShuffleVectorInst *> Shuffles,
1666-
ArrayRef<unsigned> Indices,
1667-
unsigned Factor) const override;
1666+
ArrayRef<unsigned> Indices, unsigned Factor,
1667+
const APInt &GapMask) const override;
16681668

16691669
/// Lower interleaved store(s) into target specific
16701670
/// instructions/intrinsics.

llvm/lib/Target/X86/X86InterleavedAccess.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ bool X86InterleavedAccessGroup::lowerIntoOptimizedSequence() {
802802
// Currently, lowering is supported for 4x64 bits with Factor = 4 on AVX.
803803
bool X86TargetLowering::lowerInterleavedLoad(
804804
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
805-
ArrayRef<unsigned> Indices, unsigned Factor) const {
805+
ArrayRef<unsigned> Indices, unsigned Factor, const APInt &GapMask) const {
806806
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
807807
"Invalid interleave factor");
808808
assert(!Shuffles.empty() && "Empty shufflevector input");
@@ -812,7 +812,7 @@ bool X86TargetLowering::lowerInterleavedLoad(
812812
auto *LI = dyn_cast<LoadInst>(Load);
813813
if (!LI)
814814
return false;
815-
assert(!Mask && "Unexpected mask on a load");
815+
assert(!Mask && GapMask.popcount() == Factor && "Unexpected mask on a load");
816816

817817
// Create an interleaved access group.
818818
IRBuilder<> Builder(LI);

0 commit comments

Comments
 (0)