-
Notifications
You must be signed in to change notification settings - Fork 796
[SYCL] Allow work group scratch memory to be used with free function kernels #19837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
668a2f8
48e6bb6
2340c2c
75edca0
93a29e2
625677c
6e814de
a3f8114
bd359c6
dc1e3d0
d41fee7
183042a
0154a11
ac977fc
44b0404
1e00a45
052a1e9
0e01b13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,11 +184,40 @@ lowerDynamicLocalMemCallDirect(CallInst *CI, Triple TT, | |
|
|
||
| static void lowerLocalMemCall(Function *LocalMemAllocFunc, | ||
| std::function<void(CallInst *CI)> TransformCall) { | ||
| static SmallPtrSet<Function *, 16> FuncsCache; | ||
| SmallVector<CallInst *, 4> DelCalls; | ||
| for (User *U : LocalMemAllocFunc->users()) { | ||
| auto *CI = cast<CallInst>(U); | ||
| TransformCall(CI); | ||
| DelCalls.push_back(CI); | ||
| // Now, take each kernel that calls the builtins that allocate local memory, | ||
| // either directly or through a series of function calls that eventually end | ||
| // up in a direct call to the builtin, and attach the | ||
| // work-group-memory-static attribute to the kernel if not already attached. | ||
| // This is needed because free function kernels do not have the attribute | ||
| // added by the library as is the case with other types of kernels. | ||
| if (!FuncsCache.insert(CI->getFunction()).second) | ||
| continue; // We have already traversed call graph from this function. | ||
|
|
||
| SmallVector<Function *, 8> WorkList; | ||
| WorkList.push_back(CI->getFunction()); | ||
| while (!WorkList.empty()) { | ||
| Function *F = WorkList.back(); | ||
| WorkList.pop_back(); | ||
|
|
||
| // Mark kernel as using scratch memory if it isn't marked already. | ||
| if (F->getCallingConv() == CallingConv::SPIR_KERNEL && | ||
| !F->hasFnAttribute(WORK_GROUP_STATIC_ATTR)) | ||
| F->addFnAttr(WORK_GROUP_STATIC_ATTR); | ||
maarquitos14 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| for (auto *FU : F->users()) { | ||
| if (auto *UCI = dyn_cast<CallInst>(FU)) { | ||
| if (FuncsCache.insert(UCI->getFunction()).second) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This if can be merged together with the one above, I think.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think they cannot be merged since this if is inside the while loop and the one on line 199 is not, they functions that are checked by these if statements may potentially be different.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean line 214 and 215.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The compiler seems to complain when I merge them. |
||
| WorkList.push_back(UCI->getFunction()); | ||
| } // Even though there could be other uses of a Function, we don't | ||
| // care about them because we are only concerned about call graph. | ||
| } | ||
| } | ||
| } | ||
|
|
||
| for (auto *CI : DelCalls) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| // REQUIRES: aspect-usm_shared_allocations | ||
| // UNSUPPORTED: target-amd | ||
| // UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/16072 | ||
|
|
||
| // RUN: %{build} -o %t.out | ||
| // RUN: %{run} %t.out | ||
|
|
||
| // This test verifies that we can compile, run and get correct results when | ||
| // using a free function kernel that allocates shared local memory in a kernel | ||
| // either by way of the work group scratch memory extension or the work group | ||
| // static memory extension. | ||
|
|
||
| #include "helpers.hpp" | ||
maarquitos14 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| #include <cassert> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this one should be the last one.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The formatter disagrees, I have made the change manually so hopefully it wont fail the formatter pre-commit check.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the formatter doesn't accept this so I am reverting it. |
||
| #include <sycl/ext/oneapi/experimental/enqueue_functions.hpp> | ||
| #include <sycl/ext/oneapi/free_function_queries.hpp> | ||
| #include <sycl/ext/oneapi/work_group_static.hpp> | ||
| #include <sycl/group_barrier.hpp> | ||
| #include <sycl/usm.hpp> | ||
|
|
||
| namespace syclext = sycl::ext::oneapi; | ||
| namespace syclexp = sycl::ext::oneapi::experimental; | ||
|
|
||
| constexpr int SIZE = 16; | ||
|
|
||
| SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) | ||
| void scratchKernel(float *src, float *dst) { | ||
| size_t lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id(); | ||
|
||
| float *localMem = | ||
| reinterpret_cast<float *>(syclexp::get_work_group_scratch_memory()); | ||
| localMem[lid] = 2 * src[lid]; | ||
| dst[lid] = localMem[lid]; | ||
| } | ||
|
|
||
| SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) | ||
| void staticKernel(float *src, float *dst) { | ||
| sycl::nd_item<1> item = syclext::this_work_item::get_nd_item<1>(); | ||
| size_t lid = item.get_local_linear_id(); | ||
| syclexp::work_group_static<float[SIZE]> localMem; | ||
| localMem[lid] = src[lid] * src[lid]; | ||
| sycl::group_barrier(item.get_group()); | ||
| if (item.get_group().leader()) { // Check that memory is indeed shared between | ||
| // the work group. | ||
| for (int i = 0; i < SIZE; ++i) | ||
lbushi25 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert(localMem[i] == src[i] * src[i]); | ||
| } | ||
| dst[lid] = localMem[lid]; | ||
| } | ||
|
|
||
| SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) | ||
| void scratchStaticKernel(float *src, float *dst) { | ||
| size_t lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id(); | ||
| float *scratchMem = | ||
| reinterpret_cast<float *>(syclexp::get_work_group_scratch_memory()); | ||
| syclexp::work_group_static<float[SIZE]> staticMem; | ||
| scratchMem[lid] = src[lid]; | ||
| staticMem[lid] = src[lid]; | ||
| dst[lid] = scratchMem[lid] + staticMem[lid]; | ||
| } | ||
|
|
||
| int main() { | ||
| sycl::queue q; | ||
| float *src = sycl::malloc_shared<float>(SIZE, q); | ||
| float *dst = sycl::malloc_shared<float>(SIZE, q); | ||
|
|
||
| for (int i = 0; i < SIZE; i++) { | ||
lbushi25 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| src[i] = i; | ||
| } | ||
|
|
||
| auto scratchBndl = | ||
| syclexp::get_kernel_bundle<scratchKernel, sycl::bundle_state::executable>( | ||
| q.get_context()); | ||
| auto staticBndl = | ||
| syclexp::get_kernel_bundle<staticKernel, sycl::bundle_state::executable>( | ||
| q.get_context()); | ||
| auto scratchStaticBndl = syclexp::get_kernel_bundle< | ||
| scratchStaticKernel, sycl::bundle_state::executable>(q.get_context()); | ||
|
|
||
| sycl::kernel scratchKrn = | ||
| scratchBndl.template ext_oneapi_get_kernel<scratchKernel>(); | ||
| sycl::kernel staticKrn = | ||
| staticBndl.template ext_oneapi_get_kernel<staticKernel>(); | ||
| sycl::kernel scratchStaticKrn = | ||
| scratchStaticBndl.template ext_oneapi_get_kernel<scratchStaticKernel>(); | ||
| syclexp::launch_config scratchKernelcfg{ | ||
| ::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE)), | ||
| syclexp::properties{ | ||
| syclexp::work_group_scratch_size(SIZE * sizeof(float))}}; | ||
| syclexp::launch_config staticKernelcfg{ | ||
| ::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE))}; | ||
|
|
||
| syclexp::nd_launch(q, scratchKernelcfg, scratchKrn, src, dst); | ||
| q.wait(); | ||
| for (int i = 0; i < SIZE; i++) { | ||
lbushi25 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert(dst[i] == 2 * src[i]); | ||
| } | ||
|
|
||
| syclexp::nd_launch(q, staticKernelcfg, staticKrn, src, dst); | ||
| q.wait(); | ||
| for (int i = 0; i < SIZE; i++) { | ||
lbushi25 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert(dst[i] == src[i] * src[i]); | ||
| } | ||
|
|
||
| syclexp::nd_launch(q, scratchKernelcfg, scratchStaticKrn, src, dst); | ||
| q.wait(); | ||
| for (int i = 0; i < SIZE; i++) { | ||
lbushi25 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert(dst[i] == 2 * src[i]); | ||
| } | ||
|
|
||
| sycl::free(src, q); | ||
| sycl::free(dst, q); | ||
| return 0; | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.