Skip to content

Commit 6b4b0eb

Browse files
authored
[SPIRV] Implements struct-to-int casting (microsoft#5492)
Allows casting a struct to an integer type of at least the same size, which worked for the DXIL target but not previously for SPIR-V.
1 parent ea4aca9 commit 6b4b0eb

File tree

3 files changed

+154
-1
lines changed

3 files changed

+154
-1
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3535,7 +3535,8 @@ SpirvInstruction *SpirvEmitter::processFlatConversion(
35353535
// one member, S, then (T)<an-instance-of-S> is allowed, which essentially
35363536
// constructs a new T instance using the instance of S as its only member.
35373537
// Check whether we are handling that case here first.
3538-
if (field->getType().getCanonicalType() == initType.getCanonicalType()) {
3538+
if (!field->isBitField() &&
3539+
field->getType().getCanonicalType() == initType.getCanonicalType()) {
35393540
fields.push_back(initInstr);
35403541
} else {
35413542
fields.push_back(processFlatConversion(field->getType(), initType,
@@ -8164,6 +8165,41 @@ SpirvInstruction *SpirvEmitter::castToInt(SpirvInstruction *fromVal,
81648165
}
81658166
}
81668167

8168+
if (const auto *recordType = fromType->getAs<RecordType>()) {
8169+
// This code is bogus but approximates the current (unspec'd)
8170+
// behavior for the DXIL target.
8171+
assert(recordType->isStructureType());
8172+
8173+
auto fieldDecl = recordType->getDecl()->field_begin();
8174+
QualType fieldType = fieldDecl->getType();
8175+
QualType elemType = {};
8176+
SpirvInstruction *firstField;
8177+
8178+
if (isVectorType(fieldType, &elemType)) {
8179+
fieldType = elemType;
8180+
firstField = spvBuilder.createCompositeExtract(fieldType, fromVal, {0, 0},
8181+
srcLoc, srcRange);
8182+
} else {
8183+
firstField = spvBuilder.createCompositeExtract(fieldType, fromVal, {0},
8184+
srcLoc, srcRange);
8185+
if (fieldDecl->isBitField()) {
8186+
auto offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
8187+
llvm::APInt(32, 0));
8188+
auto width = spvBuilder.getConstantInt(
8189+
astContext.UnsignedIntTy,
8190+
llvm::APInt(32, fieldDecl->getBitWidthValue(astContext)));
8191+
firstField = spvBuilder.createBitFieldExtract(
8192+
fieldType, firstField, offset, width,
8193+
toIntType->hasSignedIntegerRepresentation(), srcLoc);
8194+
}
8195+
}
8196+
8197+
SpirvInstruction *result =
8198+
castToInt(firstField, fieldType, toIntType, srcLoc, srcRange);
8199+
result->setLayoutRule(fromVal->getLayoutRule());
8200+
return result;
8201+
}
8202+
81678203
return nullptr;
81688204
}
81698205

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// RUN: %dxc -T cs_6_4 -HV 2021 -E main
2+
3+
struct ColorRGB {
4+
uint R : 8;
5+
uint G : 8;
6+
uint B : 8;
7+
};
8+
9+
struct ColorRGBA {
10+
uint R : 8;
11+
uint G : 8;
12+
uint B : 8;
13+
uint A : 8;
14+
};
15+
16+
struct TwoColors {
17+
ColorRGBA rgba1;
18+
ColorRGBA rgba2;
19+
};
20+
21+
struct Mixed {
22+
float f;
23+
uint i;
24+
};
25+
26+
struct Vectors {
27+
uint2 p1;
28+
uint2 p2;
29+
};
30+
31+
RWStructuredBuffer<uint> buf : r0;
32+
RWStructuredBuffer<uint64_t> lbuf : r1;
33+
34+
// CHECK: OpName [[BUF:%[^ ]*]] "buf"
35+
// CHECK: OpName [[LBUF:%[^ ]*]] "lbuf"
36+
// CHECK: OpName [[COLORRGB:%[^ ]*]] "ColorRGB"
37+
// CHECK: OpName [[COLORRGBA:%[^ ]*]] "ColorRGBA"
38+
// CHECK: OpName [[TWOCOLORS:%[^ ]*]] "TwoColors"
39+
// CHECK: OpName [[VECTORS:%[^ ]*]] "Vectors"
40+
// CHECK: OpName [[MIXED:%[^ ]*]] "Mixed"
41+
42+
[numthreads(1,1,1)]
43+
void main()
44+
{
45+
ColorRGB rgb;
46+
ColorRGBA c0;
47+
ColorRGBA c1;
48+
TwoColors colors;
49+
Vectors v;
50+
Mixed m = {-1.0, 1};
51+
rgb.R = 127;
52+
rgb.G = 127;
53+
rgb.B = 127;
54+
c0.R = 255;
55+
c0.G = 127;
56+
c0.B = 63;
57+
c0.A = 31;
58+
c1.R = 15;
59+
c1.G = 7;
60+
c1.B = 3;
61+
c1.A = 1;
62+
colors.rgba1 = c0;
63+
colors.rgba2 = c1;
64+
v.p1.x = 3;
65+
v.p1.y = 2;
66+
v.p2.x = 1;
67+
v.p2.y = 0;
68+
69+
// CHECK-DAG: [[FLOAT:%[^ ]*]] = OpTypeFloat 32
70+
// CHECK-DAG: [[FN1:%[^ ]*]] = OpConstant [[FLOAT]] -1
71+
// CHECK-DAG: [[UINT:%[^ ]*]] = OpTypeInt 32 0
72+
// CHECK-DAG: [[U127:%[^ ]*]] = OpConstant [[UINT]] 127
73+
// CHECK-DAG: [[INT:%[^ ]*]] = OpTypeInt 32 1
74+
// CHECK-DAG: [[I0:%[^ ]*]] = OpConstant [[INT]] 0
75+
// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0
76+
// CHECK-DAG: [[U8:%[^ ]*]] = OpConstant [[UINT]] 8
77+
// CHECK-DAG: [[U255:%[^ ]*]] = OpConstant [[UINT]] 255
78+
// CHECK-DAG: [[U3:%[^ ]*]] = OpConstant [[UINT]] 3
79+
// CHECK-DAG: [[ULONG:%[^ ]*]] = OpTypeInt 64 0
80+
// CHECK-DAG: [[DOUBLE:%[^ ]*]] = OpTypeFloat 64
81+
82+
buf[0] = (uint) colors;
83+
// CHECK: [[COLORS:%[^ ]*]] = OpLoad [[TWOCOLORS]]
84+
// CHECK: [[COLORS0:%[^ ]*]] = OpCompositeExtract [[COLORRGBA]] [[COLORS]] 0
85+
// CHECK: [[COLORS00:%[^ ]*]] = OpCompositeExtract [[UINT]] [[COLORS0]] 0
86+
// CHECK: [[COLORS000:%[^ ]*]] = OpBitFieldUExtract [[UINT]] [[COLORS00]] [[U0]] [[U8]]
87+
// CHECK: [[BUF00:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[BUF]] [[I0]] [[U0]]
88+
// CHECK: OpStore [[BUF00]] [[COLORS000]]
89+
90+
buf[0] -= (uint) rgb;
91+
// CHECK: [[RGB:%[^ ]*]] = OpLoad [[COLORRGB]]
92+
// CHECK: [[RGB0:%[^ ]*]] = OpCompositeExtract [[UINT]] [[RGB]] 0
93+
// CHECK: [[RGB00:%[^ ]*]] = OpBitFieldUExtract [[UINT]] [[RGB0]] [[U0]] [[U8]]
94+
// CHECK: [[BUF00:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[BUF]] [[I0]] [[U0]]
95+
// CHECK: [[V1:%[^ ]*]] = OpLoad [[UINT]] [[BUF00]]
96+
// CHECK: [[V2:%[^ ]*]] = OpISub [[UINT]] [[V1]] [[RGB00]]
97+
// CHECK: OpStore [[BUF00]] [[V2]]
98+
99+
lbuf[0] = (uint64_t) v;
100+
// CHECK: [[VECS:%[^ ]*]] = OpLoad [[VECTORS]]
101+
// CHECK: [[VECS00:%[^ ]*]] = OpCompositeExtract [[UINT]] [[VECS]] 0 0
102+
// CHECK: [[V1:%[^ ]*]] = OpUConvert [[ULONG]] [[VECS00]]
103+
// CHECK: [[LBUF00:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[LBUF]] [[I0]] [[U0]]
104+
// CHECK: OpStore [[LBUF00]] [[V1]]
105+
106+
lbuf[0] += (uint64_t) m;
107+
// CHECK: [[MIX:%[^ ]*]] = OpLoad [[MIXED]]
108+
// CHECK: [[MIX0:%[^ ]*]] = OpCompositeExtract [[FLOAT]] [[MIX]] 0
109+
// CHECK: [[V1:%[^ ]*]] = OpFConvert [[DOUBLE]] [[MIX0]]
110+
// CHECK: [[V2:%[^ ]*]] = OpConvertFToU [[ULONG]] [[V1]]
111+
// CHECK: [[LBUF00:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[LBUF]] [[I0]] [[U0]]
112+
// CHECK: [[V3:%[^ ]*]] = OpLoad [[ULONG]] [[LBUF00]]
113+
// CHECK: [[V4:%[^ ]*]] = OpIAdd [[ULONG]] [[V3]] [[V2]]
114+
// CHECK: OpStore [[LBUF00]] [[V4]]
115+
}
116+

tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ TEST_F(FileTest, CastImplicitVecToMat) {
377377
runFileTest("cast.vec-to-mat.implicit.hlsl");
378378
}
379379
TEST_F(FileTest, CastMatrixToVector) { runFileTest("cast.mat-to-vec.hlsl"); }
380+
TEST_F(FileTest, CastStructToInt) { runFileTest("cast.struct-to-int.hlsl"); }
380381
TEST_F(FileTest, CastBitwidth) { runFileTest("cast.bitwidth.hlsl"); }
381382

382383
TEST_F(FileTest, CastLiteralTypeForArraySubscript) {

0 commit comments

Comments
 (0)