Skip to content

Commit bdb8db3

Browse files
committed
test and comment
Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 2b8113b commit bdb8db3

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

llvm/lib/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -819,23 +819,33 @@ PreservedAnalyses SYCLLowerWGScopePass::run(Function &F,
819819
Allocas.insert(AllocaI);
820820
}
821821
for (; I && (I != BB.getTerminator()); I = I->getNextNode()) {
822+
// sycl::group functions returning a structure would return it via sret
823+
// CallInst operand, which will be placed in target's alloca address
824+
// space (which is private for SPIR). As such sycl::group function
825+
// might write to a Global with Local address space - frontend will
826+
// insert address space cast from Local to Private, which is illegal.
827+
// So here we create a temporary alloca, that will be used as sret
828+
// operand, and copy from it to the Global.
822829
if (CallInst *CI = dyn_cast<CallInst>(I)) {
823-
if (CI->getCalledFunction()->getName().
824-
starts_with(GET_GROUP_PREFIX) &&
830+
if (CI->getCalledFunction()->getName().starts_with(GET_GROUP_PREFIX) &&
825831
CI->hasStructRetAttr()) {
826-
if (auto *ASCast = dyn_cast<AddrSpaceCastOperator>(CI->getOperand(0))) {
832+
if (auto *ASCast =
833+
dyn_cast<AddrSpaceCastOperator>(CI->getOperand(0))) {
827834
unsigned SrcAS = ASCast->getSrcAddressSpace();
828835
unsigned DstAS = ASCast->getDestAddressSpace();
829836
if (SrcAS == static_cast<unsigned>(spirv::AddrSpace::Local) &&
830837
DstAS == static_cast<unsigned>(spirv::AddrSpace::Private)) {
838+
LLVM_DEBUG(llvm::dbgs() << "+++ Illegal AS cast found in a call: "
839+
<< *CI << "\n");
831840
IRBuilder<> Builder(CI->getContext());
832841
llvm::BasicBlock &FirstBB = F.getEntryBlock();
833842
Builder.SetInsertPoint(&FirstBB, FirstBB.begin());
834843
Type *ResTy = CI->getParamStructRetType(0);
835-
auto *TMPAlloca = Builder.CreateAlloca(
836-
ResTy, nullptr, "lower_wg.local_copy");
844+
auto *TMPAlloca =
845+
Builder.CreateAlloca(ResTy, nullptr, "lower_wg.local_copy");
837846
Builder.SetInsertPoint(CI->getNextNode());
838-
auto *LI = Builder.CreateLoad(ResTy, TMPAlloca, "lower_wg.private_load");
847+
auto *LI =
848+
Builder.CreateLoad(ResTy, TMPAlloca, "lower_wg.private_load");
839849
Builder.CreateStore(LI, ASCast->getPointerOperand());
840850
ASCast->replaceAllUsesWith(TMPAlloca);
841851
ASCast->dropAllReferences();
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: opt -passes=LowerWGScope -S %s -o - | FileCheck %s
2+
3+
; Check that no illegal AS casts remain after the pass
4+
5+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
6+
target triple = "spir64-unknown-unknown"
7+
8+
%"class.sycl::_V1::group" = type { %"class.sycl::_V1::range", %"class.sycl::_V1::range", %"class.sycl::_V1::range", %"class.sycl::_V1::id" }
9+
%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
10+
%"class.sycl::_V1::detail::array" = type { [1 x i64] }
11+
%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" }
12+
13+
@hierarchical = internal addrspace(3) global %"class.sycl::_V1::range" undef, align 8
14+
15+
define internal spir_func void @foo(ptr addrspace(4) noundef align 1 dereferenceable_or_null(1) %this, ptr noundef byval(%"class.sycl::_V1::group") align 8 %group_pid) !work_group_scope !0 {
16+
entry:
17+
; CHECK: entry:
18+
; CHECK-NEXT: %lower_wg.local_copy = alloca %"class.sycl::_V1::id", align 8
19+
; CHECK: call spir_func void @_ZNK4sycl3_V15groupILi1EE12get_group_idEv(ptr dead_on_unwind writable sret(%"class.sycl::_V1::id") align 8 %lower_wg.local_copy, ptr {{.*}})
20+
; CHECK: %lower_wg.private_load = load %"class.sycl::_V1::id", ptr %lower_wg.local_copy, align 8
21+
; CHECK: store %"class.sycl::_V1::id" %lower_wg.private_load, ptr addrspace(3) @hierarchical, align 8
22+
23+
%group_pid.ascast = addrspacecast ptr %group_pid to ptr addrspace(4)
24+
call spir_func void @_ZNK4sycl3_V15groupILi1EE12get_group_idEv(ptr dead_on_unwind writable sret(%"class.sycl::_V1::id") align 8 addrspacecast (ptr addrspace(3) @hierarchical to ptr), ptr addrspace(4) noundef align 8 dereferenceable_or_null(32) %group_pid.ascast)
25+
ret void
26+
}
27+
28+
declare spir_func void @_ZNK4sycl3_V15groupILi1EE12get_group_idEv(ptr dead_on_unwind noalias writable sret(%"class.sycl::_V1::id") align 8 %agg.result, ptr addrspace(4) noundef align 8 dereferenceable_or_null(32) %this)
29+
30+
!0 = !{}

0 commit comments

Comments
 (0)