Skip to content

Commit 0372fb7

Browse files
authored
Fix assertion on splat of groupshared scalar (microsoft#6930)
When splatting a groupshared scalar, we would trip an "Invalid constantexpr cast!" assertion. This would happen while evaluating the ImplicitCastExpr to turn the groupshared scalar into a vector because the scalar expression was in a different address space (groupshared) vs the target vector (local). The fix is to ensure that when looking up the vector member expression, insert an lvalue-to-rvalue cast if necessary; i.e. when a swizzle contains duplicate elements.
1 parent d9a5e97 commit 0372fb7

File tree

4 files changed

+38
-12
lines changed

4 files changed

+38
-12
lines changed

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8444,8 +8444,19 @@ ExprResult HLSLExternalSource::LookupVectorMemberExprForHLSL(
84448444
ExprValueKind VK = positions.ContainsDuplicateElements()
84458445
? VK_RValue
84468446
: (IsArrow ? VK_LValue : BaseExpr.getValueKind());
8447-
HLSLVectorElementExpr *vectorExpr = new (m_context) HLSLVectorElementExpr(
8448-
resultType, VK, &BaseExpr, *member, MemberLoc, positions);
8447+
8448+
Expr *E = &BaseExpr;
8449+
// Insert an lvalue-to-rvalue cast if necessary
8450+
if (BaseExpr.getValueKind() == VK_LValue && VK == VK_RValue) {
8451+
// Remove qualifiers from result type and cast target type
8452+
resultType = resultType.getUnqualifiedType();
8453+
auto targetType = E->getType().getUnqualifiedType();
8454+
E = ImplicitCastExpr::Create(*m_context, targetType,
8455+
CastKind::CK_LValueToRValue, E, nullptr,
8456+
VK_RValue);
8457+
}
8458+
HLSLVectorElementExpr *vectorExpr = new (m_context)
8459+
HLSLVectorElementExpr(resultType, VK, E, *member, MemberLoc, positions);
84498460

84508461
return vectorExpr;
84518462
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %dxc -E main -T cs_6_0 -fcgl %s | FileCheck %s
2+
3+
// Validate that when swizzling requires an r-value (i.e. duplicate elements), that the result is stored to
4+
// and loaded from a temporary.
5+
// CHECK: store <1 x i32> %splat.splat, <1 x i32>* %tmp
6+
// CHECK-NEXT: %1 = load <1 x i32>, <1 x i32>* %tmp
7+
// CHECK-NEXT: %2 = shufflevector <1 x i32> %1, <1 x i32> undef, <4 x i32> zeroinitializer
8+
// CHECK-NEXT: store <4 x i32> %2, <4 x i32>* %x
9+
10+
groupshared int a;
11+
[numthreads(64, 1, 1)]
12+
void main() {
13+
a = 123;
14+
int4 x = (a).xxxx;
15+
}

tools/clang/test/CodeGenSPIRV/op.vector.swizzle.hlsl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,18 @@ void main() {
118118
// Keep lhs.1
119119
// So final selectors to write to lhs.(0, 1, 2, 3): 6, 1, 4, 5
120120
// CHECK-NEXT: [[v22:%[0-9]+]] = OpLoad %v2float %v2f
121-
// CHECK-NEXT: [[vs15:%[0-9]+]] = OpVectorShuffle %v3float [[v22]] [[v22]] 0 1 0
121+
// CHECK-NEXT: [[vs15:%[0-9]+]] = OpVectorShuffle %v2float [[v22]] [[v22]] 1 0
122+
// CHECK-NEXT: [[vs16:%[0-9]+]] = OpVectorShuffle %v3float [[vs15]] [[vs15]] 1 0 1
122123
// CHECK-NEXT: [[v23:%[0-9]+]] = OpLoad %v4float %v4f2
123-
// CHECK-NEXT: [[vs16:%[0-9]+]] = OpVectorShuffle %v4float [[v23]] [[vs15]] 6 1 4 5
124-
// CHECK-NEXT: OpStore %v4f2 [[vs16]]
124+
// CHECK-NEXT: [[vs17:%[0-9]+]] = OpVectorShuffle %v4float [[v23]] [[vs16]] 6 1 4 5
125+
// CHECK-NEXT: OpStore %v4f2 [[vs17]]
125126
v4f2.wzx.grb = v2f.gr.yxy; // select more than original, write to a part
126127

127128
// CHECK-NEXT: [[v24:%[0-9]+]] = OpLoad %v4float %v4f1
128129
// CHECK-NEXT: OpStore %v4f2 [[v24]]
129130
v4f2.wzyx.abgr.xywz.rgab = v4f1.xyzw.xyzw.rgab.rgab; // from original vector to original vector
130-
131-
// CHECK-NEXT: [[v24_0:%[0-9]+]] = OpAccessChain %_ptr_Function_float %v4f1 %int_2
132-
// CHECK-NEXT: [[ce1:%[0-9]+]] = OpLoad %float [[v24_0]]
131+
// CHECK-NEXT: [[v24_0:%[0-9]+]] = OpLoad %v4float %v4f1
132+
// CHECK-NEXT: [[ce1:%[0-9]+]] = OpCompositeExtract %float [[v24_0]] 2
133133
// CHECK-NEXT: [[ac4:%[0-9]+]] = OpAccessChain %_ptr_Function_float %v4f2 %int_1
134134
// CHECK-NEXT: OpStore [[ac4]] [[ce1]]
135135
v4f2.wzyx.zy.x = v4f1.xzyx.y.x; // from one element (rvalue) to one element (lvalue)

tools/clang/test/CodeGenSPIRV/op.vector.swizzle.size1.hlsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ void main(float4 input: INPUT) {
4343

4444
// Selecting from resources
4545
// CHECK: [[fptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v4float %PerFrame %int_0 %uint_5 %int_0
46-
// CHECK-NEXT: [[elem:%[0-9]+]] = OpAccessChain %_ptr_Uniform_float [[fptr]] %int_3
47-
// CHECK-NEXT: {{%[0-9]+}} = OpLoad %float [[elem]]
46+
// CHECK-NEXT: [[val:%[0-9]+]] = OpLoad %v4float [[fptr]]
47+
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeExtract %float [[val]] 3
4848
v4f = input * PerFrame[5].f.www.r;
4949
// CHECK: [[fptr_0:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v4float %PerFrame %int_0 %uint_6 %int_0
50-
// CHECK-NEXT: [[elem_0:%[0-9]+]] = OpAccessChain %_ptr_Uniform_float [[fptr_0]] %int_2
51-
// CHECK-NEXT: {{%[0-9]+}} = OpLoad %float [[elem_0]]
50+
// CHECK-NEXT: [[val_0:%[0-9]+]] = OpLoad %v4float [[fptr_0]]
51+
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeExtract %float [[val_0]] 2
5252
sf = PerFrame[6].f.zzz.r * input.y;
5353
}

0 commit comments

Comments
 (0)