Skip to content

Commit 67ae789

Browse files
committed
[OMPIRBuilder] Fix reduction codegen for SPIR-V
Signed-off-by: Sarnie, Nick <[email protected]>
1 parent caacfff commit 67ae789

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,8 @@ CGOpenMPRuntimeGPU::CGOpenMPRuntimeGPU(CodeGenModule &CGM)
869869
CGM.getLangOpts().OpenMPOffloadMandatory,
870870
/*HasRequiresReverseOffload*/ false, /*HasRequiresUnifiedAddress*/ false,
871871
hasRequiresUnifiedSharedMemory(), /*HasRequiresDynamicAllocators*/ false);
872+
Config.setDefaultTargetAS(
873+
CGM.getContext().getTargetInfo().getTargetAddressSpace(LangAS::Default));
872874
OMPBuilder.setConfig(Config);
873875

874876
if (!CGM.getLangOpts().OpenMPIsTargetDevice)
@@ -1243,7 +1245,10 @@ void CGOpenMPRuntimeGPU::emitParallelCall(
12431245
llvm::Value *ID = llvm::ConstantPointerNull::get(CGM.Int8PtrTy);
12441246
if (WFn)
12451247
ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy);
1246-
llvm::Value *FnPtr = Bld.CreateBitOrPointerCast(OutlinedFn, CGM.Int8PtrTy);
1248+
llvm::Type *FnPtrTy = llvm::PointerType::get(
1249+
CGF.getLLVMContext(), CGM.getDataLayout().getProgramAddressSpace());
1250+
1251+
llvm::Value *FnPtr = Bld.CreateBitOrPointerCast(OutlinedFn, FnPtrTy);
12471252

12481253
// Create a private scope that will globalize the arguments
12491254
// passed from the outside of the target region.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple x86_64-unknown-linux -fopenmp-targets=spirv64-intel -emit-llvm-bc %s -o %t-host.bc
2+
// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple spirv64-intel -fopenmp-targets=spirv64-intel -emit-llvm %s -fopenmp-is-target-device -fopenmp-host-ir-file-path %t-host.bc -o - | FileCheck %s
3+
4+
// expected-no-diagnostics
5+
6+
// CHECK: call spir_func addrspace(9) void @__kmpc_parallel_51(ptr addrspace(4) addrspacecast (ptr addrspace(1) @{{.*}} to ptr addrspace(4)),
7+
// CHECK-SAME: i32 %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr addrspace(9) @{{.*}}, ptr addrspace(4) {{.*}}, ptr addrspace(4) %{{.*}}, i64 {{.*}})
8+
9+
// CHECK: call addrspace(9) i32 @__kmpc_nvptx_teams_reduce_nowait_v2(ptr addrspace(4) addrspacecast (ptr addrspace(1) @{{.*}} to ptr addrspace(4)),
10+
// CHECK-SAME: ptr addrspace(4) %{{.*}}, i32 1024, i64 4, ptr addrspace(4) %{{.*}}, ptr addrspace(9) @{{.*}}, ptr addrspace(9) @{{.*}}, ptr addrspace(9) @{{.*}}, ptr addrspace(9) @{{.*}}, ptr addrspace(9) @{{.*}}, ptr addrspace(9) @{{.*}})
11+
12+
int main() {
13+
int matrix_sum = 0;
14+
#pragma omp target teams distribute parallel for \
15+
reduction(+:matrix_sum) \
16+
map(tofrom:matrix_sum)
17+
for (int i = 0; i < 100; i++) {
18+
19+
}
20+
21+
return 0;
22+
}

llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ __OMP_TYPE(Double)
4242

4343
OMP_TYPE(SizeTy, M.getDataLayout().getIntPtrType(Ctx))
4444
OMP_TYPE(Int63, Type::getIntNTy(Ctx, 63))
45+
OMP_TYPE(FuncPtrTy, PointerType::get(Ctx, M.getDataLayout().getProgramAddressSpace()))
4546

4647
__OMP_PTR_TYPE(VoidPtr)
4748
__OMP_PTR_TYPE(VoidPtrPtr)
@@ -471,7 +472,7 @@ __OMP_RTL(__kmpc_target_init, false, Int32, KernelEnvironmentPtr, KernelLaunchEn
471472
__OMP_RTL(__kmpc_target_deinit, false, Void,)
472473
__OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
473474
__OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
474-
VoidPtr, VoidPtr, VoidPtrPtr, SizeTy)
475+
FuncPtrTy, VoidPtr, VoidPtrPtr, SizeTy)
475476
__OMP_RTL(__kmpc_for_static_loop_4, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int8)
476477
__OMP_RTL(__kmpc_for_static_loop_4u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int8)
477478
__OMP_RTL(__kmpc_for_static_loop_8, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64, Int8)

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,7 +3623,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
36233623
// 1. Build a list of reduction variables.
36243624
// void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
36253625
auto Size = ReductionInfos.size();
3626-
Type *PtrTy = PointerType::getUnqual(Ctx);
3626+
Type *PtrTy = PointerType::get(Ctx, Config.getDefaultTargetAS());
3627+
Type *FuncPtrTy =
3628+
Builder.getPtrTy(M.getDataLayout().getProgramAddressSpace());
36273629
Type *RedArrayTy = ArrayType::get(PtrTy, Size);
36283630
CodeGenIP = Builder.saveIP();
36293631
Builder.restoreIP(AllocaIP);
@@ -3667,9 +3669,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
36673669
Builder.getInt64(MaxDataSize * ReductionInfos.size());
36683670
if (!IsTeamsReduction) {
36693671
Value *SarFuncCast =
3670-
Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
3672+
Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, FuncPtrTy);
36713673
Value *WcFuncCast =
3672-
Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
3674+
Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, FuncPtrTy);
36733675
Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
36743676
WcFuncCast};
36753677
Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
@@ -10072,13 +10074,14 @@ void OpenMPIRBuilder::initializeTypes(Module &M) {
1007210074
LLVMContext &Ctx = M.getContext();
1007310075
StructType *T;
1007410076
unsigned DefaultTargetAS = Config.getDefaultTargetAS();
10077+
unsigned ProgramAS = M.getDataLayout().getProgramAddressSpace();
1007510078
#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
1007610079
#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
1007710080
VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
1007810081
VarName##PtrTy = PointerType::get(Ctx, DefaultTargetAS);
1007910082
#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
1008010083
VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
10081-
VarName##Ptr = PointerType::get(Ctx, DefaultTargetAS);
10084+
VarName##Ptr = PointerType::get(Ctx, ProgramAS);
1008210085
#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
1008310086
T = StructType::getTypeByName(Ctx, StructName); \
1008410087
if (!T) \

0 commit comments

Comments
 (0)