Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sycl/include/sycl/exception_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ inline namespace _V1 {
// Forward declaration
namespace detail {
class queue_impl;
class device_impl;
}

/// A list of asynchronous exceptions.
Expand All @@ -46,6 +47,7 @@ class __SYCL_EXPORT exception_list {

private:
friend class detail::queue_impl;
friend class detail::device_impl;
void PushBack(const_reference Value);
void PushBack(value_type &&Value);
void Clear() noexcept;
Expand Down
32 changes: 32 additions & 0 deletions sycl/source/detail/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,31 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
return {};
}

/// Puts exception to the list of asynchronous ecxeptions.
///
/// \param QueueWeakPtr is a weak pointer referring to the queue to report
/// the asynchronous exceptions for.
/// \param ExceptionPtr is a pointer to exception to be put.
void reportAsyncException(std::weak_ptr<queue_impl> QueueWeakPtr,
const std::exception_ptr &ExceptionPtr) {
std::lock_guard<std::mutex> Lock(MAsyncExceptionsMutex);
MAsyncExceptions[QueueWeakPtr].PushBack(ExceptionPtr);
}

/// Extracts all unconsumed asynchronous exceptions for a given queue.
///
/// \param QueueWeakPtr is a weak pointer referring to the queue to extract
/// unconsumed asynchronous exceptions for.
exception_list flushAsyncExceptions(std::weak_ptr<queue_impl> QueueWeakPtr) {
std::lock_guard<std::mutex> Lock(MAsyncExceptionsMutex);
auto ExceptionsEntryIt = MAsyncExceptions.find(QueueWeakPtr);
if (ExceptionsEntryIt == MAsyncExceptions.end())
return exception_list{};
exception_list Exceptions = std::move(ExceptionsEntryIt->second);
MAsyncExceptions.erase(ExceptionsEntryIt);
return Exceptions;
}

private:
ur_device_handle_t MDevice = 0;
// This is used for getAdapter so should be above other properties.
Expand All @@ -2272,6 +2297,13 @@ class device_impl : public std::enable_shared_from_this<device_impl> {

const ur_device_handle_t MRootDevice;

// Asynchronous exceptions are captured at device-level until flushed, either
// by queues, events or a synchronization on the device itself.
std::mutex MAsyncExceptionsMutex;
std::map<std::weak_ptr<queue_impl>, exception_list,
std::owner_less<std::weak_ptr<queue_impl>>>
MAsyncExceptions;

// Order of caches matters! UR must come before SYCL info descriptors (because
// get_info calls get_info_impl but the opposite never happens) and both
// should come before aspects.
Expand Down
27 changes: 24 additions & 3 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ void event_impl::initHostProfilingInfo() {
MHostProfilingInfo->setDevice(&Device);
}

void event_impl::setSubmittedQueue(std::weak_ptr<queue_impl> SubmittedQueue) {
MSubmittedQueue = std::move(SubmittedQueue);
void event_impl::setSubmittedQueue(queue_impl *SubmittedQueue) {
MSubmittedQueue = SubmittedQueue->weak_from_this();
MSubmittedDevice = &SubmittedQueue->getDeviceImpl();
}

#ifdef XPTI_ENABLE_INSTRUMENTATION
Expand Down Expand Up @@ -308,8 +309,28 @@ void event_impl::wait(bool *Success) {
void event_impl::wait_and_throw() {
wait();

if (std::shared_ptr<queue_impl> SubmittedQueue = MSubmittedQueue.lock())
if (std::shared_ptr<queue_impl> SubmittedQueue = MSubmittedQueue.lock()) {
SubmittedQueue->throw_asynchronous();
return;
}

// If the queue has died, we rely on finding its exceptions through the
// device.
if (MSubmittedDevice == nullptr)
return;

// If MSubmittedQueue has died, get flush any exceptions associated with it
// still, then user either the context async_handler or the default
// async_handler.
exception_list Exceptions =
MSubmittedDevice->flushAsyncExceptions(MSubmittedQueue);
if (Exceptions.size() == 0)
return;

if (MContext && MContext->get_async_handler())
MContext->get_async_handler()(std::move(Exceptions));
else
defaultAsyncHandler(std::move(Exceptions));
}

void event_impl::checkProfilingPreconditions() const {
Expand Down
5 changes: 3 additions & 2 deletions sycl/source/detail/event_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ class event_impl {
MWorkerQueue = std::move(WorkerQueue);
};

/// Sets original queue used for submission.
/// Sets original queue and device used for submission.
///
/// @return
void setSubmittedQueue(std::weak_ptr<queue_impl> SubmittedQueue);
void setSubmittedQueue(queue_impl *SubmittedQueue);

/// Indicates if this event is not associated with any command and doesn't
/// have native handle.
Expand Down Expand Up @@ -394,6 +394,7 @@ class event_impl {

std::weak_ptr<queue_impl> MWorkerQueue;
std::weak_ptr<queue_impl> MSubmittedQueue;
device_impl *MSubmittedDevice = nullptr;

/// Dependency events prepared for waiting by backend.
std::vector<EventImplPtr> MPreparedDepsEvents;
Expand Down
37 changes: 10 additions & 27 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
// notification and destroy the trace event for this queue.
destructorNotification();
#endif
throw_asynchronous();
auto status =
getAdapter().call_nocheck<UrApiKind::urQueueRelease>(MQueue);
// If loader is already closed, it'll return a not-initialized status
Expand Down Expand Up @@ -393,9 +392,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
/// @param Loc is the code location of the submit call (default argument)
void wait(const detail::code_location &Loc = {});

/// \return list of asynchronous exceptions occurred during execution.
exception_list getExceptionList() const { return MExceptions; }

/// @param Loc is the code location of the submit call (default argument)
void wait_and_throw(const detail::code_location &Loc = {}) {
wait(Loc);
Expand All @@ -408,21 +404,20 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
/// Synchronous errors will be reported through SYCL exceptions.
/// Asynchronous errors will be passed to the async_handler passed to the
/// queue on construction. If no async_handler was provided then
/// asynchronous exceptions will be lost.
/// asynchronous exceptions will be passed to the default async_handler.
void throw_asynchronous() {
if (!MAsyncHandler)
exception_list Exceptions =
getDeviceImpl().flushAsyncExceptions(weak_from_this());
if (Exceptions.size() == 0)
return;

exception_list Exceptions;
{
std::lock_guard<std::mutex> Lock(MMutex);
std::swap(Exceptions, MExceptions);
}
// Unlock the mutex before calling user-provided handler to avoid
// potential deadlock if the same queue is somehow referenced in the
// handler.
if (Exceptions.size())
if (MAsyncHandler)
MAsyncHandler(std::move(Exceptions));
else if (const async_handler &CtxAsyncHandler =
getContextImpl().get_async_handler())
CtxAsyncHandler(std::move(Exceptions));
else
defaultAsyncHandler(std::move(Exceptions));
}

/// Creates UR properties array.
Expand Down Expand Up @@ -570,14 +565,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
event mem_advise(const void *Ptr, size_t Length, ur_usm_advice_flags_t Advice,
const std::vector<event> &DepEvents, bool CallerNeedsEvent);

/// Puts exception to the list of asynchronous ecxeptions.
///
/// \param ExceptionPtr is a pointer to exception to be put.
void reportAsyncException(const std::exception_ptr &ExceptionPtr) {
std::lock_guard<std::mutex> Lock(MMutex);
MExceptions.PushBack(ExceptionPtr);
}

static ThreadPool &getThreadPool() {
return GlobalHandler::instance().getHostTaskThreadPool();
}
Expand Down Expand Up @@ -979,10 +966,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
/// These events are tracked, but not owned, by the queue.
std::vector<std::weak_ptr<event_impl>> MEventsWeak;

/// Events without data dependencies (such as USM) need an owner,
/// additionally, USM operations are not added to the scheduler command graph,
/// queue is the only owner on the runtime side.
exception_list MExceptions;
const async_handler MAsyncHandler;
const property_list MPropList;

Expand Down
28 changes: 17 additions & 11 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,14 @@ class DispatchHostTask {
AdapterWithEvents.first->call<UrApiKind::urEventWait>(RawEvents.size(),
RawEvents.data());
} catch (const sycl::exception &) {
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
std::current_exception());
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
QueuePtr->getDeviceImpl().reportAsyncException(
QueuePtr, std::current_exception());
return false;
} catch (...) {
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
std::current_exception());
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
QueuePtr->getDeviceImpl().reportAsyncException(
QueuePtr, std::current_exception());
return false;
}
}
Expand Down Expand Up @@ -407,7 +409,8 @@ class DispatchHostTask {
make_error_code(errc::runtime),
std::string("Couldn't wait for host-task's dependencies")));

MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(EPtr);
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
QueuePtr->getDeviceImpl().reportAsyncException(QueuePtr, EPtr);
// reset host-task's lambda and quit
HostTask.MHostTask.reset();
Scheduler::getInstance().NotifyHostTaskCompletion(MThisCmd);
Expand Down Expand Up @@ -469,8 +472,9 @@ class DispatchHostTask {
}
}
#endif
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
CurrentException);
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
QueuePtr->getDeviceImpl().reportAsyncException(QueuePtr,
CurrentException);
}

HostTask.MHostTask.reset();
Expand All @@ -487,8 +491,9 @@ class DispatchHostTask {
Scheduler::getInstance().NotifyHostTaskCompletion(MThisCmd);
} catch (...) {
auto CurrentException = std::current_exception();
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
CurrentException);
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
QueuePtr->getDeviceImpl().reportAsyncException(QueuePtr,
CurrentException);
}
}
};
Expand Down Expand Up @@ -563,7 +568,8 @@ Command::Command(
MCommandBuffer(CommandBuffer), MSyncPointDeps(SyncPoints) {
MWorkerQueue = MQueue;
MEvent->setWorkerQueue(MWorkerQueue);
MEvent->setSubmittedQueue(MWorkerQueue);
if (Queue)
MEvent->setSubmittedQueue(Queue);
MEvent->setCommand(this);
if (MQueue)
MEvent->setContextImpl(MQueue->getContextImpl());
Expand Down Expand Up @@ -1958,7 +1964,7 @@ ExecCGCommand::ExecCGCommand(
assert(SubmitQueue &&
"Host task command group must have a valid submit queue");

MEvent->setSubmittedQueue(SubmitQueue->weak_from_this());
MEvent->setSubmittedQueue(SubmitQueue);
// Initialize host profiling info if the queue has profiling enabled.
if (SubmitQueue->MIsProfilingEnabled)
MEvent->initHostProfilingInfo();
Expand Down
3 changes: 2 additions & 1 deletion sycl/source/detail/scheduler/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ EventImplPtr Scheduler::addCopyBack(Requirement *Req) {
auto WorkerQueue = NewCmd->getEvent()->getWorkerQueue();
assert(WorkerQueue &&
"WorkerQueue for CopyBack command must be not null");
WorkerQueue->reportAsyncException(std::current_exception());
WorkerQueue->getDeviceImpl().reportAsyncException(
WorkerQueue, std::current_exception());
}
}
EventImplPtr NewEvent = NewCmd->getEvent();
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ event handler::finalize() {
// it to the graph to create a node, rather than submit it to the scheduler.
if (auto GraphImpl = Queue->getCommandGraph(); GraphImpl) {
auto EventImpl = detail::event_impl::create_completed_host_event();
EventImpl->setSubmittedQueue(Queue->weak_from_this());
EventImpl->setSubmittedQueue(Queue);
ext::oneapi::experimental::detail::node_impl *NodeImpl = nullptr;

// GraphImpl is read and written in this scope so we lock this graph
Expand Down
Loading
Loading