Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions llvm/lib/SYCLLowerIR/LowerWGLocalMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ sycl::getKernelNamesUsingImplicitLocalMem(const Module &M) {
return -1;
};
llvm::for_each(M.functions(), [&](const Function &F) {
if (F.getCallingConv() == CallingConv::SPIR_KERNEL &&
F.hasFnAttribute(WORK_GROUP_STATIC_ATTR)) {
if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
int ArgPos = GetArgumentPos(F);
SPIRKernelNames.emplace_back(F.getName(), ArgPos);
if (ArgPos >= 0 || F.hasFnAttribute(WORK_GROUP_STATIC_ATTR))
SPIRKernelNames.emplace_back(F.getName(), ArgPos);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that these changes can be reverted.

With the changes below we mark all kernels using the feature automatically if they weren't marked by headers, so there is no need to analyze arguments of kernels without the attribute

}
});
}
Expand Down Expand Up @@ -184,11 +184,40 @@ lowerDynamicLocalMemCallDirect(CallInst *CI, Triple TT,

static void lowerLocalMemCall(Function *LocalMemAllocFunc,
std::function<void(CallInst *CI)> TransformCall) {
static SmallPtrSet<Function *, 16> FuncsCache;
SmallVector<CallInst *, 4> DelCalls;
for (User *U : LocalMemAllocFunc->users()) {
auto *CI = cast<CallInst>(U);
TransformCall(CI);
DelCalls.push_back(CI);
// Now, take each kernel that calls the builtins that allocate local memory,
// either directly or through a series of function calls that eventually end
// up in a direct call to the builtin, and attach the
// work-group-memory-static attribute to the kernel if not already attached.
// This is needed because free function kernels do not have the attribute
// added by the library as is the case with other types of kernels.
if (!FuncsCache.insert(CI->getFunction()).second)
continue; // We have already traversed call graph from this function

SmallVector<Function *, 8> WorkList;
WorkList.push_back(CI->getFunction());
while (!WorkList.empty()) {
auto *F = WorkList.back();
WorkList.pop_back();

// Mark kernel as using scratch memory if it isn't marked already
if (F->getCallingConv() == CallingConv::SPIR_KERNEL &&
!F->hasFnAttribute(WORK_GROUP_STATIC_ATTR))
F->addFnAttr(WORK_GROUP_STATIC_ATTR);

for (auto *FU : F->users()) {
if (auto *UCI = dyn_cast<CallInst>(FU)) {
if (FuncsCache.insert(UCI->getFunction()).second)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if can be merged together with the one above, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they cannot be merged since this if is inside the while loop and the one on line 199 is not, they functions that are checked by these if statements may potentially be different.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean line 214 and 215.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler seems to complain when I merge them.

WorkList.push_back(UCI->getFunction());
} // Even though there could be other uses of a Function, we don't
// care about them because we are only concerned about call graph
}
}
}

for (auto *CI : DelCalls) {
Expand Down
63 changes: 63 additions & 0 deletions sycl/test-e2e/FreeFunctionKernels/work_group_scratch_memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// REQUIRES: aspect-usm_shared_allocations

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// This test verifies that we can compile, run and get correct results when
// using a free function kernel that uses the work group scratch memory feature.

#include <sycl/ext/oneapi/work_group_static.hpp>

#include "helpers.hpp"
#include <cassert>
#include <sycl/ext/oneapi/experimental/enqueue_functions.hpp>
#include <sycl/ext/oneapi/free_function_queries.hpp>
#include <sycl/usm.hpp>

namespace syclext = sycl::ext::oneapi;
namespace syclexp = sycl::ext::oneapi::experimental;

constexpr int SIZE = 16;

SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
void double_kernel(float *src, float *dst) {
size_t lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id();

float *local_mem = (float *)syclexp::get_work_group_scratch_memory();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are it, we need to make sure that other extensions for local/work-group memory are covered as well (i.e. work with free function kernels):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are it, we need to make sure that other extensions for local/work-group memory are covered as well (i.e. work with free function kernels):

I have added the test.


for (int i = 0; i < SIZE; i++) {
local_mem[lid] = 2 * src[i];
dst[i] = local_mem[i];
}
}

int main() {
sycl::queue q;
float *src = sycl::malloc_shared<float>(SIZE, q);
float *dst = sycl::malloc_shared<float>(SIZE, q);

for (int i = 1; i < SIZE; i++) {
src[i] = i;
}

auto kbndl =
syclexp::get_kernel_bundle<double_kernel, sycl::bundle_state::executable>(
q.get_context());
sycl::kernel k = kbndl.template ext_oneapi_get_kernel<double_kernel>();

syclexp::launch_config cfg{
::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE)),
syclexp::properties{
syclexp::work_group_scratch_size(SIZE * sizeof(float))}};

syclexp::nd_launch(q, cfg, k, src, dst);
q.wait();

for (int i = 0; i < SIZE; i++) {
assert(dst[i] == 2 * src[i]);
}

sycl::free(src, q);
sycl::free(dst, q);
return 0;
}
Loading