Skip to content

Commit 18f51e9

Browse files
committed
[SPIRV] Folding global constant variables
At the beginning, if we use static constant variables of template structures with library profile, we might get undefined values. We considered this is a bug. After thorough debugging, this appears to be expected behavior -- we need to initialize global static variables at the right time. In the common stage (non-library profile), these variables are initialized at the start of the main function before calling the user's main function. But in the library profile, we cannot easily initialize them... hence the undefined values. What if global variables were compile-time constants? We don't care how they're initialized -- they're COMPILE-TIME CONSTANTS! Fold them and promote them! On the other hand, this seems benefits all shaders, not just library profile. We can generate smaller SPIR-V code with constant folding. This change cannot fix all global static variables initialization issues; it only addresses issues for global compile-time constant variables. Fixes: #7049
1 parent d72f75e commit 18f51e9

12 files changed

+129
-71
lines changed

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
10061006
// implicit VarDecl. All implicit VarDecls are lazily created in order to
10071007
// avoid creating large number of unused variables/constants/enums.
10081008
if (!info) {
1009-
tryToCreateImplicitConstVar(decl);
1009+
tryToCreateConstantVar(decl);
10101010
info = getDeclSpirvInfo(decl);
10111011
}
10121012

@@ -4850,19 +4850,57 @@ SpirvVariable *DeclResultIdMapper::createRayTracingNVStageVar(
48504850
return retVal;
48514851
}
48524852

4853-
void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
4853+
bool DeclResultIdMapper::tryToCreateConstantVar(const ValueDecl *decl) {
4854+
// TODO: support spirv basic type with constant intrinsic. (e.g. int8)
48544855
const VarDecl *varDecl = dyn_cast<VarDecl>(decl);
4855-
if (!varDecl || !varDecl->isImplicit())
4856-
return;
4856+
if (!varDecl)
4857+
return false;
4858+
4859+
const BuiltinType *type = decl->getType()->getAs<BuiltinType>();
4860+
if (!type)
4861+
return false;
48574862

48584863
APValue *val = varDecl->evaluateValue();
48594864
if (!val)
4860-
return;
4865+
return false;
48614866

4862-
SpirvInstruction *constVal =
4863-
spvBuilder.getConstantInt(astContext.UnsignedIntTy, val->getInt());
4867+
SpirvInstruction *constVal = nullptr;
4868+
switch (type->getKind()) {
4869+
case BuiltinType::Bool: // bool
4870+
constVal = spvBuilder.getConstantBool(val->getInt().getExtValue());
4871+
break;
4872+
case BuiltinType::UShort: // uint16_t
4873+
constVal =
4874+
spvBuilder.getConstantInt(astContext.UnsignedShortTy, val->getInt());
4875+
break;
4876+
case BuiltinType::UInt: // uint32_t
4877+
constVal =
4878+
spvBuilder.getConstantInt(astContext.UnsignedIntTy, val->getInt());
4879+
break;
4880+
case BuiltinType::Short: // int16_t
4881+
constVal = spvBuilder.getConstantInt(astContext.ShortTy, val->getInt());
4882+
break;
4883+
case BuiltinType::Int: // int32_t
4884+
constVal = spvBuilder.getConstantInt(astContext.IntTy, val->getInt());
4885+
break;
4886+
case BuiltinType::Half: // float16_t
4887+
constVal = spvBuilder.getConstantFloat(astContext.HalfTy, val->getFloat());
4888+
break;
4889+
case BuiltinType::Float: // float32_t
4890+
case BuiltinType::HalfFloat: // float16_t without -enable-16bit-types
4891+
constVal = spvBuilder.getConstantFloat(astContext.FloatTy, val->getFloat());
4892+
break;
4893+
case BuiltinType::Double: // float64_t
4894+
constVal =
4895+
spvBuilder.getConstantFloat(astContext.DoubleTy, val->getFloat());
4896+
break;
4897+
default:
4898+
assert(false && "Unsupported builtin type evaluation at compile-time");
4899+
return false;
4900+
}
48644901
constVal->setRValue(true);
48654902
registerVariableForDecl(varDecl, constVal);
4903+
return true;
48664904
}
48674905

48684906
void DeclResultIdMapper::decorateWithIntrinsicAttrs(

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,14 @@ class DeclResultIdMapper {
351351
/// \brief Sets the entry function.
352352
void setEntryFunction(SpirvFunction *fn) { entryFunction = fn; }
353353

354-
/// \brief If the given decl is an implicit VarDecl that evaluates to a
355-
/// constant, it evaluates the constant and registers the resulting SPIR-V
356-
/// instruction in the astDecls map. Otherwise returns without doing anything.
354+
/// \brief If the given decl is a VarDecl that evaluates to a constant, it
355+
/// evaluates the constant and registers the resulting SPIR-V instruction in
356+
/// the astDecls map. Otherwise returns without doing anything. The typical
357+
/// cases are implicit VarDecls and global static constant variables.
357358
///
358359
/// Note: There are many cases where the front-end might create such implicit
359360
/// VarDecls (such as some ray tracing enums).
360-
void tryToCreateImplicitConstVar(const ValueDecl *);
361+
bool tryToCreateConstantVar(const ValueDecl *);
361362

362363
/// \brief Creates instructions to copy output stage variables defined by
363364
/// outputPatchDecl to hullMainOutputPatch that is a variable for the

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,9 +2148,12 @@ void SpirvEmitter::doVarDecl(const VarDecl *decl) {
21482148
// We already know the variable is not externally visible here. If it does
21492149
// not have local storage, it should be file scope variable.
21502150
const bool isFileScopeVar = !decl->hasLocalStorage();
2151-
if (isFileScopeVar)
2151+
if (isFileScopeVar) {
2152+
if (decl->getType().isConstQualified() &&
2153+
declIdMapper.tryToCreateConstantVar(decl))
2154+
return;
21522155
var = declIdMapper.createFileVar(decl, llvm::None);
2153-
else
2156+
} else
21542157
var = declIdMapper.createFnVar(decl, llvm::None);
21552158

21562159
// Emit OpStore to initialize the variable

tools/clang/test/CodeGenSPIRV/nested.static.var.hlsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: %dxc -T cs_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s
22

3-
// Check that the variable `value` is defined, and set to 6 in the entry point wrapper.
4-
// CHECK: %value = OpVariable %_ptr_Private_uint Private
5-
// CHECK: %main = OpFunction %void None
3+
// Check that the folded constant value `6u` is stored into buffer.
4+
// CHECK: %src_main = OpFunction %void None
65
// CHECK-NEXT: OpLabel
7-
// CHECK-NEXT: OpStore %value %uint_6
6+
// CHECK-NEXT: [[BUFFER_A_0:%.*]] = OpAccessChain %_ptr_Uniform_uint %a %int_0 %uint_0
7+
// CHECK-NEXT: OpStore [[BUFFER_A_0]] %uint_6
88

99
struct A {
1010
struct B { static const uint value = 6u; };

tools/clang/test/CodeGenSPIRV/oo.static.member.init.hlsl

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,24 @@ class T {
1414

1515
const int T::SIX = 6;
1616

17+
// CHECK-DAG: %int_5 = OpConstant %int 5
18+
// CHECK-DAG: %int_6 = OpConstant %int 6
19+
1720
int foo(int val) { return val; }
1821

19-
// CHECK: %FIVE = OpVariable %_ptr_Private_int Private
20-
// CHECK: %SIX = OpVariable %_ptr_Private_int Private
21-
// CHECK: %FIVE_0 = OpVariable %_ptr_Private_int Private
22-
// CHECK: %SIX_0 = OpVariable %_ptr_Private_int Private
2322
int main() : A {
24-
// CHECK-LABEL: %main = OpFunction
25-
26-
// CHECK: OpStore %FIVE %int_5
27-
// CHECK: OpStore %SIX %int_6
28-
// CHECK: OpStore %FIVE_0 %int_5
29-
// CHECK: OpStore %SIX_0 %int_6
30-
// CHECK: OpFunctionCall %int %src_main
31-
3223
// CHECK-LABEL: %src_main = OpFunction
33-
34-
// CHECK: OpLoad %int %FIVE
35-
// CHECK: OpLoad %int %SIX
36-
// CHECK: OpLoad %int %FIVE_0
37-
// CHECK: OpLoad %int %SIX_0
38-
return foo(S::FIVE) + foo(S::SIX) + foo(T::FIVE) + foo(T::SIX);
24+
return
25+
// CHECK: OpStore [[FOO_PARAM_0:%.*]] %int_5
26+
// CHECK-NEXT: [[FOO_RETURN_0:%.*]] = OpFunctionCall %int %foo [[FOO_PARAM_0]]
27+
foo(S::FIVE) +
28+
// CHECK: OpStore [[FOO_PARAM_1:%.*]] %int_6
29+
// CHECK-NEXT: [[FOO_RETURN_1:%.*]] = OpFunctionCall %int %foo [[FOO_PARAM_1]]
30+
foo(S::SIX) +
31+
// CHECK: OpStore [[FOO_PARAM_2:%.*]] %int_5
32+
// CHECK-NEXT: [[FOO_RETURN_2:%.*]] = OpFunctionCall %int %foo [[FOO_PARAM_2]]
33+
foo(T::FIVE) +
34+
// CHECK: OpStore [[FOO_PARAM_3:%.*]] %int_6
35+
// CHECK-NEXT: [[FOO_RETURN_3:%.*]] = OpFunctionCall %int %foo [[FOO_PARAM_3]]
36+
foo(T::SIX);
3937
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %dxc -O3 -T lib_6_8 -HV 2021 -spirv -fcgl -fspv-target-env=universal1.5 %s | FileCheck %s
2+
3+
static const uint32_t kIterCnt = 1024;
4+
static const uint32_t kSizeShared = 4096;
5+
static const uint32_t kOffset = 2048;
6+
7+
groupshared uint32_t sharedMem[kSizeShared * sizeof(uint32_t)];
8+
9+
export void testcase(uint32_t threadIdInGroup, uint32_t threadGroupSize) {
10+
// CHECK: OpULessThan %bool %{{.*}} %uint_1024
11+
for (uint32_t i = threadIdInGroup; i < kIterCnt; i += threadGroupSize) {
12+
// CHECK: [[INDEX:%.*]] = OpIAdd %uint %uint_2048
13+
// CHECK: OpAccessChain %_ptr_Workgroup_uint %sharedMem [[INDEX]]
14+
sharedMem[kOffset + i] = 0xffu;
15+
}
16+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %dxc -O3 -T lib_6_8 -HV 2021 -spirv -fcgl -fspv-target-env=universal1.5 -enable-16bit-types -Wno-gnu-static-float-init %s | FileCheck %s
2+
3+
template <typename S> struct Trait;
4+
template <> struct Trait<half> {
5+
using type = half;
6+
static const uint32_t size = 2;
7+
static const half value = 1.0;
8+
};
9+
10+
uint32_t get_size() { return Trait<half>::size; }
11+
float cvt(half x) { return float(x); }
12+
13+
// CHECK-LABEL: %testcase = OpFunction %float
14+
export float testcase(int x) {
15+
if (x == 2) {
16+
// CHECK: OpReturnValue %float_2
17+
return float(Trait<half>::size);
18+
} else if (x == 4) {
19+
// CHECK: OpStore %param_var_x %half_0x1p_0
20+
// CHECK-NEXT: OpFunctionCall %float %cvt %param_var_x
21+
return cvt(Trait<half>::value);
22+
} else if (x == 8) {
23+
return get_size();
24+
}
25+
return 0.f;
26+
}
27+
28+
// CHECK-LABEL: %get_size = OpFunction %uint
29+
// CHECK: OpReturnValue %uint_2

tools/clang/test/CodeGenSPIRV/template.class.partial.specialization.hlsl

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,6 @@ uint32_t elementCount()
1818

1919
RWBuffer<int> o;
2020

21-
// Initialize the static members at the start of wrapper
22-
// CHECK: %main = OpFunction %void None
23-
// CHECK: OpStore %RowCount %uint_4
24-
// CHECK: OpStore %ColumnCount %uint_4
25-
// CHECK: OpStore %RowCount_0 %uint_3
26-
// CHECK: OpStore %ColumnCount_0 %uint_2
27-
// CHECK: OpFunctionEnd
28-
29-
30-
3121
// CHECK: %src_main = OpFunction %void None
3222
[numthreads(64,1,1)]
3323
void main()
@@ -40,16 +30,12 @@ void main()
4030

4131
// CHECK: %elementCount = OpFunction %uint None
4232
// CHECK-NEXT: OpLabel
43-
// CHECK-NEXT: [[rc:%[0-9]+]] = OpLoad %uint %RowCount
44-
// CHECK-NEXT: [[cc:%[0-9]+]] = OpLoad %uint %ColumnCount
45-
// CHECK-NEXT: [[mul:%[0-9]+]] = OpIMul %uint [[rc]] [[cc]]
33+
// CHECK-NEXT: [[mul:%[0-9]+]] = OpIMul %uint %uint_4 %uint_4
4634
// CHECK-NEXT: OpReturnValue [[mul]]
4735
// CHECK-NEXT: OpFunctionEnd
4836

4937
// CHECK: %elementCount_0 = OpFunction %uint None
5038
// CHECK-NEXT: %bb_entry_1 = OpLabel
51-
// CHECK-NEXT: [[rc:%[0-9]+]] = OpLoad %uint %RowCount_0
52-
// CHECK-NEXT: [[cc:%[0-9]+]] = OpLoad %uint %ColumnCount_0
53-
// CHECK-NEXT: [[mul:%[0-9]+]] = OpIMul %uint [[rc]] [[cc]]
39+
// CHECK-NEXT: [[mul:%[0-9]+]] = OpIMul %uint %uint_3 %uint_2
5440
// CHECK-NEXT: OpReturnValue [[mul]]
5541
// CHECK-NEXT: OpFunctionEnd

tools/clang/test/CodeGenSPIRV/type.template.struct.template-instance.hlsl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,8 @@ struct Foo {
1313
};
1414

1515
void main() {
16-
// CHECK: [[bar_int:%[a-zA-Z0-9_]+]] = OpVariable %_ptr_Private_int Private
17-
// CHECK: [[bar_float:%[a-zA-Z0-9_]+]] = OpVariable %_ptr_Private_float Private
18-
19-
// CHECK: OpStore [[bar_int]] %int_0
20-
// CHECK: OpStore [[bar_float]] %float_0
21-
16+
// `Foo<int>::bar` is a global constant value,
17+
// it's folded at compile-time and no longer declare variable.
2218
Foo<int>::bar;
2319

2420
// CHECK: %x = OpVariable %_ptr_Function_int Function

tools/clang/test/CodeGenSPIRV/type.type-alias.template.hlsl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@ struct integral_constant {
88
template <bool val>
99
using bool_constant = integral_constant<bool, val>;
1010

11+
// CHECK: %true = OpConstantTrue %bool
1112
bool main(): SV_Target {
12-
// CHECK: OpStore %value %true
1313
// CHECK: %tru = OpVariable %_ptr_Function_integral_constant Function
1414
bool_constant<true> tru;
1515

16-
// CHECK: [[value:%[0-9]+]] = OpLoad %bool %value
17-
// CHECK: OpReturnValue [[value]]
16+
// CHECK: OpReturnValue %true
1817
return tru.value;
1918
}

0 commit comments

Comments
 (0)