Skip to content

Commit bbda25d

Browse files
[RISCV][GISEL] Legalize G_INSERT_SUBVECTOR
This code is heavily based on the SelectionDAG lowerINSERT_SUBVECTOR code.
1 parent 6497283 commit bbda25d

File tree

4 files changed

+577
-0
lines changed

4 files changed

+577
-0
lines changed

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,12 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
581581

582582
SplatActions.clampScalar(1, sXLen, sXLen);
583583

584+
getActionDefinitionsBuilder(G_INSERT_SUBVECTOR)
585+
.customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
586+
typeIsLegalBoolVec(1, BoolVecTys, ST)))
587+
.customIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
588+
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
589+
584590
getLegacyLegalizerInfo().computeTables();
585591
}
586592

@@ -915,6 +921,154 @@ bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
915921
return true;
916922
}
917923

924+
static LLT getLMUL1Ty(LLT VecTy) {
925+
assert(VecTy.getElementType().getSizeInBits() <= 64 &&
926+
"Unexpected vector LLT");
927+
return LLT::scalable_vector(RISCV::RVVBitsPerBlock /
928+
VecTy.getElementType().getSizeInBits(),
929+
VecTy.getElementType());
930+
}
931+
932+
bool RISCVLegalizerInfo::legalizeInsertSubvector(MachineInstr &MI,
933+
MachineIRBuilder &MIB) const {
934+
assert(MI.getOpcode() == TargetOpcode::G_INSERT_SUBVECTOR);
935+
936+
MachineRegisterInfo &MRI = *MIB.getMRI();
937+
938+
Register Dst = MI.getOperand(0).getReg();
939+
Register Src1 = MI.getOperand(1).getReg();
940+
Register Src2 = MI.getOperand(2).getReg();
941+
uint64_t Idx = MI.getOperand(3).getImm();
942+
943+
LLT BigTy = MRI.getType(Src1);
944+
LLT LitTy = MRI.getType(Src2);
945+
Register BigVec = Src1;
946+
Register LitVec = Src2;
947+
948+
// We don't have the ability to slide mask vectors up indexed by their i1
949+
// elements; the smallest we can do is i8. Often we are able to bitcast to
950+
// equivalent i8 vectors. Otherwise, we can must zeroextend to equivalent i8
951+
// vectors and truncate down after the insert.
952+
if (LitTy.getElementType() == LLT::scalar(1) &&
953+
(Idx != 0 ||
954+
MRI.getVRegDef(BigVec)->getOpcode() != TargetOpcode::G_IMPLICIT_DEF)) {
955+
auto BigTyMinElts = BigTy.getElementCount().getKnownMinValue();
956+
auto LitTyMinElts = LitTy.getElementCount().getKnownMinValue();
957+
if (BigTyMinElts >= 8 && LitTyMinElts >= 8) {
958+
assert(Idx % 8 == 0 && "Invalid index");
959+
assert(BigTyMinElts % 8 == 0 && LitTyMinElts % 8 == 0 &&
960+
"Unexpected mask vector lowering");
961+
Idx /= 8;
962+
BigTy = LLT::vector(BigTy.getElementCount().divideCoefficientBy(8), 8);
963+
LitTy = LLT::vector(LitTy.getElementCount().divideCoefficientBy(8), 8);
964+
BigVec = MIB.buildBitcast(BigTy, BigVec).getReg(0);
965+
LitVec = MIB.buildBitcast(LitTy, LitVec).getReg(0);
966+
} else {
967+
// We can't slide this mask vector up indexed by its i1 elements.
968+
// This poses a problem when we wish to insert a scalable vector which
969+
// can't be re-expressed as a larger type. Just choose the slow path and
970+
// extend to a larger type, then truncate back down.
971+
LLT ExtBigTy = BigTy.changeElementType(LLT::scalar(8));
972+
LLT ExtLitTy = LitTy.changeElementType(LLT::scalar(8));
973+
auto BigZExt = MIB.buildZExt(ExtBigTy, BigVec);
974+
auto LitZExt = MIB.buildZExt(ExtLitTy, LitVec);
975+
auto Insert = MIB.buildInsertSubvector(ExtBigTy, BigZExt, LitZExt, Idx);
976+
auto SplatZero = MIB.buildSplatVector(
977+
ExtBigTy, MIB.buildConstant(ExtBigTy.getElementType(), 0));
978+
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, Insert, SplatZero);
979+
MI.eraseFromParent();
980+
return true;
981+
}
982+
}
983+
984+
const RISCVRegisterInfo *TRI = STI.getRegisterInfo();
985+
MVT LitTyMVT = getMVTForLLT(LitTy);
986+
unsigned SubRegIdx, RemIdx;
987+
std::tie(SubRegIdx, RemIdx) =
988+
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
989+
getMVTForLLT(BigTy), LitTyMVT, Idx, TRI);
990+
991+
RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(LitTy));
992+
bool IsSubVecPartReg = SubVecLMUL == RISCVII::VLMUL::LMUL_F2 ||
993+
SubVecLMUL == RISCVII::VLMUL::LMUL_F4 ||
994+
SubVecLMUL == RISCVII::VLMUL::LMUL_F8;
995+
996+
// If the Idx has been completely eliminated and this subvector's size is a
997+
// vector register or a multiple thereof, or the surrounding elements are
998+
// undef, then this is a subvector insert which naturally aligns to a vector
999+
// register. These can easily be handled using subregister manipulation.
1000+
if (RemIdx == 0 && (!IsSubVecPartReg || MRI.getVRegDef(Src1)->getOpcode() ==
1001+
TargetOpcode::G_IMPLICIT_DEF))
1002+
return true;
1003+
1004+
// If the subvector is smaller than a vector register, then the insertion
1005+
// must preserve the undisturbed elements of the register. We do this by
1006+
// lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
1007+
// (which resolves to a subregister copy), performing a VSLIDEUP to place the
1008+
// subvector within the vector register, and an INSERT_SUBVECTOR of that
1009+
// LMUL=1 type back into the larger vector (resolving to another subregister
1010+
// operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
1011+
// to avoid allocating a large register group to hold our subvector.
1012+
1013+
// VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
1014+
// OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
1015+
// (in our case undisturbed). This means we can set up a subvector insertion
1016+
// where OFFSET is the insertion offset, and the VL is the OFFSET plus the
1017+
// size of the subvector.
1018+
const LLT XLenTy(STI.getXLenVT());
1019+
LLT InterLitTy = BigTy;
1020+
Register AlignedExtract = Src1;
1021+
unsigned AlignedIdx = Idx - RemIdx;
1022+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(),
1023+
getLMUL1Ty(BigTy).getSizeInBits())) {
1024+
InterLitTy = getLMUL1Ty(BigTy);
1025+
// Extract a subvector equal to the nearest full vector register type. This
1026+
// should resolve to a G_EXTRACT on a subreg.
1027+
AlignedExtract =
1028+
MIB.buildExtractSubvector(InterLitTy, BigVec, AlignedIdx).getReg(0);
1029+
}
1030+
1031+
auto Insert = MIB.buildInsertSubvector(InterLitTy, MIB.buildUndef(InterLitTy),
1032+
LitVec, 0);
1033+
1034+
auto [Mask, _] = buildDefaultVLOps(BigTy, MIB, MRI);
1035+
auto VL = MIB.buildVScale(XLenTy, LitTy.getElementCount().getKnownMinValue());
1036+
1037+
// Use tail agnostic policy if we're inserting over InterLitTy's tail.
1038+
ElementCount EndIndex =
1039+
ElementCount::getScalable(RemIdx) + LitTy.getElementCount();
1040+
uint64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
1041+
if (EndIndex == InterLitTy.getElementCount())
1042+
Policy = RISCVII::TAIL_AGNOSTIC;
1043+
1044+
// If we're inserting into the lowest elements, use a tail undisturbed
1045+
// vmv.v.v.
1046+
MachineInstrBuilder Inserted;
1047+
if (RemIdx == 0) {
1048+
Inserted = MIB.buildInstr(RISCV::G_VMV_V_V_VL, {InterLitTy},
1049+
{AlignedExtract, Insert, VL});
1050+
} else {
1051+
auto SlideupAmt = MIB.buildVScale(XLenTy, RemIdx);
1052+
// Construct the vector length corresponding to RemIdx + length(LitTy).
1053+
VL = MIB.buildAdd(XLenTy, SlideupAmt, VL);
1054+
Inserted =
1055+
MIB.buildInstr(RISCV::G_VSLIDEUP_VL, {InterLitTy},
1056+
{AlignedExtract, LitVec, SlideupAmt, Mask, VL, Policy});
1057+
}
1058+
1059+
// If required, insert this subvector back into the correct vector register.
1060+
// This should resolve to an INSERT_SUBREG instruction.
1061+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(), InterLitTy.getSizeInBits()))
1062+
Inserted = MIB.buildInsert(BigTy, BigVec, LitVec, AlignedIdx);
1063+
1064+
// We might have bitcast from a mask type: cast back to the original type if
1065+
// required.
1066+
MIB.buildBitcast(Dst, Inserted);
1067+
1068+
MI.eraseFromParent();
1069+
return true;
1070+
}
1071+
9181072
bool RISCVLegalizerInfo::legalizeCustom(
9191073
LegalizerHelper &Helper, MachineInstr &MI,
9201074
LostDebugLocObserver &LocObserver) const {
@@ -985,6 +1139,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
9851139
return legalizeExt(MI, MIRBuilder);
9861140
case TargetOpcode::G_SPLAT_VECTOR:
9871141
return legalizeSplatVector(MI, MIRBuilder);
1142+
case TargetOpcode::G_INSERT_SUBVECTOR:
1143+
return legalizeInsertSubvector(MI, MIRBuilder);
9881144
case TargetOpcode::G_LOAD:
9891145
case TargetOpcode::G_STORE:
9901146
return legalizeLoadStore(MI, Helper, MIRBuilder);

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4646
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
4747
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4848
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
49+
bool legalizeInsertSubvector(MachineInstr &MI, MachineIRBuilder &MIB) const;
4950
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
5051
MachineIRBuilder &MIB) const;
5152
};

llvm/lib/Target/RISCV/RISCVInstrGISel.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,21 @@ def G_SPLAT_VECTOR_SPLIT_I64_VL : RISCVGenericInstruction {
5757
let InOperandList = (ins type0:$passthru, type1:$hi, type1:$lo, type2:$vl);
5858
let hasSideEffects = false;
5959
}
60+
61+
// Pseudo equivalent to a RISCVISD::VMV_V_V_VL
62+
def G_VMV_V_V_VL : RISCVGenericInstruction {
63+
let OutOperandList = (outs type0:$dst);
64+
let InOperandList = (ins type0:$vec, type2:$vl);
65+
let hasSideEffects = false;
66+
}
67+
def : GINodeEquiv<G_VMV_V_V_VL, riscv_vmv_v_v_vl>;
68+
69+
// Pseudo equivalent to a RISCVISD::VSLIDEUP_VL
70+
def G_VSLIDEUP_VL : RISCVGenericInstruction {
71+
let OutOperandList = (outs type0:$dst);
72+
let InOperandList = (ins type0:$merge, type0:$vec, type1:$idx, type2:$mask,
73+
type3:$vl, type4:$policy);
74+
let hasSideEffects = false;
75+
}
76+
def : GINodeEquiv<G_VSLIDEUP_VL, riscv_slideup_vl>;
77+

0 commit comments

Comments
 (0)