Skip to content

Commit 24e8486

Browse files
committed
[HLSL][SPIRV] Add convergence tokens to entry point wrapper
Inlining currently assumes that either all function use controled convergence or none of them do. This is why we need to have the entry point wrapper use controled convergence. https://github.com/llvm/llvm-project/blob/c85611e8583e6392d56075ebdfa60893b6284813/llvm/lib/Transforms/Utils/InlineFunction.cpp#L2431-L2439
1 parent ab208de commit 24e8486

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,17 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
399399
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
400400
IRBuilder<> B(BB);
401401
llvm::SmallVector<Value *> Args;
402+
403+
SmallVector<OperandBundleDef, 1> OB;
404+
if (CGM.shouldEmitConvergenceTokens()) {
405+
assert(EntryFn->isConvergent());
406+
llvm::Value *
407+
I = B.CreateIntrinsic(llvm::Intrinsic::experimental_convergence_entry, {},
408+
{});
409+
llvm::Value *bundleArgs[] = {I};
410+
OB.emplace_back("convergencectrl", bundleArgs);
411+
}
412+
402413
// FIXME: support struct parameters where semantics are on members.
403414
// See: https://github.com/llvm/llvm-project/issues/57874
404415
unsigned SRetOffset = 0;
@@ -414,7 +425,7 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
414425
Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
415426
}
416427

417-
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
428+
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
418429
CI->setCallingConv(Fn->getCallingConv());
419430
// FIXME: Handle codegen for return type semantics.
420431
// See: https://github.com/llvm/llvm-project/issues/57875
@@ -469,14 +480,21 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
469480
for (auto &F : M.functions()) {
470481
if (!F.hasFnAttribute("hlsl.shader"))
471482
continue;
472-
IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
483+
auto* Token = getConvergenceToken(F.getEntryBlock());
484+
Instruction* IP = Token ? Token : &*F.getEntryBlock().begin();
485+
IRBuilder<> B(IP);
486+
std::vector<OperandBundleDef> OB;
487+
if (Token) {
488+
llvm::Value *bundleArgs[] = {Token};
489+
OB.emplace_back("convergencectrl", bundleArgs);
490+
}
473491
for (auto *Fn : CtorFns)
474-
B.CreateCall(FunctionCallee(Fn));
492+
B.CreateCall(FunctionCallee(Fn), {}, OB);
475493

476494
// Insert global dtors before the terminator of the last instruction
477495
B.SetInsertPoint(F.back().getTerminator());
478496
for (auto *Fn : DtorFns)
479-
B.CreateCall(FunctionCallee(Fn));
497+
B.CreateCall(FunctionCallee(Fn), {}, OB);
480498
}
481499

482500
// No need to keep global ctors/dtors for non-lib profile after call to
@@ -489,3 +507,18 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
489507
GV->eraseFromParent();
490508
}
491509
}
510+
511+
llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
512+
if (!CGM.shouldEmitConvergenceTokens())
513+
return nullptr;
514+
515+
auto E = BB.end();
516+
for(auto I = BB.begin(); I != E; ++I) {
517+
auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
518+
if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
519+
return II;
520+
}
521+
}
522+
llvm_unreachable("Convergence token should have been emitted.");
523+
return nullptr;
524+
}

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CGHLSLRuntime {
137137

138138
void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn);
139139
void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn);
140+
llvm::Instruction *getConvergenceToken(llvm::BasicBlock &BB);
140141

141142
private:
142143
void addBufferResourceAnnotation(llvm::GlobalVariable *GV,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -fnative-half-type -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @main()
4+
// CHECK-NEXT: entry:
5+
// CHECK-NEXT: [[token:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
6+
// CHECK-NEXT: call spir_func void @_Z4mainv() [ "convergencectrl"(token [[token]]) ]
7+
8+
[numthreads(1,1,1)]
9+
void main() {
10+
}
11+

0 commit comments

Comments
 (0)