Skip to content

Commit 1b405f6

Browse files
authored
[SPIR-V] Add support for 64-bit switch sel (microsoft#6049)
Allow 64-bit integer variables as switch statement selectors and clarify the error message for literal selectors (related to microsoft#4338).
1 parent 2c913d6 commit 1b405f6

File tree

10 files changed

+68
-18
lines changed

10 files changed

+68
-18
lines changed

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class SpirvBuilder {
377377
void
378378
createSwitch(SpirvBasicBlock *mergeLabel, SpirvInstruction *selector,
379379
SpirvBasicBlock *defaultLabel,
380-
llvm::ArrayRef<std::pair<uint32_t, SpirvBasicBlock *>> target,
380+
llvm::ArrayRef<std::pair<llvm::APInt, SpirvBasicBlock *>> target,
381381
SourceLocation, SourceRange);
382382

383383
/// \brief Creates a fragment-shader discard via by emitting OpKill.

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ class SpirvSwitch : public SpirvBranching {
824824
SpirvSwitch(
825825
SourceLocation loc, SpirvInstruction *selector,
826826
SpirvBasicBlock *defaultLabel,
827-
llvm::ArrayRef<std::pair<uint32_t, SpirvBasicBlock *>> &targetsVec);
827+
llvm::ArrayRef<std::pair<llvm::APInt, SpirvBasicBlock *>> &targetsVec);
828828

829829
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSwitch)
830830

@@ -837,7 +837,7 @@ class SpirvSwitch : public SpirvBranching {
837837

838838
SpirvInstruction *getSelector() const { return selector; }
839839
SpirvBasicBlock *getDefaultLabel() const { return defaultLabel; }
840-
llvm::ArrayRef<std::pair<uint32_t, SpirvBasicBlock *>> getTargets() const {
840+
llvm::ArrayRef<std::pair<llvm::APInt, SpirvBasicBlock *>> getTargets() const {
841841
return targets;
842842
}
843843
// Returns the branch label that will be taken for the given literal.
@@ -853,7 +853,7 @@ class SpirvSwitch : public SpirvBranching {
853853
private:
854854
SpirvInstruction *selector;
855855
SpirvBasicBlock *defaultLabel;
856-
llvm::SmallVector<std::pair<uint32_t, SpirvBasicBlock *>, 4> targets;
856+
llvm::SmallVector<std::pair<llvm::APInt, SpirvBasicBlock *>, 4> targets;
857857
};
858858

859859
/// \brief OpUnreachable instruction

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ bool EmitVisitor::visit(SpirvSwitch *inst) {
846846
curInst.push_back(
847847
getOrAssignResultId<SpirvBasicBlock>(inst->getDefaultLabel()));
848848
for (const auto &target : inst->getTargets()) {
849-
curInst.push_back(target.first);
849+
typeHandler.emitIntLiteral(target.first, curInst);
850850
curInst.push_back(getOrAssignResultId<SpirvBasicBlock>(target.second));
851851
}
852852
finalizeInstruction(&mainBinary);
@@ -2611,6 +2611,12 @@ template <typename vecType>
26112611
void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral,
26122612
vecType &outInst) {
26132613
const auto &literalVal = intLiteral->getValue();
2614+
emitIntLiteral(literalVal, outInst);
2615+
}
2616+
2617+
template <typename vecType>
2618+
void EmitTypeHandler::emitIntLiteral(const llvm::APInt &literalVal,
2619+
vecType &outInst) {
26142620
bool positive = !literalVal.isNegative();
26152621
if (literalVal.getBitWidth() <= 32) {
26162622
outInst.push_back(positive ? literalVal.getZExtValue()

tools/clang/lib/SPIRV/EmitVisitor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ class EmitTypeHandler {
114114
void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst);
115115
template <typename vecType>
116116
void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst);
117+
template <typename vecType>
118+
void emitIntLiteral(const llvm::APInt &literalVal, vecType &outInst);
117119

118120
private:
119121
void initTypeInstruction(spv::Op op);

tools/clang/lib/SPIRV/SpirvBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ SpirvSelect *SpirvBuilder::createSelect(QualType resultType,
761761
void SpirvBuilder::createSwitch(
762762
SpirvBasicBlock *mergeLabel, SpirvInstruction *selector,
763763
SpirvBasicBlock *defaultLabel,
764-
llvm::ArrayRef<std::pair<uint32_t, SpirvBasicBlock *>> target,
764+
llvm::ArrayRef<std::pair<llvm::APInt, SpirvBasicBlock *>> target,
765765
SourceLocation loc, SourceRange range) {
766766
assert(insertPoint && "null insert point");
767767
// Create the OpSelectioMerege.

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
#include "clang/SPIRV/AstTypeProbe.h"
2626
#include "clang/SPIRV/String.h"
2727
#include "clang/Sema/Sema.h"
28+
#include "llvm/ADT/APInt.h"
2829
#include "llvm/ADT/SetVector.h"
2930
#include "llvm/ADT/StringExtras.h"
31+
#include "llvm/Support/Casting.h"
3032

3133
#ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
3234
#include "clang/Basic/Version.h"
@@ -13422,7 +13424,7 @@ bool SpirvEmitter::allSwitchCasesAreIntegerLiterals(const Stmt *root) {
1342213424

1342313425
void SpirvEmitter::discoverAllCaseStmtInSwitchStmt(
1342413426
const Stmt *root, SpirvBasicBlock **defaultBB,
13425-
std::vector<std::pair<uint32_t, SpirvBasicBlock *>> *targets) {
13427+
std::vector<std::pair<llvm::APInt, SpirvBasicBlock *>> *targets) {
1342613428
if (!root)
1342713429
return;
1342813430

@@ -13442,7 +13444,7 @@ void SpirvEmitter::discoverAllCaseStmtInSwitchStmt(
1344213444
}
1344313445

1344413446
std::string caseLabel;
13445-
uint32_t caseValue = 0;
13447+
llvm::APInt caseValue;
1344613448
if (defaultStmt) {
1344713449
// This is the default branch.
1344813450
caseLabel = "switch.default";
@@ -13452,15 +13454,10 @@ void SpirvEmitter::discoverAllCaseStmtInSwitchStmt(
1345213454
// case <literal_integer>: {...; break;}
1345313455
const Expr *caseExpr = caseStmt->getLHS();
1345413456
assert(caseExpr && caseExpr->isEvaluatable(astContext));
13455-
auto bitWidth = astContext.getIntWidth(caseExpr->getType());
13456-
if (bitWidth != 32)
13457-
emitError(
13458-
"non-32bit integer case value in switch statement unimplemented",
13459-
caseExpr->getExprLoc());
1346013457
Expr::EvalResult evalResult;
1346113458
caseExpr->EvaluateAsRValue(evalResult, astContext);
13462-
const int64_t value = evalResult.Val.getInt().getSExtValue();
13463-
caseValue = static_cast<uint32_t>(value);
13459+
caseValue = evalResult.Val.getInt();
13460+
const int64_t value = caseValue.getSExtValue();
1346413461
caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
1346513462
llvm::itostr(std::abs(value));
1346613463
}
@@ -13535,6 +13532,13 @@ void SpirvEmitter::processSwitchStmtUsingSpirvOpSwitch(
1353513532
doDeclStmt(condVarDeclStmt);
1353613533

1353713534
auto *cond = switchStmt->getCond();
13535+
if (llvm::dyn_cast<IntegerLiteral>(cond)) {
13536+
emitError(
13537+
"integer literal selectors in switch statements not yet implemented",
13538+
cond->getLocStart());
13539+
return;
13540+
}
13541+
1353813542
auto *selector = doExpr(cond);
1353913543

1354013544
// We need a merge block regardless of the number of switch cases.
@@ -13547,7 +13551,7 @@ void SpirvEmitter::processSwitchStmtUsingSpirvOpSwitch(
1354713551
auto *defaultBB = mergeBB;
1354813552

1354913553
// (literal, labelId) pairs to pass to the OpSwitch instruction.
13550-
std::vector<std::pair<uint32_t, SpirvBasicBlock *>> targets;
13554+
std::vector<std::pair<llvm::APInt, SpirvBasicBlock *>> targets;
1355113555
discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB, &targets);
1355213556

1355313557
// Create the OpSelectionMerge and OpSwitch.

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ class SpirvEmitter : public ASTConsumer {
919919
/// method panics if it finds a case value that is not an integer literal.
920920
void discoverAllCaseStmtInSwitchStmt(
921921
const Stmt *root, SpirvBasicBlock **defaultBB,
922-
std::vector<std::pair<uint32_t, SpirvBasicBlock *>> *targets);
922+
std::vector<std::pair<llvm::APInt, SpirvBasicBlock *>> *targets);
923923

924924
/// Flattens structured AST of the given switch statement into a vector of AST
925925
/// nodes and stores into flatSwitch.

tools/clang/lib/SPIRV/SpirvInstruction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ SpirvReturn::SpirvReturn(SourceLocation loc, SpirvInstruction *retVal,
384384
SpirvSwitch::SpirvSwitch(
385385
SourceLocation loc, SpirvInstruction *selectorInst,
386386
SpirvBasicBlock *defaultLbl,
387-
llvm::ArrayRef<std::pair<uint32_t, SpirvBasicBlock *>> &targetsVec)
387+
llvm::ArrayRef<std::pair<llvm::APInt, SpirvBasicBlock *>> &targetsVec)
388388
: SpirvBranching(IK_Switch, spv::Op::OpSwitch, loc), selector(selectorInst),
389389
defaultLabel(defaultLbl), targets(targetsVec.begin(), targetsVec.end()) {}
390390

tools/clang/test/CodeGenSPIRV_Lit/cf.switch.opswitch.hlsl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,4 +357,31 @@ void main() {
357357
break;
358358
}
359359

360+
361+
/////////////////////////////////////////
362+
// 64-bit integer variable as selector //
363+
/////////////////////////////////////////
364+
int64_t longsel;
365+
// CHECK: [[longSelector:%[0-9]+]] = OpLoad %long %longsel
366+
// CHECK-NEXT: OpSelectionMerge %switch_merge_10 None
367+
// CHECK-NEXT: OpSwitch [[longSelector]] %switch_merge_10 -1 %switch_n1
368+
switch (longsel) {
369+
case -1:
370+
result = 0;
371+
break;
372+
}
373+
374+
375+
//////////////////////////////////////////////////
376+
// 64-bit unsigned integer variable as selector //
377+
//////////////////////////////////////////////////
378+
uint64_t ulongsel;
379+
// CHECK: [[ulongSelector:%[0-9]+]] = OpLoad %ulong %ulongsel
380+
// CHECK-NEXT: OpSelectionMerge %switch_merge_11 None
381+
// CHECK-NEXT: OpSwitch [[ulongSelector]] %switch_merge_11 12345678910 %switch_12345678910
382+
switch (ulongsel) {
383+
case 12345678910:
384+
result = 0;
385+
break;
386+
}
360387
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: not %dxc -T ps_6_0 -fcgl -spirv %s 2>&1 | FileCheck %s
2+
3+
// CHECK: error: integer literal selectors in switch statements not yet implemented
4+
5+
float main() : SV_TARGET {
6+
switch (0) {
7+
case 0:
8+
return 1;
9+
}
10+
return 0;
11+
}

0 commit comments

Comments
 (0)