diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 2f7bbdfebc1f..95626d4a94ec 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -210,6 +210,7 @@ class HandlerAccess; class HostTask; using EventImplPtr = std::shared_ptr; +using DeviceImplPtr = std::shared_ptr; template static Arg member_ptr_helper(RetType (Func::*)(Arg) const); @@ -246,6 +247,7 @@ template struct get_kernel_wrapper_name_t { }; __SYCL_EXPORT device getDeviceFromHandler(handler &); +const DeviceImplPtr &getDeviceImplFromHandler(handler &); // Checks if a device_global has any registered kernel usage. __SYCL_EXPORT bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr); @@ -3460,6 +3462,8 @@ class __SYCL_EXPORT handler { typename PropertyListT> friend class accessor; friend device detail::getDeviceFromHandler(handler &); + friend const detail::DeviceImplPtr & + detail::getDeviceImplFromHandler(handler &); template diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 6045cc9ededa..f5a81c1cbf7c 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -919,6 +919,12 @@ class graph_impl : public std::enable_shared_from_this { /// @return Context associated with graph. sycl::context getContext() const { return MContext; } + /// Query for the device_impl tied to this graph. + /// @return device_impl shared ptr reference associated with graph. + const DeviceImplPtr &getDeviceImplPtr() const { + return getSyclObjImpl(MDevice); + } + /// Query for the device tied to this graph. /// @return Device associated with graph. sycl::device getDevice() const { return MDevice; } diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 5103e5524669..86a1c2544d69 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -42,6 +42,15 @@ inline namespace _V1 { namespace detail { +const DeviceImplPtr &getDeviceImplFromHandler(handler &CGH) { + assert((CGH.MQueue || getSyclObjImpl(CGH)->MGraph) && + "One of MQueue or MGraph should be nonnull!"); + if (CGH.MQueue) + return CGH.MQueue->getDeviceImplPtr(); + + return getSyclObjImpl(CGH)->MGraph->getDeviceImplPtr(); +} + bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr) { DeviceGlobalMapEntry *DGEntry = detail::ProgramManager::getInstance().getDeviceGlobalEntry( @@ -2030,10 +2039,10 @@ void handler::setUserFacingNodeType(ext::oneapi::experimental::node_type Type) { } std::optional> handler::getMaxWorkGroups() { - auto Dev = detail::getSyclObjImpl(detail::getDeviceFromHandler(*this)); + const auto &DeviceImpl = detail::getDeviceImplFromHandler(*this); std::array UrResult = {}; - auto Ret = Dev->getAdapter()->call_nocheck( - Dev->getHandleRef(), + auto Ret = DeviceImpl->getAdapter()->call_nocheck( + DeviceImpl->getHandleRef(), UrInfoCode< ext::oneapi::experimental::info::device::max_work_groups<3>>::value, sizeof(UrResult), &UrResult, nullptr);