Skip to content

Commit d8aa9c9

Browse files
committed
[SPIR-V] Legalize vector arithmetic and intrinsics for large vectors
This patch improves the legalization of vector operations, particularly focusing on vectors that exceed the maximum supported size (e.g., 4 elements for shaders). This includes better handling for insert and extract element operations, which facilitates the legalization of loads and stores for long vectors—a common pattern when compiling HLSL matrices with Clang. Key changes include: - Adding legalization rules for G_FMA, G_INSERT_VECTOR_ELT, and various arithmetic operations to handle splitting of large vectors. - Updating G_CONCAT_VECTORS and G_SPLAT_VECTOR to be legal for allowed types. - Implementing custom legalization for G_INSERT_VECTOR_ELT using the spv_insertelt intrinsic. - Enhancing SPIRVPostLegalizer to deduce types for arithmetic instructions and vector element intrinsics (spv_insertelt, spv_extractelt). - Refactoring legalizeIntrinsic to uniformly handle vector legalization requirements. The strategy for insert and extract operations mirrors that of bitcasts: incoming intrinsics are converted to generic MIR instructions (G_INSERT_VECTOR_ELT and G_EXTRACT_VECTOR_ELT) to leverage standard legalization rules (like splitting). After legalization, they are converted back to their respective SPIR-V intrinsics (spv_insertelt, spv_extractelt) because later passes in the backend expect these intrinsics rather than the generic instructions. This ensures that operations on large vectors (e.g., <16 x float>) are correctly broken down into legal sub-vectors.
1 parent 07e63a3 commit d8aa9c9

File tree

4 files changed

+464
-43
lines changed

4 files changed

+464
-43
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 83 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
113113
v3s1, v3s8, v3s16, v3s32, v3s64,
114114
v4s1, v4s8, v4s16, v4s32, v4s64};
115115

116+
auto allScalars = {s1, s8, s16, s32};
117+
116118
auto allScalarsAndVectors = {
117119
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
118120
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -172,9 +174,25 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
172174

173175
for (auto Opc : getTypeFoldingSupportedOpcodes()) {
174176
if (Opc != G_EXTRACT_VECTOR_ELT)
175-
getActionDefinitionsBuilder(Opc).custom();
177+
getActionDefinitionsBuilder(Opc)
178+
.customFor(allScalars)
179+
.customFor(allowedVectorTypes)
180+
.moreElementsToNextPow2(0)
181+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
182+
LegalizeMutations::changeElementCountTo(
183+
0, ElementCount::getFixed(MaxVectorSize)))
184+
.custom();
176185
}
177186

187+
getActionDefinitionsBuilder(TargetOpcode::G_FMA)
188+
.legalFor(allScalars)
189+
.legalFor(allowedVectorTypes)
190+
.moreElementsToNextPow2(0)
191+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
192+
LegalizeMutations::changeElementCountTo(
193+
0, ElementCount::getFixed(MaxVectorSize)))
194+
.alwaysLegal();
195+
178196
getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
179197

180198
getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
@@ -192,6 +210,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
192210
1, ElementCount::getFixed(MaxVectorSize)))
193211
.custom();
194212

213+
getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
214+
.moreElementsToNextPow2(0)
215+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
216+
LegalizeMutations::changeElementCountTo(
217+
0, ElementCount::getFixed(MaxVectorSize)))
218+
.custom();
219+
195220
// Illegal G_UNMERGE_VALUES instructions should be handled
196221
// during the combine phase.
197222
getActionDefinitionsBuilder(G_BUILD_VECTOR)
@@ -215,14 +240,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
215240
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
216241
.custom();
217242

243+
// If the result is still illegal, the combiner should be able to remove it.
218244
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
219-
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
220-
.moreElementsToNextPow2(0)
221-
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
222-
.alwaysLegal();
245+
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
246+
.moreElementsToNextPow2(0);
223247

224248
getActionDefinitionsBuilder(G_SPLAT_VECTOR)
225-
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
249+
.legalFor(allowedVectorTypes)
226250
.moreElementsToNextPow2(0)
227251
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
228252
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
@@ -458,6 +482,23 @@ static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
458482
return true;
459483
}
460484

485+
static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
486+
SPIRVGlobalRegistry *GR) {
487+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
488+
Register DstReg = MI.getOperand(0).getReg();
489+
Register SrcReg = MI.getOperand(1).getReg();
490+
Register ValReg = MI.getOperand(2).getReg();
491+
Register IdxReg = MI.getOperand(3).getReg();
492+
493+
MIRBuilder
494+
.buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
495+
.addUse(SrcReg)
496+
.addUse(ValReg)
497+
.addUse(IdxReg);
498+
MI.eraseFromParent();
499+
return true;
500+
}
501+
461502
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
462503
LegalizerHelper &Helper,
463504
MachineRegisterInfo &MRI,
@@ -483,6 +524,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
483524
return legalizeBitcast(Helper, MI);
484525
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
485526
return legalizeExtractVectorElt(Helper, MI, GR);
527+
case TargetOpcode::G_INSERT_VECTOR_ELT:
528+
return legalizeInsertVectorElt(Helper, MI, GR);
486529
case TargetOpcode::G_INTRINSIC:
487530
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
488531
return legalizeIntrinsic(Helper, MI);
@@ -512,6 +555,15 @@ bool SPIRVLegalizerInfo::legalizeCustom(
512555
}
513556
}
514557

558+
static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
559+
if (!Ty.isVector())
560+
return false;
561+
unsigned NumElements = Ty.getNumElements();
562+
unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
563+
return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
564+
NumElements > MaxVectorSize;
565+
}
566+
515567
bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
516568
MachineInstr &MI) const {
517569
LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
@@ -528,41 +580,38 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
528580
LLT DstTy = MRI.getType(DstReg);
529581
LLT SrcTy = MRI.getType(SrcReg);
530582

531-
int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
532-
533-
bool DstNeedsLegalization = false;
534-
bool SrcNeedsLegalization = false;
535-
536-
if (DstTy.isVector()) {
537-
if (DstTy.getNumElements() > 4 &&
538-
!isPowerOf2_32(DstTy.getNumElements())) {
539-
DstNeedsLegalization = true;
540-
}
541-
542-
if (DstTy.getNumElements() > MaxVectorSize) {
543-
DstNeedsLegalization = true;
544-
}
545-
}
546-
547-
if (SrcTy.isVector()) {
548-
if (SrcTy.getNumElements() > 4 &&
549-
!isPowerOf2_32(SrcTy.getNumElements())) {
550-
SrcNeedsLegalization = true;
551-
}
552-
553-
if (SrcTy.getNumElements() > MaxVectorSize) {
554-
SrcNeedsLegalization = true;
555-
}
556-
}
557-
558583
// If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
559584
// allow using the generic legalization rules.
560-
if (DstNeedsLegalization || SrcNeedsLegalization) {
585+
if (needsVectorLegalization(DstTy, ST) ||
586+
needsVectorLegalization(SrcTy, ST)) {
561587
LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
562588
MIRBuilder.buildBitcast(DstReg, SrcReg);
563589
MI.eraseFromParent();
564590
}
565591
return true;
592+
} else if (IntrinsicID == Intrinsic::spv_insertelt) {
593+
Register DstReg = MI.getOperand(0).getReg();
594+
LLT DstTy = MRI.getType(DstReg);
595+
596+
if (needsVectorLegalization(DstTy, ST)) {
597+
Register SrcReg = MI.getOperand(2).getReg();
598+
Register ValReg = MI.getOperand(3).getReg();
599+
Register IdxReg = MI.getOperand(4).getReg();
600+
MIRBuilder.buildInsertVectorElement(DstReg, SrcReg, ValReg, IdxReg);
601+
MI.eraseFromParent();
602+
}
603+
return true;
604+
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
605+
Register SrcReg = MI.getOperand(2).getReg();
606+
LLT SrcTy = MRI.getType(SrcReg);
607+
608+
if (needsVectorLegalization(SrcTy, ST)) {
609+
Register DstReg = MI.getOperand(0).getReg();
610+
Register IdxReg = MI.getOperand(3).getReg();
611+
MIRBuilder.buildExtractVectorElement(DstReg, SrcReg, IdxReg);
612+
MI.eraseFromParent();
613+
}
614+
return true;
566615
}
567616
return true;
568617
}

llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "SPIRV.h"
1717
#include "SPIRVSubtarget.h"
1818
#include "SPIRVUtils.h"
19+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
1920
#include "llvm/IR/IntrinsicsSPIRV.h"
2021
#include "llvm/Support/Debug.h"
2122
#include <stack>
@@ -66,8 +67,9 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
6667
for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
6768
for (const auto &Use :
6869
MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
69-
assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
70-
"Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
70+
if (Use.getOpcode() != TargetOpcode::G_BUILD_VECTOR)
71+
continue;
72+
7173
if (auto *VecType =
7274
GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
7375
ScalarType = GR->getScalarOrVectorComponentType(VecType);
@@ -133,10 +135,10 @@ static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
133135
return ResType;
134136
}
135137

136-
static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
137-
Register UseRegister,
138-
SPIRVGlobalRegistry *GR,
139-
MachineIRBuilder &MIB) {
138+
static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
139+
Register UseRegister,
140+
SPIRVGlobalRegistry *GR,
141+
MachineIRBuilder &MIB) {
140142
for (const MachineOperand &MO : Use->defs()) {
141143
if (!MO.isReg())
142144
continue;
@@ -159,16 +161,43 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
159161
MachineRegisterInfo &MRI = MF.getRegInfo();
160162
for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
161163
SPIRVType *ResType = nullptr;
164+
LLVM_DEBUG(dbgs() << "Looking at use " << Use);
162165
switch (Use.getOpcode()) {
163166
case TargetOpcode::G_BUILD_VECTOR:
164167
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
165168
case TargetOpcode::G_UNMERGE_VALUES:
166-
LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
167-
ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
169+
case TargetOpcode::G_ADD:
170+
case TargetOpcode::G_SUB:
171+
case TargetOpcode::G_MUL:
172+
case TargetOpcode::G_SDIV:
173+
case TargetOpcode::G_UDIV:
174+
case TargetOpcode::G_SREM:
175+
case TargetOpcode::G_UREM:
176+
case TargetOpcode::G_FADD:
177+
case TargetOpcode::G_FSUB:
178+
case TargetOpcode::G_FMUL:
179+
case TargetOpcode::G_FDIV:
180+
case TargetOpcode::G_FREM:
181+
case TargetOpcode::G_FMA:
182+
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
183+
break;
184+
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
185+
case TargetOpcode::G_INTRINSIC: {
186+
auto IntrinsicID = cast<GIntrinsic>(Use).getIntrinsicID();
187+
if (IntrinsicID == Intrinsic::spv_insertelt) {
188+
if (Reg == Use.getOperand(2).getReg())
189+
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
190+
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
191+
if (Reg == Use.getOperand(2).getReg())
192+
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
193+
}
168194
break;
169195
}
170-
if (ResType)
196+
}
197+
if (ResType) {
198+
LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
171199
return ResType;
200+
}
172201
}
173202
return nullptr;
174203
}

0 commit comments

Comments
 (0)