Skip to content

Commit 3abe11e

Browse files
committed
[SPIRV] Handle ptrcast between array and vector types
This commit adds support for legalizing pointer casts between array and vector types within the SPIRV backend. This is necessary to handle cases where a vector is loaded from or stored to an array, which can occur with HLSL matrix types. The following changes are included: - Added to load a vector from an array. - Added to store a vector to an array. - Added the test case to verify the functionality.
1 parent a50d036 commit 3abe11e

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,85 @@ class SPIRVLegalizePointerCast : public FunctionPass {
116116
return LI;
117117
}
118118

119+
// Loads elements from an array and constructs a vector.
120+
Value *loadVectorFromArray(IRBuilder<> &B, FixedVectorType *TargetType,
121+
ArrayType *SourceType, Value *Source) {
122+
// Ensure the element types of the array and vector are the same.
123+
assert(TargetType->getElementType() == SourceType->getElementType() &&
124+
"Element types of array and vector must be the same.");
125+
126+
// Load each element of the array.
127+
SmallVector<Value *, 4> LoadedElements;
128+
for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
129+
// Create a GEP to access the i-th element of the array.
130+
SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
131+
SmallVector<Value *, 4> Args;
132+
Args.push_back(B.getInt1(true));
133+
Args.push_back(Source);
134+
Args.push_back(B.getInt32(0));
135+
Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
136+
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
137+
GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr);
138+
139+
// Load the value from the element pointer.
140+
Value *Load = B.CreateLoad(TargetType->getElementType(), ElementPtr);
141+
buildAssignType(B, TargetType->getElementType(), Load);
142+
LoadedElements.push_back(Load);
143+
}
144+
145+
// Build the vector from the loaded elements.
146+
Value *NewVector = UndefValue::get(TargetType);
147+
buildAssignType(B, TargetType, NewVector);
148+
149+
for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
150+
Value *Index = B.getInt32(i);
151+
SmallVector<Type *, 4> Types = {TargetType, TargetType,
152+
TargetType->getElementType(),
153+
Index->getType()};
154+
SmallVector<Value *> Args = {NewVector, LoadedElements[i], Index};
155+
NewVector = B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
156+
buildAssignType(B, TargetType, NewVector);
157+
}
158+
return NewVector;
159+
}
160+
161+
// Stores elements from a vector into an array.
162+
void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
163+
Value *DstArrayPtr, ArrayType *ArrTy,
164+
Align Alignment) {
165+
auto *VecTy = cast<FixedVectorType>(SrcVector->getType());
166+
167+
// Ensure the element types of the array and vector are the same.
168+
assert(VecTy->getElementType() == ArrTy->getElementType() &&
169+
"Element types of array and vector must be the same.");
170+
171+
for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
172+
// Create a GEP to access the i-th element of the array.
173+
SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
174+
DstArrayPtr->getType()};
175+
SmallVector<Value *, 4> Args;
176+
Args.push_back(B.getInt1(true));
177+
Args.push_back(DstArrayPtr);
178+
Args.push_back(B.getInt32(0));
179+
Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
180+
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
181+
GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr);
182+
183+
// Extract the element from the vector and store it.
184+
Value *Index = B.getInt32(i);
185+
SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
186+
Index->getType()};
187+
SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
188+
Value *Element =
189+
B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs});
190+
buildAssignType(B, VecTy->getElementType(), Element);
191+
192+
Types = {Element->getType(), ElementPtr->getType()};
193+
Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
194+
B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
195+
}
196+
}
197+
119198
// Replaces the load instruction to get rid of the ptrcast used as source
120199
// operand.
121200
void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
@@ -154,6 +233,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
154233
// - float v = s.m;
155234
else if (SST && SST->getTypeAtIndex(0u) == ToTy)
156235
Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
236+
else if (SAT && DVT && SAT->getElementType() == DVT->getElementType())
237+
Output = loadVectorFromArray(B, DVT, SAT, OriginalOperand);
157238
else
158239
llvm_unreachable("Unimplemented implicit down-cast from load.");
159240

@@ -288,6 +369,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
288369
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
289370
auto *D_ST = dyn_cast<StructType>(ToTy);
290371
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
372+
auto *D_AT = dyn_cast<ArrayType>(ToTy);
291373

292374
B.SetInsertPoint(BadStore);
293375
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
@@ -296,6 +378,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
296378
storeVectorFromVector(B, Src, Dst, Alignment);
297379
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
298380
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
381+
else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
382+
storeArrayFromVector(B, Src, Dst, D_AT, Alignment);
299383
else
300384
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
301385

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32
5+
; CHECK-DAG: [[VEC4FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 4
6+
; CHECK-DAG: [[UINT_TYPE:%[0-9]+]] = OpTypeInt 32 0
7+
; CHECK-DAG: [[UINT4:%[0-9]+]] = OpConstant [[UINT_TYPE]] 4
8+
; CHECK-DAG: [[ARRAY4FLOAT:%[0-9]+]] = OpTypeArray [[FLOAT]] [[UINT4]]
9+
; CHECK-DAG: [[PTR_ARRAY4FLOAT:%[0-9]+]] = OpTypePointer Private [[ARRAY4FLOAT]]
10+
; CHECK-DAG: [[G_IN:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
11+
; CHECK-DAG: [[G_OUT:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
12+
; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT_TYPE]] 0
13+
; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT_TYPE]] 1
14+
; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT_TYPE]] 2
15+
; CHECK-DAG: [[UINT3:%[0-9]+]] = OpConstant [[UINT_TYPE]] 3
16+
; CHECK-DAG: [[PTR_FLOAT:%[0-9]+]] = OpTypePointer Private [[FLOAT]]
17+
; CHECK-DAG: [[UNDEF_VEC:%[0-9]+]] = OpUndef [[VEC4FLOAT]]
18+
19+
@G_in = internal addrspace(10) global [4 x float] zeroinitializer
20+
@G_out = internal addrspace(10) global [4 x float] zeroinitializer
21+
22+
define spir_func void @main() {
23+
entry:
24+
; CHECK: [[GEP0:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT0]]
25+
; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP0]]
26+
; CHECK-NEXT: [[GEP1:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT1]]
27+
; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP1]]
28+
; CHECK-NEXT: [[GEP2:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT2]]
29+
; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP2]]
30+
; CHECK-NEXT: [[GEP3:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT3]]
31+
; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP3]]
32+
; CHECK-NEXT: [[VEC_INSERT0:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD0]] [[UNDEF_VEC]] 0
33+
; CHECK-NEXT: [[VEC_INSERT1:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD1]] [[VEC_INSERT0]] 1
34+
; CHECK-NEXT: [[VEC_INSERT2:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD2]] [[VEC_INSERT1]] 2
35+
; CHECK-NEXT: [[VEC:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD3]] [[VEC_INSERT2]] 3
36+
%0 = load <4 x float>, ptr addrspace(10) @G_in, align 64
37+
38+
; CHECK-NEXT: [[GEP_OUT0:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT0]]
39+
; CHECK-NEXT: [[VEC_EXTRACT0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 0
40+
; CHECK-NEXT: OpStore [[GEP_OUT0]] [[VEC_EXTRACT0]]
41+
; CHECK-NEXT: [[GEP_OUT1:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT1]]
42+
; CHECK-NEXT: [[VEC_EXTRACT1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 1
43+
; CHECK-NEXT: OpStore [[GEP_OUT1]] [[VEC_EXTRACT1]]
44+
; CHECK-NEXT: [[GEP_OUT2:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT2]]
45+
; CHECK-NEXT: [[VEC_EXTRACT2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 2
46+
; CHECK-NEXT: OpStore [[GEP_OUT2]] [[VEC_EXTRACT2]]
47+
; CHECK-NEXT: [[GEP_OUT3:%[0-9]+]] = OpInBoundsAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT3]]
48+
; CHECK-NEXT: [[VEC_EXTRACT3:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 3
49+
; CHECK-NEXT: OpStore [[GEP_OUT3]] [[VEC_EXTRACT3]]
50+
store <4 x float> %0, ptr addrspace(10) @G_out, align 64
51+
52+
; CHECK-NEXT: OpReturn
53+
ret void
54+
}

0 commit comments

Comments
 (0)