Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class HandlerAccess;
class HostTask;

using EventImplPtr = std::shared_ptr<event_impl>;
using DeviceImplPtr = std::shared_ptr<device_impl>;

template <typename RetType, typename Func, typename Arg>
static Arg member_ptr_helper(RetType (Func::*)(Arg) const);
Expand Down Expand Up @@ -246,6 +247,7 @@ template <typename Type> struct get_kernel_wrapper_name_t {
};

__SYCL_EXPORT device getDeviceFromHandler(handler &);
const DeviceImplPtr &getDeviceImplFromHandler(handler &cgh);

// Checks if a device_global has any registered kernel usage.
__SYCL_EXPORT bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr);
Expand Down Expand Up @@ -3460,6 +3462,8 @@ class __SYCL_EXPORT handler {
typename PropertyListT>
friend class accessor;
friend device detail::getDeviceFromHandler(handler &);
friend const detail::DeviceImplPtr &
detail::getDeviceImplFromHandler(handler &);

template <typename DataT, int Dimensions, access::mode AccessMode,
access::target AccessTarget, access::placeholder IsPlaceholder>
Expand Down
4 changes: 4 additions & 0 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @return Context associated with graph.
sycl::context getContext() const { return MContext; }

const DeviceImplPtr &getDeviceImplPtr() const {
return getSyclObjImpl(MDevice);
}

/// Query for the device tied to this graph.
/// @return Device associated with graph.
sycl::device getDevice() const { return MDevice; }
Expand Down
15 changes: 12 additions & 3 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ inline namespace _V1 {

namespace detail {

const DeviceImplPtr &getDeviceImplFromHandler(handler &cgh) {
assert((cgh.MQueue || getSyclObjImpl(cgh)->MGraph) &&
"One of MQueue or MGraph should be nonnull!");
if (cgh.MQueue)
return cgh.MQueue->getDeviceImplPtr();

return getSyclObjImpl(cgh)->MGraph->getDeviceImplPtr();
}

bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr) {
DeviceGlobalMapEntry *DGEntry =
detail::ProgramManager::getInstance().getDeviceGlobalEntry(
Expand Down Expand Up @@ -2030,10 +2039,10 @@ void handler::setUserFacingNodeType(ext::oneapi::experimental::node_type Type) {
}

std::optional<std::array<size_t, 3>> handler::getMaxWorkGroups() {
auto Dev = detail::getSyclObjImpl(detail::getDeviceFromHandler(*this));
const auto &DeviceImpl = detail::getDeviceImplFromHandler(*this);
std::array<size_t, 3> UrResult = {};
auto Ret = Dev->getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
Dev->getHandleRef(),
auto Ret = DeviceImpl->getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
DeviceImpl->getHandleRef(),
UrInfoCode<
ext::oneapi::experimental::info::device::max_work_groups<3>>::value,
sizeof(UrResult), &UrResult, nullptr);
Expand Down
Loading