Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class HandlerAccess;
class HostTask;

using EventImplPtr = std::shared_ptr<event_impl>;
using DeviceImplPtr = std::shared_ptr<device_impl>;

template <typename RetType, typename Func, typename Arg>
static Arg member_ptr_helper(RetType (Func::*)(Arg) const);
Expand Down Expand Up @@ -246,6 +247,7 @@ template <typename Type> struct get_kernel_wrapper_name_t {
};

__SYCL_EXPORT device getDeviceFromHandler(handler &);
const DeviceImplPtr &getDeviceImplFromHandler(handler &);

// Checks if a device_global has any registered kernel usage.
__SYCL_EXPORT bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr);
Expand Down Expand Up @@ -3460,6 +3462,8 @@ class __SYCL_EXPORT handler {
typename PropertyListT>
friend class accessor;
friend device detail::getDeviceFromHandler(handler &);
friend const detail::DeviceImplPtr &
detail::getDeviceImplFromHandler(handler &);

template <typename DataT, int Dimensions, access::mode AccessMode,
access::target AccessTarget, access::placeholder IsPlaceholder>
Expand Down
6 changes: 6 additions & 0 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @return Context associated with graph.
sycl::context getContext() const { return MContext; }

/// Query for the device_impl tied to this graph.
/// @return device_impl shared ptr reference associated with graph.
const DeviceImplPtr &getDeviceImplPtr() const {
return getSyclObjImpl(MDevice);
}

/// Query for the device tied to this graph.
/// @return Device associated with graph.
sycl::device getDevice() const { return MDevice; }
Expand Down
15 changes: 12 additions & 3 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ inline namespace _V1 {

namespace detail {

const DeviceImplPtr &getDeviceImplFromHandler(handler &CGH) {
assert((CGH.MQueue || getSyclObjImpl(CGH)->MGraph) &&
"One of MQueue or MGraph should be nonnull!");
if (CGH.MQueue)
return CGH.MQueue->getDeviceImplPtr();

return getSyclObjImpl(CGH)->MGraph->getDeviceImplPtr();
}

bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr) {
DeviceGlobalMapEntry *DGEntry =
detail::ProgramManager::getInstance().getDeviceGlobalEntry(
Expand Down Expand Up @@ -2030,10 +2039,10 @@ void handler::setUserFacingNodeType(ext::oneapi::experimental::node_type Type) {
}

std::optional<std::array<size_t, 3>> handler::getMaxWorkGroups() {
auto Dev = detail::getSyclObjImpl(detail::getDeviceFromHandler(*this));
const auto &DeviceImpl = detail::getDeviceImplFromHandler(*this);
std::array<size_t, 3> UrResult = {};
auto Ret = Dev->getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
Dev->getHandleRef(),
auto Ret = DeviceImpl->getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
DeviceImpl->getHandleRef(),
UrInfoCode<
ext::oneapi::experimental::info::device::max_work_groups<3>>::value,
sizeof(UrResult), &UrResult, nullptr);
Expand Down