Skip to content

Commit 9f72fab

Browse files
authored
[SPIRV] Fix vector bitcast check in LegalizePointerCast (llvm#164997)
The previous check for vector bitcasts in `loadVectorFromVector` only compared the number of elements, which is insufficient when the element types differ. This can lead to incorrect assumptions about the validity of the cast. This commit replaces the element count check with a comparison of the total size of the vectors in bits. This ensures that the bitcast is only performed between vectors of the same size, preventing potential miscompilations. Part of llvm#153091
1 parent 19bf0ad commit 9f72fab

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {
7373
// Returns the loaded value.
7474
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
7575
FixedVectorType *TargetType, Value *Source) {
76-
assert(TargetType->getNumElements() <= SourceType->getNumElements());
7776
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
7877
buildAssignType(B, SourceType, NewLoad);
7978
Value *AssignValue = NewLoad;
8079
if (TargetType->getElementType() != SourceType->getElementType()) {
80+
const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
81+
[[maybe_unused]] TypeSize TargetTypeSize =
82+
DL.getTypeSizeInBits(TargetType);
83+
[[maybe_unused]] TypeSize SourceTypeSize =
84+
DL.getTypeSizeInBits(SourceType);
85+
assert(TargetTypeSize == SourceTypeSize);
8186
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
8287
{TargetType, SourceType}, {NewLoad});
8388
buildAssignType(B, TargetType, AssignValue);
89+
return AssignValue;
8490
}
8591

92+
assert(TargetType->getNumElements() < SourceType->getNumElements());
8693
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
8794
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
8895
Mask[I] = I;

llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
define void @case1() local_unnamed_addr {
1717
; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
1818
; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
19-
; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
2019
%1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
2120
%2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2)
2221
%3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
@@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr {
2928
define void @case2() local_unnamed_addr {
3029
; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
3130
; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
32-
; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
33-
; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2
31+
; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2
3432
%1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
3533
%2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3)
3634
%3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)

llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,25 @@ entry:
2626
store <4 x i32> %6, ptr addrspace(11) %7, align 16
2727
ret void
2828
}
29+
30+
; This tests a load from a pointer that has been bitcast between vector types
31+
; which share the same total bit-width but have different numbers of elements.
32+
; Tests that legalize-pointer-casts works correctly by moving the bitcast to
33+
; the element that was loaded.
34+
35+
define void @main2() local_unnamed_addr #0 {
36+
entry:
37+
; CHECK: %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}}
38+
; CHECK: %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]]
39+
; CHECK: %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]]
40+
; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}}
41+
42+
%0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
43+
%2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0)
44+
%3 = load <4 x i32>, ptr addrspace(11) %2
45+
%4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1)
46+
store <4 x i32> %3, ptr addrspace(11) %4
47+
ret void
48+
}
49+
50+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

0 commit comments

Comments
 (0)