Skip to content

Commit b0ad9c2

Browse files
authored
[SPIR-V] Fix asdouble issue in SPIRV codegen to correctly generate OpBitCast instruction. (#161891)
Generate `OpBitCast` instruction for pointer cast operation if the element type is different. The HLSL for the unit test is ```hlsl StructuredBuffer<uint2> In : register(t0); RWStructuredBuffer<double2> Out : register(u2); [numthreads(1,1,1)] void main() { Out[0] = asdouble(In[0], In[1]); } ``` Resolves #153513
1 parent be29612 commit b0ad9c2

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,31 @@ class SPIRVLegalizePointerCast : public FunctionPass {
188188
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
189189
FixedVectorType *DstType =
190190
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
191-
assert(DstType->getNumElements() >= SrcType->getNumElements());
191+
auto dstNumElements = DstType->getNumElements();
192+
auto srcNumElements = SrcType->getNumElements();
193+
194+
// if the element type differs, it is a bitcast.
195+
if (DstType->getElementType() != SrcType->getElementType()) {
196+
// Support bitcast between vectors of different sizes only if
197+
// the total bitwidth is the same.
198+
auto dstBitWidth =
199+
DstType->getElementType()->getScalarSizeInBits() * dstNumElements;
200+
auto srcBitWidth =
201+
SrcType->getElementType()->getScalarSizeInBits() * srcNumElements;
202+
assert(dstBitWidth == srcBitWidth &&
203+
"Unsupported bitcast between vectors of different sizes.");
204+
205+
Src =
206+
B.CreateIntrinsic(Intrinsic::spv_bitcast, {DstType, SrcType}, {Src});
207+
buildAssignType(B, DstType, Src);
208+
SrcType = DstType;
209+
210+
StoreInst *SI = B.CreateStore(Src, Dst);
211+
SI->setAlignment(Alignment);
212+
return SI;
213+
}
192214

215+
assert(DstType->getNumElements() >= SrcType->getNumElements());
193216
LoadInst *LI = B.CreateLoad(DstType, Dst);
194217
LI->setAlignment(Alignment);
195218
Value *OldValues = LI;
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2
6+
; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
7+
; CHECK-DAG: %[[#v2_double:]] = OpTypeVector %[[#double]] 2
8+
; CHECK-DAG: %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4
9+
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
10+
@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
11+
12+
define void @main() local_unnamed_addr #0 {
13+
entry:
14+
%0 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
15+
%1 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
16+
%2 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 0)
17+
%3 = load <2 x i32>, ptr addrspace(11) %2, align 8
18+
%4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 1)
19+
%5 = load <2 x i32>, ptr addrspace(11) %4, align 8
20+
; CHECK: %[[#tmp:]] = OpVectorShuffle %[[#v4_uint]] {{%[0-9]+}} {{%[0-9]+}} 0 2 1 3
21+
%6 = shufflevector <2 x i32> %3, <2 x i32> %5, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
22+
; CHECK: %[[#access:]] = OpAccessChain {{.*}}
23+
%7 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %1, i32 0)
24+
; CHECK: %[[#bitcast:]] = OpBitcast %[[#v2_double]] %[[#tmp]]
25+
; CHECK: OpStore %[[#access]] %[[#bitcast]] Aligned 16
26+
store <4 x i32> %6, ptr addrspace(11) %7, align 16
27+
ret void
28+
}

0 commit comments

Comments
 (0)