Skip to content

Commit 56f3c40

Browse files
authored
[SPIR-V] Emit OpUndef for undefined values (microsoft#6686)
Before this change, OpConstantNull was emitted when an undef value was required. This causes an issue for some types which cannot have the OpConstantNull value. In addition, it mixed well-defined values with undefined values, which prevents any kind of optimization/analysis later on. Fixes microsoft#6653 --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent 84d39b6 commit 56f3c40

13 files changed

+167
-9
lines changed

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,7 @@ class SpirvBuilder {
748748
llvm::ArrayRef<SpirvConstant *> constituents,
749749
bool specConst = false);
750750
SpirvConstant *getConstantNull(QualType);
751+
SpirvUndef *getUndef(QualType);
751752

752753
SpirvString *createString(llvm::StringRef str);
753754
SpirvString *getString(llvm::StringRef str);

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class SpirvInstruction {
6767
IK_ConstantComposite,
6868
IK_ConstantNull,
6969

70+
// OpUndef
71+
IK_Undef,
72+
7073
// Function structure kinds
7174

7275
IK_FunctionParameter, // OpFunctionParameter
@@ -1302,6 +1305,22 @@ class SpirvConstantNull : public SpirvConstant {
13021305
bool operator==(const SpirvConstantNull &that) const;
13031306
};
13041307

1308+
class SpirvUndef : public SpirvInstruction {
1309+
public:
1310+
SpirvUndef(QualType type);
1311+
1312+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvUndef)
1313+
1314+
// For LLVM-style RTTI
1315+
static bool classof(const SpirvInstruction *inst) {
1316+
return inst->getKind() == IK_Undef;
1317+
}
1318+
1319+
bool operator==(const SpirvUndef &that) const;
1320+
1321+
bool invokeVisitor(Visitor *v) override;
1322+
};
1323+
13051324
/// \brief OpCompositeConstruct instruction
13061325
class SpirvCompositeConstruct : public SpirvInstruction {
13071326
public:

tools/clang/include/clang/SPIRV/SpirvModule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ class SpirvModule {
142142
// Adds a constant to the module.
143143
void addConstant(SpirvConstant *);
144144

145+
// Adds an Undef to the module.
146+
void addUndef(SpirvUndef *);
147+
145148
// Adds given string to the module which will be emitted via OpString.
146149
void addString(SpirvString *);
147150

@@ -202,6 +205,7 @@ class SpirvModule {
202205
decorations;
203206

204207
std::vector<SpirvConstant *> constants;
208+
std::vector<SpirvUndef *> undefs;
205209
std::vector<SpirvVariable *> variables;
206210
// A vector of functions in the module in the order that they should be
207211
// emitted. The order starts with the entry-point function followed by a

tools/clang/include/clang/SPIRV/SpirvVisitor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class Visitor {
8989
DEFINE_VISIT_METHOD(SpirvConstantFloat)
9090
DEFINE_VISIT_METHOD(SpirvConstantComposite)
9191
DEFINE_VISIT_METHOD(SpirvConstantNull)
92+
DEFINE_VISIT_METHOD(SpirvUndef)
9293
DEFINE_VISIT_METHOD(SpirvCompositeConstruct)
9394
DEFINE_VISIT_METHOD(SpirvCompositeExtract)
9495
DEFINE_VISIT_METHOD(SpirvCompositeInsert)

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,13 @@ bool EmitVisitor::visit(SpirvConstantNull *inst) {
10061006
return true;
10071007
}
10081008

1009+
bool EmitVisitor::visit(SpirvUndef *inst) {
1010+
typeHandler.getOrCreateUndef(inst);
1011+
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
1012+
inst->getDebugName());
1013+
return true;
1014+
}
1015+
10091016
bool EmitVisitor::visit(SpirvCompositeConstruct *inst) {
10101017
initInstruction(inst);
10111018
curInst.push_back(inst->getResultTypeId());
@@ -2010,6 +2017,8 @@ uint32_t EmitTypeHandler::getOrCreateConstant(SpirvConstant *inst) {
20102017
return getOrCreateConstantNull(constNull);
20112018
} else if (auto *constBool = dyn_cast<SpirvConstantBoolean>(inst)) {
20122019
return getOrCreateConstantBool(constBool);
2020+
} else if (auto *constUndef = dyn_cast<SpirvUndef>(inst)) {
2021+
return getOrCreateUndef(constUndef);
20132022
}
20142023

20152024
llvm_unreachable("cannot emit unknown constant type");
@@ -2070,6 +2079,31 @@ uint32_t EmitTypeHandler::getOrCreateConstantNull(SpirvConstantNull *inst) {
20702079
return inst->getResultId();
20712080
}
20722081

2082+
uint32_t EmitTypeHandler::getOrCreateUndef(SpirvUndef *inst) {
2083+
auto canonicalType = inst->getAstResultType().getCanonicalType();
2084+
auto found = std::find_if(
2085+
emittedUndef.begin(), emittedUndef.end(),
2086+
[canonicalType](SpirvUndef *cached) {
2087+
return cached->getAstResultType().getCanonicalType() == canonicalType;
2088+
});
2089+
2090+
if (found != emittedUndef.end()) {
2091+
// We have already emitted this constant. Reuse.
2092+
inst->setResultId((*found)->getResultId());
2093+
return inst->getResultId();
2094+
}
2095+
2096+
// Constant wasn't emitted in the past.
2097+
const uint32_t typeId = emitType(inst->getResultType());
2098+
initTypeInstruction(inst->getopcode());
2099+
curTypeInst.push_back(typeId);
2100+
curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
2101+
finalizeTypeInstruction();
2102+
// Remember this constant for the future
2103+
emittedUndef.push_back(inst);
2104+
return inst->getResultId();
2105+
}
2106+
20732107
uint32_t EmitTypeHandler::getOrCreateConstantFloat(SpirvConstantFloat *inst) {
20742108
llvm::APFloat value = inst->getValue();
20752109
const SpirvType *type = inst->getResultType();

tools/clang/lib/SPIRV/EmitVisitor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class EmitTypeHandler {
5757
typeConstantBinary(typesVec), takeNextIdFunction(takeNextIdFn),
5858
emittedConstantInts({}), emittedConstantFloats({}),
5959
emittedConstantComposites({}), emittedConstantNulls({}),
60-
emittedConstantBools() {
60+
emittedUndef({}), emittedConstantBools() {
6161
assert(decVec);
6262
assert(typesVec);
6363
}
@@ -107,6 +107,7 @@ class EmitTypeHandler {
107107
uint32_t getOrCreateConstantFloat(SpirvConstantFloat *);
108108
uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
109109
uint32_t getOrCreateConstantNull(SpirvConstantNull *);
110+
uint32_t getOrCreateUndef(SpirvUndef *);
110111
uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
111112
template <typename vecType>
112113
void emitLiteral(const SpirvConstant *, vecType &outInst);
@@ -172,6 +173,7 @@ class EmitTypeHandler {
172173
emittedConstantFloats;
173174
llvm::SmallVector<SpirvConstantComposite *, 8> emittedConstantComposites;
174175
llvm::SmallVector<SpirvConstantNull *, 8> emittedConstantNulls;
176+
llvm::SmallVector<SpirvUndef *, 8> emittedUndef;
175177
SpirvConstantBoolean *emittedConstantBools[2];
176178
llvm::DenseSet<const SpirvInstruction *> emittedSpecConstantInstructions;
177179

@@ -252,6 +254,7 @@ class EmitVisitor : public Visitor {
252254
bool visit(SpirvConstantFloat *) override;
253255
bool visit(SpirvConstantComposite *) override;
254256
bool visit(SpirvConstantNull *) override;
257+
bool visit(SpirvUndef *) override;
255258
bool visit(SpirvCompositeConstruct *) override;
256259
bool visit(SpirvCompositeExtract *) override;
257260
bool visit(SpirvCompositeInsert *) override;

tools/clang/lib/SPIRV/SpirvBuilder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,6 +1826,13 @@ SpirvConstant *SpirvBuilder::getConstantNull(QualType type) {
18261826
return nullConst;
18271827
}
18281828

1829+
SpirvUndef *SpirvBuilder::getUndef(QualType type) {
1830+
// We do not care about making unique constants at this point.
1831+
auto *undef = new (context) SpirvUndef(type);
1832+
mod->addUndef(undef);
1833+
return undef;
1834+
}
1835+
18291836
SpirvString *SpirvBuilder::createString(llvm::StringRef str) {
18301837
// Create a SpirvString instruction
18311838
auto *instr = new (context) SpirvString(/* SourceLocation */ {}, str);

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,10 +1517,9 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
15171517
spvBuilder.createReturn(returnLoc);
15181518
} else {
15191519
// If the source code does not provide a proper return value for some
1520-
// control flow path, it's undefined behavior. We just return null
1521-
// value here.
1522-
spvBuilder.createReturnValue(spvBuilder.getConstantNull(retType),
1523-
returnLoc);
1520+
// control flow path, it's undefined behavior. We just return an
1521+
// undefined value here.
1522+
spvBuilder.createReturnValue(spvBuilder.getUndef(retType), returnLoc);
15241523
}
15251524
}
15261525
}

tools/clang/lib/SPIRV/SpirvInstruction.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantInteger)
5757
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantFloat)
5858
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantComposite)
5959
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantNull)
60+
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUndef)
6061
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeConstruct)
6162
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeExtract)
6263
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeInsert)
@@ -540,6 +541,11 @@ bool SpirvConstant::operator==(const SpirvConstant &that) const {
540541
if (thatNullInst == nullptr)
541542
return false;
542543
return *nullInst == *thatNullInst;
544+
} else if (auto *nullInst = dyn_cast<SpirvUndef>(this)) {
545+
auto *thatNullInst = dyn_cast<SpirvUndef>(&that);
546+
if (thatNullInst == nullptr)
547+
return false;
548+
return *nullInst == *thatNullInst;
543549
}
544550

545551
assert(false && "operator== undefined for SpirvConstant subclass");
@@ -613,6 +619,15 @@ bool SpirvConstantNull::operator==(const SpirvConstantNull &that) const {
613619
astResultType == that.astResultType;
614620
}
615621

622+
SpirvUndef::SpirvUndef(QualType type)
623+
: SpirvInstruction(IK_Undef, spv::Op::OpUndef, type,
624+
/*SourceLocation*/ {}) {}
625+
626+
bool SpirvUndef::operator==(const SpirvUndef &that) const {
627+
return opcode == that.opcode && resultType == that.resultType &&
628+
astResultType == that.astResultType;
629+
}
630+
616631
SpirvCompositeExtract::SpirvCompositeExtract(QualType resultType,
617632
SourceLocation loc,
618633
SpirvInstruction *compositeInst,

tools/clang/lib/SPIRV/SpirvModule.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ namespace spirv {
1717
SpirvModule::SpirvModule()
1818
: capabilities({}), extensions({}), extInstSets({}), memoryModel(nullptr),
1919
entryPoints({}), executionModes({}), moduleProcesses({}), decorations({}),
20-
constants({}), variables({}), functions({}), debugInstructions({}),
21-
perVertexInterp(false) {}
20+
constants({}), undefs({}), variables({}), functions({}),
21+
debugInstructions({}), perVertexInterp(false) {}
2222

2323
SpirvModule::~SpirvModule() {
2424
for (auto *cap : capabilities)
@@ -43,6 +43,8 @@ SpirvModule::~SpirvModule() {
4343
decoration->releaseMemory();
4444
for (auto *constant : constants)
4545
constant->releaseMemory();
46+
for (auto *undef : undefs)
47+
undef->releaseMemory();
4648
for (auto *var : variables)
4749
var->releaseMemory();
4850
for (auto *di : debugInstructions)
@@ -91,6 +93,12 @@ bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
9193
return false;
9294
}
9395

96+
for (auto iter = undefs.rbegin(); iter != undefs.rend(); ++iter) {
97+
auto *undef = *iter;
98+
if (!undef->invokeVisitor(visitor))
99+
return false;
100+
}
101+
94102
// Since SetVector doesn't have 'rbegin()' and 'rend()' methods, we use
95103
// manual indexing.
96104
for (auto decorIndex = decorations.size(); decorIndex > 0; --decorIndex) {
@@ -203,6 +211,10 @@ bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
203211
if (!constant->invokeVisitor(visitor))
204212
return false;
205213

214+
for (auto undef : undefs)
215+
if (!undef->invokeVisitor(visitor))
216+
return false;
217+
206218
for (auto var : variables)
207219
if (!var->invokeVisitor(visitor))
208220
return false;
@@ -334,6 +346,11 @@ void SpirvModule::addConstant(SpirvConstant *constant) {
334346
constants.push_back(constant);
335347
}
336348

349+
void SpirvModule::addUndef(SpirvUndef *undef) {
350+
assert(undef);
351+
undefs.push_back(undef);
352+
}
353+
337354
void SpirvModule::addString(SpirvString *str) {
338355
assert(str);
339356
constStrings.push_back(str);

0 commit comments

Comments
 (0)