From 6972fbbd9cd5cd969cc7a81b378188d6d0da8b30 Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Tue, 7 Jan 2025 11:07:44 +0000 Subject: [PATCH] [NativeCPU] Handle local args. Depending on the number of available threads, NativeCPU goes through different code paths for launching kernels. Some of these were missing the call to kernel.handleLocalArgs, resulting in local pointers being left as nullptr. Skip this code path for kernels that use local pointers. --- source/adapters/native_cpu/enqueue.cpp | 8 ++++---- source/adapters/native_cpu/kernel.hpp | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/source/adapters/native_cpu/enqueue.cpp b/source/adapters/native_cpu/enqueue.cpp index 6e4094ddef..ec5a6cf339 100644 --- a/source/adapters/native_cpu/enqueue.cpp +++ b/source/adapters/native_cpu/enqueue.cpp @@ -138,12 +138,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( #else bool isLocalSizeOne = ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1; - if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads) { + if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads && + !hKernel->hasLocalArgs()) { // If the local size is one, we make the assumption that we are running a // parallel_for over a sycl::range. - // Todo: we could add compiler checks and - // kernel properties for this (e.g. check that no barriers are called, no - // local memory args). + // Todo: we could add more compiler checks and + // kernel properties for this (e.g. check that no barriers are called). // Todo: this assumes that dim 0 is the best dimension over which we want to // parallelize diff --git a/source/adapters/native_cpu/kernel.hpp b/source/adapters/native_cpu/kernel.hpp index e2df672d05..4d2dec85cb 100644 --- a/source/adapters/native_cpu/kernel.hpp +++ b/source/adapters/native_cpu/kernel.hpp @@ -142,7 +142,9 @@ struct ur_kernel_handle_t_ : RefCounted { _localMemPoolSize = reqSize; } - // To be called before executing a work group + bool hasLocalArgs() const { return !_localArgInfo.empty(); } + + // To be called before executing a work group if local args are present void handleLocalArgs(size_t numParallelThread, size_t threadId) { // For each local argument we have size*numthreads size_t offset = 0;