Skip to content

Commit 3d862cf

Browse files
authored
[SPIRV] Add legalization for long vectors (#169665)
This patch introduces the necessary infrastructure to legalize vector operations on vectors that are longer than what the SPIR-V target supports. For instance, shaders only support vectors up to 4 elements. The legalization is done by splitting the long vectors into smaller vectors of a legal size. Specifically, this patch does the following: - Introduces `vectorElementCountIsGreaterThan` and `vectorElementCountIsLessThanOrEqualTo` legality predicates. - Adds legalization rules for `G_SHUFFLE_VECTOR`, `G_EXTRACT_VECTOR_ELT`, `G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and `G_UNMERGE_VALUES`. - Handles `G_BITCAST` of long vectors by converting them to `@llvm.spv.bitcast` intrinsics which are then legalized. - Updates `selectUnmergeValues` to handle extraction of both scalars and vectors from a larger vector, using `OpCompositeExtract` and `OpVectorShuffle` respectively. Fixes #165444
1 parent 8a3891c commit 3d862cf

File tree

7 files changed

+454
-24
lines changed

7 files changed

+454
-24
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
314314
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
315315
unsigned Size);
316316

317+
/// True iff the specified type index is a vector with a number of elements
318+
/// that's greater than the given size.
319+
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
320+
unsigned Size);
321+
322+
/// True iff the specified type index is a vector with a number of elements
323+
/// that's less than or equal to the given size.
324+
LLVM_ABI LegalityPredicate
325+
vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);
326+
317327
/// True iff the specified type index is a scalar or a vector with an element
318328
/// type that's wider than the given size.
319329
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,

llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
155155
};
156156
}
157157

158+
LegalityPredicate
159+
LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
160+
unsigned Size) {
161+
162+
return [=](const LegalityQuery &Query) {
163+
const LLT QueryTy = Query.Types[TypeIdx];
164+
return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
165+
};
166+
}
167+
168+
LegalityPredicate
169+
LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
170+
unsigned Size) {
171+
172+
return [=](const LegalityQuery &Query) {
173+
const LLT QueryTy = Query.Types[TypeIdx];
174+
return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
175+
};
176+
}
177+
158178
LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
159179
unsigned Size) {
160180
return [=](const LegalityQuery &Query) {

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,33 +1781,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
17811781
unsigned ArgI = I.getNumOperands() - 1;
17821782
Register SrcReg =
17831783
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
1784-
SPIRVType *DefType =
1784+
SPIRVType *SrcType =
17851785
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
1786-
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
1786+
if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
17871787
report_fatal_error(
17881788
"cannot select G_UNMERGE_VALUES with a non-vector argument");
17891789

17901790
SPIRVType *ScalarType =
1791-
GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
1791+
GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
17921792
MachineBasicBlock &BB = *I.getParent();
17931793
bool Res = false;
1794+
unsigned CurrentIndex = 0;
17941795
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
17951796
Register ResVReg = I.getOperand(i).getReg();
17961797
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
17971798
if (!ResType) {
1798-
// There was no "assign type" actions, let's fix this now
1799-
ResType = ScalarType;
1799+
LLT ResLLT = MRI->getType(ResVReg);
1800+
assert(ResLLT.isValid());
1801+
if (ResLLT.isVector()) {
1802+
ResType = GR.getOrCreateSPIRVVectorType(
1803+
ScalarType, ResLLT.getNumElements(), I, TII);
1804+
} else {
1805+
ResType = ScalarType;
1806+
}
18001807
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
1801-
MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
18021808
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
18031809
}
1804-
auto MIB =
1805-
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1806-
.addDef(ResVReg)
1807-
.addUse(GR.getSPIRVTypeID(ResType))
1808-
.addUse(SrcReg)
1809-
.addImm(static_cast<int64_t>(i));
1810-
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1810+
1811+
if (ResType->getOpcode() == SPIRV::OpTypeVector) {
1812+
Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
1813+
auto MIB =
1814+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
1815+
.addDef(ResVReg)
1816+
.addUse(GR.getSPIRVTypeID(ResType))
1817+
.addUse(SrcReg)
1818+
.addUse(UndefReg);
1819+
unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
1820+
for (unsigned j = 0; j < NumElements; ++j) {
1821+
MIB.addImm(CurrentIndex + j);
1822+
}
1823+
CurrentIndex += NumElements;
1824+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1825+
} else {
1826+
auto MIB =
1827+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1828+
.addDef(ResVReg)
1829+
.addUse(GR.getSPIRVTypeID(ResType))
1830+
.addUse(SrcReg)
1831+
.addImm(CurrentIndex);
1832+
CurrentIndex++;
1833+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1834+
}
18111835
}
18121836
return Res;
18131837
}

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 181 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414
#include "SPIRV.h"
1515
#include "SPIRVGlobalRegistry.h"
1616
#include "SPIRVSubtarget.h"
17+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
1718
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
1819
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
1920
#include "llvm/CodeGen/MachineInstr.h"
2021
#include "llvm/CodeGen/MachineRegisterInfo.h"
2122
#include "llvm/CodeGen/TargetOpcodes.h"
23+
#include "llvm/IR/IntrinsicsSPIRV.h"
24+
#include "llvm/Support/Debug.h"
25+
#include "llvm/Support/MathExtras.h"
2226

2327
using namespace llvm;
2428
using namespace llvm::LegalizeActions;
2529
using namespace llvm::LegalityPredicates;
2630

31+
#define DEBUG_TYPE "spirv-legalizer"
32+
2733
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
2834
return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
2935
const LLT Ty = Query.Types[TypeIdx];
@@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
101107
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
102108
v16s8, v16s16, v16s32, v16s64};
103109

110+
auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
111+
v3s1, v3s8, v3s16, v3s32, v3s64,
112+
v4s1, v4s8, v4s16, v4s32, v4s64};
113+
104114
auto allScalarsAndVectors = {
105115
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
106116
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
126136

127137
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
128138

139+
auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
140+
129141
bool IsExtendedInts =
130142
ST.canUseExtension(
131143
SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
@@ -148,14 +160,70 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
148160
return IsExtendedInts && Ty.isValid();
149161
};
150162

151-
for (auto Opc : getTypeFoldingSupportedOpcodes())
152-
getActionDefinitionsBuilder(Opc).custom();
163+
// The universal validation rules in the SPIR-V specification state that
164+
// vector sizes are typically limited to 2, 3, or 4. However, larger vector
165+
// sizes (8 and 16) are enabled when the Kernel capability is present. For
166+
// shader execution models, vector sizes are strictly limited to 4. In
167+
// non-shader contexts, vector sizes of 8 and 16 are also permitted, but
168+
// arbitrary sizes (e.g., 6 or 11) are not.
169+
uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
170+
171+
for (auto Opc : getTypeFoldingSupportedOpcodes()) {
172+
if (Opc != G_EXTRACT_VECTOR_ELT)
173+
getActionDefinitionsBuilder(Opc).custom();
174+
}
153175

154-
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
176+
getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
155177

156-
// TODO: add proper rules for vectors legalization.
157-
getActionDefinitionsBuilder(
158-
{G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
178+
getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
179+
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
180+
.moreElementsToNextPow2(0)
181+
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
182+
.moreElementsToNextPow2(1)
183+
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
184+
.alwaysLegal();
185+
186+
getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
187+
.moreElementsToNextPow2(1)
188+
.fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
189+
LegalizeMutations::changeElementCountTo(
190+
1, ElementCount::getFixed(MaxVectorSize)))
191+
.custom();
192+
193+
// Illegal G_UNMERGE_VALUES instructions should be handled
194+
// during the combine phase.
195+
getActionDefinitionsBuilder(G_BUILD_VECTOR)
196+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
197+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
198+
LegalizeMutations::changeElementCountTo(
199+
0, ElementCount::getFixed(MaxVectorSize)));
200+
201+
// When entering the legalizer, there should be no G_BITCAST instructions.
202+
// They should all be calls to the `spv_bitcast` intrinsic. The call to
203+
// the intrinsic will be converted to a G_BITCAST during legalization if
204+
// the vectors are not legal. After using the rules to legalize a G_BITCAST,
205+
// we turn it back into a call to the intrinsic with a custom rule to avoid
206+
// potential machine verifier failures.
207+
getActionDefinitionsBuilder(G_BITCAST)
208+
.moreElementsToNextPow2(0)
209+
.moreElementsToNextPow2(1)
210+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
211+
LegalizeMutations::changeElementCountTo(
212+
0, ElementCount::getFixed(MaxVectorSize)))
213+
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
214+
.custom();
215+
216+
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
217+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
218+
.moreElementsToNextPow2(0)
219+
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
220+
.alwaysLegal();
221+
222+
getActionDefinitionsBuilder(G_SPLAT_VECTOR)
223+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
224+
.moreElementsToNextPow2(0)
225+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
226+
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
159227
.alwaysLegal();
160228

161229
// Vector Reduction Operations
@@ -164,17 +232,18 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
164232
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
165233
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
166234
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
167-
.legalFor(allVectors)
235+
.legalFor(allowedVectorTypes)
168236
.scalarize(1)
169237
.lower();
170238

171239
getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
172240
.scalarize(2)
173241
.lower();
174242

175-
// Merge/Unmerge
176-
// TODO: add proper legalization rules.
177-
getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
243+
// Illegal G_UNMERGE_VALUES instructions should be handled
244+
// during the combine phase.
245+
getActionDefinitionsBuilder(G_UNMERGE_VALUES)
246+
.legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
178247

179248
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
180249
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
@@ -228,7 +297,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
228297
all(typeInSet(0, allPtrsScalarsAndVectors),
229298
typeInSet(1, allPtrsScalarsAndVectors)));
230299

231-
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
300+
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
301+
.legalFor({s1})
302+
.legalFor(allFloatAndIntScalarsAndPtrs)
303+
.legalFor(allowedVectorTypes)
304+
.moreElementsToNextPow2(0)
305+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
306+
LegalizeMutations::changeElementCountTo(
307+
0, ElementCount::getFixed(MaxVectorSize)));
232308

233309
getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
234310

@@ -287,6 +363,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
287363
// Pointer-handling.
288364
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
289365

366+
getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
367+
290368
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
291369
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
292370

@@ -353,6 +431,21 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
353431
verify(*ST.getInstrInfo());
354432
}
355433

434+
static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
435+
SPIRVGlobalRegistry *GR) {
436+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
437+
Register DstReg = MI.getOperand(0).getReg();
438+
Register SrcReg = MI.getOperand(1).getReg();
439+
Register IdxReg = MI.getOperand(2).getReg();
440+
441+
MIRBuilder
442+
.buildIntrinsic(Intrinsic::spv_extractelt, ArrayRef<Register>{DstReg})
443+
.addUse(SrcReg)
444+
.addUse(IdxReg);
445+
MI.eraseFromParent();
446+
return true;
447+
}
448+
356449
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
357450
LegalizerHelper &Helper,
358451
MachineRegisterInfo &MRI,
@@ -374,6 +467,13 @@ bool SPIRVLegalizerInfo::legalizeCustom(
374467
default:
375468
// TODO: implement legalization for other opcodes.
376469
return true;
470+
case TargetOpcode::G_BITCAST:
471+
return legalizeBitcast(Helper, MI);
472+
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
473+
return legalizeExtractVectorElt(Helper, MI, GR);
474+
case TargetOpcode::G_INTRINSIC:
475+
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
476+
return legalizeIntrinsic(Helper, MI);
377477
case TargetOpcode::G_IS_FPCLASS:
378478
return legalizeIsFPClass(Helper, MI, LocObserver);
379479
case TargetOpcode::G_ICMP: {
@@ -400,6 +500,76 @@ bool SPIRVLegalizerInfo::legalizeCustom(
400500
}
401501
}
402502

503+
bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
504+
MachineInstr &MI) const {
505+
LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
506+
507+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
508+
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
509+
const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
510+
511+
auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
512+
if (IntrinsicID == Intrinsic::spv_bitcast) {
513+
LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
514+
Register DstReg = MI.getOperand(0).getReg();
515+
Register SrcReg = MI.getOperand(2).getReg();
516+
LLT DstTy = MRI.getType(DstReg);
517+
LLT SrcTy = MRI.getType(SrcReg);
518+
519+
int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
520+
521+
bool DstNeedsLegalization = false;
522+
bool SrcNeedsLegalization = false;
523+
524+
if (DstTy.isVector()) {
525+
if (DstTy.getNumElements() > 4 &&
526+
!isPowerOf2_32(DstTy.getNumElements())) {
527+
DstNeedsLegalization = true;
528+
}
529+
530+
if (DstTy.getNumElements() > MaxVectorSize) {
531+
DstNeedsLegalization = true;
532+
}
533+
}
534+
535+
if (SrcTy.isVector()) {
536+
if (SrcTy.getNumElements() > 4 &&
537+
!isPowerOf2_32(SrcTy.getNumElements())) {
538+
SrcNeedsLegalization = true;
539+
}
540+
541+
if (SrcTy.getNumElements() > MaxVectorSize) {
542+
SrcNeedsLegalization = true;
543+
}
544+
}
545+
546+
// If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
547+
// allow using the generic legalization rules.
548+
if (DstNeedsLegalization || SrcNeedsLegalization) {
549+
LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
550+
MIRBuilder.buildBitcast(DstReg, SrcReg);
551+
MI.eraseFromParent();
552+
}
553+
return true;
554+
}
555+
return true;
556+
}
557+
558+
bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
559+
MachineInstr &MI) const {
560+
// Once the G_BITCAST is using vectors that are allowed, we turn it back into
561+
// an spv_bitcast to avoid verifier problems when the register types are the
562+
// same for the source and the result. Note that the SPIR-V types associated
563+
// with the bitcast can be different even if the register types are the same.
564+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
565+
Register DstReg = MI.getOperand(0).getReg();
566+
Register SrcReg = MI.getOperand(1).getReg();
567+
SmallVector<Register, 1> DstRegs = {DstReg};
568+
MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
569+
MI.eraseFromParent();
570+
return true;
571+
}
572+
403573
// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
404574
// to ensure that all instructions created during the lowering have SPIR-V types
405575
// assigned to them.

0 commit comments

Comments
 (0)