Skip to content

Commit 22cfc51

Browse files
committed
[SPIRV] Legalize long vectors in GlobalISel
This commit introduces support for legalizing long vectors (vectors with more than 4 elements) in the SPIR-V backend using GlobalISel. This is primarily for shader compilation where the GLSL_std_450 instruction set is available. The main changes include: - Adding legalization rules for vector operations (G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_EXTRACT_VECTOR_ELT, G_BITCAST, G_CONCAT_VECTORS) to split vectors with more than 4 elements into smaller vectors. - Enhancing the SPIRVPostLegalizer with a worklist-based approach to correctly process instructions and types generated during legalization. - Lowering G_EXTRACT_VECTOR_ELT to a spv_extractelt intrinsic. - Refining the handling of G_BITCAST to legalize non-pointer bitcasts. - Marking many SPIR-V operations as pure to aid optimization.
1 parent 6bee6b2 commit 22cfc51

File tree

11 files changed

+520
-145
lines changed

11 files changed

+520
-145
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,11 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
314314
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
315315
unsigned Size);
316316

317+
/// True iff the specified type index is a vector with an element size
318+
/// that's greater than the given size.
319+
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
320+
unsigned Size);
321+
317322
/// True iff the specified type index is a scalar or a vector with an element
318323
/// type that's wider than the given size.
319324
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,

llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
155155
};
156156
}
157157

158+
LegalityPredicate
159+
LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
160+
unsigned Size) {
161+
162+
return [=](const LegalityQuery &Query) {
163+
const LLT QueryTy = Query.Types[TypeIdx];
164+
return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
165+
};
166+
}
167+
158168
LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
159169
unsigned Size) {
160170
return [=](const LegalityQuery &Query) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,35 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
116116
return SpirvType;
117117
}
118118

119+
SPIRVType *SPIRVGlobalRegistry::assignLltTypeToVReg(Register VReg,
120+
MachineInstr &I) {
121+
MachineRegisterInfo &MRI = CurMF->getRegInfo();
122+
LLT Type = MRI.getType(VReg);
123+
assert(Type.isValid());
124+
125+
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
126+
const SPIRVInstrInfo &TII = *ST.getInstrInfo();
127+
128+
if (Type.isVector()) {
129+
const LLT ElemTy = Type.getElementType();
130+
const unsigned NumElements = Type.getElementCount().getFixedValue();
131+
const unsigned ElemBitwidth = ElemTy.getScalarSizeInBits();
132+
SPIRVType *BaseType = nullptr;
133+
if (MRI.getRegClass(VReg) == &SPIRV::vfIDRegClass)
134+
BaseType = getOrCreateSPIRVFloatType(ElemBitwidth, I, TII);
135+
else
136+
BaseType = getOrCreateSPIRVIntegerType(ElemBitwidth, I, TII);
137+
return assignVectTypeToVReg(BaseType, NumElements, VReg, I, TII);
138+
}
139+
if (Type.isScalar()) {
140+
const unsigned Bitwidth = Type.getScalarSizeInBits();
141+
if (MRI.getRegClass(VReg) == &SPIRV::fIDRegClass)
142+
return assignFloatTypeToVReg(Bitwidth, VReg, I, TII);
143+
return assignIntTypeToVReg(Bitwidth, VReg, I, TII);
144+
}
145+
llvm_unreachable("Not implemented LLT");
146+
}
147+
119148
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
120149
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
121150
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
288288
SPIRVType *assignVectTypeToVReg(SPIRVType *BaseType, unsigned NumElements,
289289
Register VReg, MachineInstr &I,
290290
const SPIRVInstrInfo &TII);
291+
SPIRVType *assignLltTypeToVReg(Register VReg, MachineInstr &I);
291292

292293
// In cases where the SPIR-V type is already known, this function can be
293294
// used to map it to the given VReg via an ASSIGN_TYPE instruction.

llvm/lib/Target/SPIRV/SPIRVInstrFormats.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
2525
let Pattern = pattern;
2626
}
2727

28+
class PureOp<bits<16> Opcode, dag outs, dag ins, string asmstr,
29+
list<dag> pattern = []> : Op<Opcode, outs, ins, asmstr, pattern> {
30+
let hasSideEffects = 0;
31+
}
32+
2833
class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
2934
: Op<0, outs, ins, asmstr, pattern> {
3035
let isPseudo = 1;

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 108 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari
163163

164164
// 3.42.6 Type-Declaration Instructions
165165

166-
def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
167-
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
168-
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
169-
"$type = OpTypeInt $width $signedness">;
170-
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
171-
"$type = OpTypeFloat $width">;
172-
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
173-
"$type = OpTypeVector $compType $compCount">;
174-
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
175-
"$type = OpTypeMatrix $colType $colCount">;
176-
def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
177-
i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops),
178-
"$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">;
179-
def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
180-
def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType),
181-
"$res = OpTypeSampledImage $imageType">;
182-
def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
183-
"$type = OpTypeArray $elementType $length">;
184-
def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType),
185-
"$type = OpTypeRuntimeArray $elementType">;
186-
def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
187-
def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops),
188-
"OpTypeStructContinuedINTEL">;
189-
def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
190-
"$res = OpTypeOpaque $name">;
191-
def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
192-
"$res = OpTypePointer $storage $type">;
193-
def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
194-
"$funcType = OpTypeFunction $returnType">;
195-
def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
196-
def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
197-
def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
198-
def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
199-
def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">;
200-
def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
201-
"OpTypeForwardPointer $ptrType $storageClass">;
202-
def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
203-
def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
204-
def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
205-
"$res = OpTypeAccelerationStructureNV">;
206-
def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
207-
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
208-
"$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
209-
def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
210-
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
211-
"$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
166+
def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
167+
def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
168+
def OpTypeInt
169+
: PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
170+
"$type = OpTypeInt $width $signedness">;
171+
def OpTypeFloat
172+
: PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
173+
"$type = OpTypeFloat $width">;
174+
def OpTypeVector
175+
: PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
176+
"$type = OpTypeVector $compType $compCount">;
177+
def OpTypeMatrix
178+
: PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
179+
"$type = OpTypeMatrix $colType $colCount">;
180+
def OpTypeImage : PureOp<25, (outs TYPE:$res),
181+
(ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
182+
i32imm:$arrayed, i32imm:$MS, i32imm:$sampled,
183+
ImageFormat:$imFormat, variable_ops),
184+
"$res = OpTypeImage $sampTy $dim $depth $arrayed $MS "
185+
"$sampled $imFormat">;
186+
def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
187+
def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType),
188+
"$res = OpTypeSampledImage $imageType">;
189+
def OpTypeArray
190+
: PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
191+
"$type = OpTypeArray $elementType $length">;
192+
def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType),
193+
"$type = OpTypeRuntimeArray $elementType">;
194+
def OpTypeStruct
195+
: PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
196+
def OpTypeStructContinuedINTEL
197+
: PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">;
198+
def OpTypeOpaque
199+
: PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
200+
"$res = OpTypeOpaque $name">;
201+
def OpTypePointer
202+
: PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
203+
"$res = OpTypePointer $storage $type">;
204+
def OpTypeFunction
205+
: PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
206+
"$funcType = OpTypeFunction $returnType">;
207+
def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
208+
def OpTypeDeviceEvent
209+
: PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
210+
def OpTypeReserveId
211+
: PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
212+
def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
213+
def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a),
214+
"$res = OpTypePipe $a">;
215+
def OpTypeForwardPointer
216+
: PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
217+
"OpTypeForwardPointer $ptrType $storageClass">;
218+
def OpTypePipeStorage
219+
: PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
220+
def OpTypeNamedBarrier
221+
: PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
222+
def OpTypeAccelerationStructureNV
223+
: PureOp<5341, (outs TYPE:$res), (ins),
224+
"$res = OpTypeAccelerationStructureNV">;
225+
def OpTypeCooperativeMatrixNV
226+
: PureOp<5358, (outs TYPE:$res),
227+
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
228+
"$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
229+
def OpTypeCooperativeMatrixKHR
230+
: PureOp<4456, (outs TYPE:$res),
231+
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
232+
"$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols "
233+
"$use">;
212234

213235
// 3.42.7 Constant-Creation Instructions
214236

@@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">;
222244

223245
def ConstPseudoTrue: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
224246
def ConstPseudoFalse: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
225-
def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty",
226-
[(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
227-
def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty",
228-
[(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
229-
230-
def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
231-
"$res = OpConstantComposite $type">;
232-
def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops),
233-
"OpConstantCompositeContinuedINTEL">;
234-
235-
def OpConstantSampler: Op<45, (outs ID:$res),
236-
(ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f),
237-
"$res = OpConstantSampler $t $s $p $f">;
238-
def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">;
239-
240-
def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
241-
def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
242-
def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
243-
"$res = OpSpecConstant $type $imm">;
244-
def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
245-
"$res = OpSpecConstantComposite $type">;
246-
def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops),
247-
"OpSpecConstantCompositeContinuedINTEL">;
248-
def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
249-
"$res = OpSpecConstantOp $t $c $o">;
247+
def OpConstantTrue
248+
: PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty),
249+
"$dst = OpConstantTrue $src_ty",
250+
[(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
251+
def OpConstantFalse
252+
: PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty),
253+
"$dst = OpConstantFalse $src_ty",
254+
[(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
255+
256+
def OpConstantComposite
257+
: PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
258+
"$res = OpConstantComposite $type">;
259+
def OpConstantCompositeContinuedINTEL
260+
: PureOp<6091, (outs), (ins variable_ops),
261+
"OpConstantCompositeContinuedINTEL">;
262+
263+
def OpConstantSampler : PureOp<45, (outs ID:$res),
264+
(ins TYPE:$t, SamplerAddressingMode:$s,
265+
i32imm:$p, SamplerFilterMode:$f),
266+
"$res = OpConstantSampler $t $s $p $f">;
267+
def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty),
268+
"$dst = OpConstantNull $src_ty">;
269+
270+
def OpSpecConstantTrue
271+
: PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
272+
def OpSpecConstantFalse
273+
: PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
274+
def OpSpecConstant
275+
: PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
276+
"$res = OpSpecConstant $type $imm">;
277+
def OpSpecConstantComposite
278+
: PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
279+
"$res = OpSpecConstantComposite $type">;
280+
def OpSpecConstantCompositeContinuedINTEL
281+
: PureOp<6092, (outs), (ins variable_ops),
282+
"OpSpecConstantCompositeContinuedINTEL">;
283+
def OpSpecConstantOp
284+
: PureOp<52, (outs ID:$res),
285+
(ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
286+
"$res = OpSpecConstantOp $t $c $o">;
250287

251288
// 3.42.8 Memory Instructions
252289

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
15261526
unsigned ArgI = I.getNumOperands() - 1;
15271527
Register SrcReg =
15281528
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
1529-
SPIRVType *DefType =
1529+
SPIRVType *SrcType =
15301530
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
1531-
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
1531+
if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
15321532
report_fatal_error(
15331533
"cannot select G_UNMERGE_VALUES with a non-vector argument");
15341534

15351535
SPIRVType *ScalarType =
1536-
GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
1536+
GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
15371537
MachineBasicBlock &BB = *I.getParent();
15381538
bool Res = false;
1539+
unsigned CurrentIndex = 0;
15391540
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
15401541
Register ResVReg = I.getOperand(i).getReg();
15411542
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
15421543
if (!ResType) {
1543-
// There was no "assign type" actions, let's fix this now
1544-
ResType = ScalarType;
1544+
LLT ResLLT = MRI->getType(ResVReg);
1545+
assert(ResLLT.isValid());
1546+
if (ResLLT.isVector()) {
1547+
ResType = GR.getOrCreateSPIRVVectorType(
1548+
ScalarType, ResLLT.getNumElements(), I, TII);
1549+
} else {
1550+
ResType = ScalarType;
1551+
}
15451552
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
1546-
MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
15471553
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
15481554
}
1549-
auto MIB =
1550-
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1551-
.addDef(ResVReg)
1552-
.addUse(GR.getSPIRVTypeID(ResType))
1553-
.addUse(SrcReg)
1554-
.addImm(static_cast<int64_t>(i));
1555-
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1555+
1556+
if (ResType->getOpcode() == SPIRV::OpTypeVector) {
1557+
Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
1558+
auto MIB =
1559+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
1560+
.addDef(ResVReg)
1561+
.addUse(GR.getSPIRVTypeID(ResType))
1562+
.addUse(SrcReg)
1563+
.addUse(UndefReg);
1564+
unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
1565+
for (unsigned j = 0; j < NumElements; ++j) {
1566+
MIB.addImm(CurrentIndex + j);
1567+
}
1568+
CurrentIndex += NumElements;
1569+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1570+
} else {
1571+
auto MIB =
1572+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1573+
.addDef(ResVReg)
1574+
.addUse(GR.getSPIRVTypeID(ResType))
1575+
.addUse(SrcReg)
1576+
.addImm(CurrentIndex);
1577+
CurrentIndex++;
1578+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1579+
}
15561580
}
15571581
return Res;
15581582
}

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ class SPIRVLegalizePointerCast : public FunctionPass {
7373
// Returns the loaded value.
7474
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
7575
FixedVectorType *TargetType, Value *Source) {
76-
assert(TargetType->getNumElements() <= SourceType->getNumElements());
76+
const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
77+
[[maybe_unused]] TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType);
78+
[[maybe_unused]] TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType);
79+
assert(TargetTypeSize <= SourceTypeSize);
7780
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
7881
buildAssignType(B, SourceType, NewLoad);
7982
Value *AssignValue = NewLoad;

0 commit comments

Comments
 (0)