Skip to content

Commit ed13aeb

Browse files
committed
[SPIRV] Add legalization for long vectors
This patch introduces the necessary infrastructure to legalize vector operations on vectors that are longer than what the SPIR-V target supports. For instance, shaders only support vectors up to 4 elements. The legalization is done by splitting the long vectors into smaller vectors of a legal size. Specifically, this patch does the following: - Introduces `vectorElementCountIsGreaterThan` and `vectorElementCountIsLessThanOrEqualTo` legality predicates. - Adds legalization rules for `G_SHUFFLE_VECTOR`, `G_EXTRACT_VECTOR_ELT`, `G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and `G_UNMERGE_VALUES`. - Handles `G_BITCAST` of long vectors by converting them to `@llvm.spv.bitcast` intrinsics which are then legalized. - Updates `selectUnmergeValues` to handle extraction of both scalars and vectors from a larger vector, using `OpCompositeExtract` and `OpVectorShuffle` respectively. - Adds a test case to verify the legalization of a bitcast between a `<8 x i32>` and `<4 x f64>`, which is a pattern generated by HLSL's `asuint` and `asdouble` intrinsics. Fixes: #165444
1 parent e2fa040 commit ed13aeb

File tree

6 files changed

+268
-23
lines changed

6 files changed

+268
-23
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
314314
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
315315
unsigned Size);
316316

317+
/// True iff the specified type index is a vector with an element size
318+
/// that's greater than the given size.
319+
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
320+
unsigned Size);
321+
322+
/// True iff the specified type index is a vector with an element size
323+
/// that's less than or equal to the given size.
324+
LLVM_ABI LegalityPredicate
325+
vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);
326+
317327
/// True iff the specified type index is a scalar or a vector with an element
318328
/// type that's wider than the given size.
319329
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,

llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
155155
};
156156
}
157157

158+
LegalityPredicate
159+
LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
160+
unsigned Size) {
161+
162+
return [=](const LegalityQuery &Query) {
163+
const LLT QueryTy = Query.Types[TypeIdx];
164+
return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
165+
};
166+
}
167+
168+
LegalityPredicate
169+
LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
170+
unsigned Size) {
171+
172+
return [=](const LegalityQuery &Query) {
173+
const LLT QueryTy = Query.Types[TypeIdx];
174+
return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
175+
};
176+
}
177+
158178
LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
159179
unsigned Size) {
160180
return [=](const LegalityQuery &Query) {

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
15261526
unsigned ArgI = I.getNumOperands() - 1;
15271527
Register SrcReg =
15281528
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
1529-
SPIRVType *DefType =
1529+
SPIRVType *SrcType =
15301530
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
1531-
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
1531+
if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
15321532
report_fatal_error(
15331533
"cannot select G_UNMERGE_VALUES with a non-vector argument");
15341534

15351535
SPIRVType *ScalarType =
1536-
GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
1536+
GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
15371537
MachineBasicBlock &BB = *I.getParent();
15381538
bool Res = false;
1539+
unsigned CurrentIndex = 0;
15391540
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
15401541
Register ResVReg = I.getOperand(i).getReg();
15411542
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
15421543
if (!ResType) {
1543-
// There was no "assign type" actions, let's fix this now
1544-
ResType = ScalarType;
1544+
LLT ResLLT = MRI->getType(ResVReg);
1545+
assert(ResLLT.isValid());
1546+
if (ResLLT.isVector()) {
1547+
ResType = GR.getOrCreateSPIRVVectorType(
1548+
ScalarType, ResLLT.getNumElements(), I, TII);
1549+
} else {
1550+
ResType = ScalarType;
1551+
}
15451552
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
1546-
MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
15471553
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
15481554
}
1549-
auto MIB =
1550-
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1551-
.addDef(ResVReg)
1552-
.addUse(GR.getSPIRVTypeID(ResType))
1553-
.addUse(SrcReg)
1554-
.addImm(static_cast<int64_t>(i));
1555-
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1555+
1556+
if (ResType->getOpcode() == SPIRV::OpTypeVector) {
1557+
Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
1558+
auto MIB =
1559+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
1560+
.addDef(ResVReg)
1561+
.addUse(GR.getSPIRVTypeID(ResType))
1562+
.addUse(SrcReg)
1563+
.addUse(UndefReg);
1564+
unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
1565+
for (unsigned j = 0; j < NumElements; ++j) {
1566+
MIB.addImm(CurrentIndex + j);
1567+
}
1568+
CurrentIndex += NumElements;
1569+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1570+
} else {
1571+
auto MIB =
1572+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1573+
.addDef(ResVReg)
1574+
.addUse(GR.getSPIRVTypeID(ResType))
1575+
.addUse(SrcReg)
1576+
.addImm(CurrentIndex);
1577+
CurrentIndex++;
1578+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1579+
}
15561580
}
15571581
return Res;
15581582
}

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
#include "SPIRV.h"
1515
#include "SPIRVGlobalRegistry.h"
1616
#include "SPIRVSubtarget.h"
17+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
1718
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
1819
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
1920
#include "llvm/CodeGen/MachineInstr.h"
2021
#include "llvm/CodeGen/MachineRegisterInfo.h"
2122
#include "llvm/CodeGen/TargetOpcodes.h"
23+
#include "llvm/IR/IntrinsicsSPIRV.h"
2224

2325
using namespace llvm;
2426
using namespace llvm::LegalizeActions;
@@ -101,6 +103,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
101103
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
102104
v16s8, v16s16, v16s32, v16s64};
103105

106+
auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
107+
v3s1, v3s8, v3s16, v3s32, v3s64,
108+
v4s1, v4s8, v4s16, v4s32, v4s64};
109+
104110
auto allScalarsAndVectors = {
105111
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
106112
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -126,6 +132,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
126132

127133
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
128134

135+
auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
136+
129137
bool IsExtendedInts =
130138
ST.canUseExtension(
131139
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
@@ -148,14 +156,63 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
148156
return IsExtendedInts && Ty.isValid();
149157
};
150158

151-
for (auto Opc : getTypeFoldingSupportedOpcodes())
152-
getActionDefinitionsBuilder(Opc).custom();
159+
uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
160+
161+
for (auto Opc : getTypeFoldingSupportedOpcodes()) {
162+
if (Opc != G_EXTRACT_VECTOR_ELT)
163+
getActionDefinitionsBuilder(Opc).custom();
164+
}
153165

154-
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
166+
getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
167+
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
168+
.moreElementsToNextPow2(0)
169+
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
170+
.moreElementsToNextPow2(1)
171+
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
172+
.alwaysLegal();
155173

156-
// TODO: add proper rules for vectors legalization.
157-
getActionDefinitionsBuilder(
158-
{G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
174+
getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
175+
.legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize))
176+
.moreElementsToNextPow2(1)
177+
.fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
178+
LegalizeMutations::changeElementCountTo(
179+
1, ElementCount::getFixed(MaxVectorSize)))
180+
.custom();
181+
182+
// Illegal G_UNMERGE_VALUES instructions should be handled
183+
// during the combine phase.
184+
getActionDefinitionsBuilder(G_BUILD_VECTOR)
185+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
186+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
187+
LegalizeMutations::changeElementCountTo(
188+
0, ElementCount::getFixed(MaxVectorSize)));
189+
190+
// When entering the legalizer, there should be no G_BITCAST instructions.
191+
// They should all be calls to the `spv_bitcast` intrinsic. The call to
192+
// the intrinsic will be converted to a G_BITCAST during legalization if
193+
// the vectors are not legal. After using the rules to legalize a G_BITCAST,
194+
// we turn it back into a call to the intrinsic with a custom ruel to avoid
195+
// potential machines verifier failures.
196+
getActionDefinitionsBuilder(G_BITCAST)
197+
.moreElementsToNextPow2(0)
198+
.moreElementsToNextPow2(1)
199+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
200+
LegalizeMutations::changeElementCountTo(
201+
0, ElementCount::getFixed(MaxVectorSize)))
202+
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
203+
.custom();
204+
205+
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
206+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
207+
.moreElementsToNextPow2(0)
208+
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
209+
.alwaysLegal();
210+
211+
getActionDefinitionsBuilder(G_SPLAT_VECTOR)
212+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
213+
.moreElementsToNextPow2(0)
214+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
215+
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
159216
.alwaysLegal();
160217

161218
// Vector Reduction Operations
@@ -164,17 +221,18 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
164221
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
165222
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
166223
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
167-
.legalFor(allVectors)
224+
.legalFor(allowedVectorTypes)
168225
.scalarize(1)
169226
.lower();
170227

171228
getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
172229
.scalarize(2)
173230
.lower();
174231

175-
// Merge/Unmerge
176-
// TODO: add proper legalization rules.
177-
getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
232+
// Illegal G_UNMERGE_VALUES instructions should be handled
233+
// during the combine phase.
234+
getActionDefinitionsBuilder(G_UNMERGE_VALUES)
235+
.legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
178236

179237
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
180238
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
@@ -287,6 +345,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
287345
// Pointer-handling.
288346
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
289347

348+
getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
349+
290350
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
291351
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
292352

@@ -374,6 +434,11 @@ bool SPIRVLegalizerInfo::legalizeCustom(
374434
default:
375435
// TODO: implement legalization for other opcodes.
376436
return true;
437+
case TargetOpcode::G_BITCAST:
438+
return legalizeBitcast(Helper, MI);
439+
case TargetOpcode::G_INTRINSIC:
440+
return legalizeIntrinsic(Helper, MI);
441+
377442
case TargetOpcode::G_IS_FPCLASS:
378443
return legalizeIsFPClass(Helper, MI, LocObserver);
379444
case TargetOpcode::G_ICMP: {
@@ -400,6 +465,44 @@ bool SPIRVLegalizerInfo::legalizeCustom(
400465
}
401466
}
402467

468+
bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
469+
MachineInstr &MI) const {
470+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
471+
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
472+
const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
473+
474+
auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
475+
if (IntrinsicID == Intrinsic::spv_bitcast) {
476+
Register DstReg = MI.getOperand(0).getReg();
477+
Register SrcReg = MI.getOperand(2).getReg();
478+
LLT DstTy = MRI.getType(DstReg);
479+
LLT SrcTy = MRI.getType(SrcReg);
480+
481+
int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
482+
bool isLongVector =
483+
(DstTy.isVector() && DstTy.getNumElements() > MaxVectorSize) ||
484+
(SrcTy.isVector() && SrcTy.getNumElements() > MaxVectorSize);
485+
486+
if (isLongVector) {
487+
MIRBuilder.buildBitcast(DstReg, SrcReg);
488+
MI.eraseFromParent();
489+
}
490+
return true;
491+
}
492+
return true;
493+
}
494+
495+
bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
496+
MachineInstr &MI) const {
497+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
498+
Register DstReg = MI.getOperand(0).getReg();
499+
Register SrcReg = MI.getOperand(1).getReg();
500+
SmallVector<Register, 1> DstRegs = {DstReg};
501+
MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
502+
MI.eraseFromParent();
503+
return true;
504+
}
505+
403506
// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
404507
// to ensure that all instructions created during the lowering have SPIR-V types
405508
// assigned to them.

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
2929
public:
3030
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
3131
LostDebugLocObserver &LocObserver) const override;
32+
bool legalizeIntrinsic(LegalizerHelper &Helper,
33+
MachineInstr &MI) const override;
34+
3235
SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
3336

3437
private:
3538
bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
3639
LostDebugLocObserver &LocObserver) const;
40+
bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const;
3741
};
3842
} // namespace llvm
3943
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H

0 commit comments

Comments
 (0)