Skip to content

Commit 321d0a0

Browse files
Merge remote-tracking branches 'my/queue_impl', 'my/get-worker-context' and 'my/context_impl-queue_impl-ctor' into HEAD
3 parents c6d0a75 + bd88f59 + 6ca9490 commit 321d0a0

File tree

11 files changed

+49
-53
lines changed

11 files changed

+49
-53
lines changed

sycl/source/backend.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
126126
ur_device_handle_t UrDevice =
127127
Device ? getSyclObjImpl(*Device)->getHandleRef() : nullptr;
128128
const auto &Adapter = getAdapter(Backend);
129-
const auto &ContextImpl = getSyclObjImpl(Context);
129+
context_impl &ContextImpl = *getSyclObjImpl(Context);
130130

131131
if (PropList.has_property<ext::intel::property::queue::compute_index>()) {
132132
throw sycl::exception(
@@ -156,7 +156,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
156156
ur_queue_handle_t UrQueue = nullptr;
157157

158158
Adapter->call<UrApiKind::urQueueCreateWithNativeHandle>(
159-
NativeHandle, ContextImpl->getHandleRef(), UrDevice, &NativeProperties,
159+
NativeHandle, ContextImpl.getHandleRef(), UrDevice, &NativeProperties,
160160
&UrQueue);
161161
// Construct the SYCL queue from UR queue.
162162
return detail::createSyclObjFromImpl<queue>(

sycl/source/detail/graph_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
996996
: MSchedule(), MGraphImpl(GraphImpl), MSyncPoints(),
997997
MQueueImpl(sycl::detail::queue_impl::create(
998998
*sycl::detail::getSyclObjImpl(GraphImpl->getDevice()),
999-
sycl::detail::getSyclObjImpl(Context), sycl::async_handler{},
999+
*sycl::detail::getSyclObjImpl(Context), sycl::async_handler{},
10001000
sycl::property_list{})),
10011001
MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(),
10021002
MSchedulerDependencies(),

sycl/source/detail/queue_impl.hpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
117117
/// constructed.
118118
/// \param AsyncHandler is a SYCL asynchronous exception handler.
119119
/// \param PropList is a list of properties to use for queue construction.
120-
queue_impl(device_impl &Device, const ContextImplPtr &Context,
120+
queue_impl(device_impl &Device, std::shared_ptr<context_impl> &&Context,
121121
const async_handler &AsyncHandler, const property_list &PropList,
122122
private_tag)
123-
: MDevice(Device), MContext(Context), MAsyncHandler(AsyncHandler),
124-
MPropList(PropList),
123+
: MDevice(Device), MContext(std::move(Context)),
124+
MAsyncHandler(AsyncHandler), MPropList(PropList),
125125
MIsInorder(has_property<property::queue::in_order>()),
126126
MIsProfilingEnabled(has_property<property::queue::enable_profiling>()),
127127
MQueueID{
@@ -146,8 +146,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
146146
"Queue compute index must be a non-negative number less than "
147147
"device's number of available compute queue indices.");
148148
}
149-
if (!Context->isDeviceValid(Device)) {
150-
if (Context->getBackend() == backend::opencl)
149+
if (!MContext->isDeviceValid(Device)) {
150+
if (MContext->getBackend() == backend::opencl)
151151
throw sycl::exception(
152152
make_error_code(errc::invalid),
153153
"Queue cannot be constructed with the given context and device "
@@ -177,17 +177,13 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
177177
trySwitchingToNoEventsMode();
178178
}
179179

180-
sycl::detail::optional<event> getLastEvent();
180+
queue_impl(device_impl &Device, context_impl &Context,
181+
const async_handler &AsyncHandler, const property_list &PropList,
182+
private_tag Tag)
183+
: queue_impl(Device, Context.shared_from_this(), AsyncHandler, PropList,
184+
Tag) {}
181185

182-
/// Constructs a SYCL queue from adapter interoperability handle.
183-
///
184-
/// \param UrQueue is a raw UR queue handle.
185-
/// \param Context is a SYCL context to associate with the queue being
186-
/// constructed.
187-
/// \param AsyncHandler is a SYCL asynchronous exception handler.
188-
queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context,
189-
const async_handler &AsyncHandler, private_tag tag)
190-
: queue_impl(UrQueue, Context, AsyncHandler, {}, tag) {}
186+
sycl::detail::optional<event> getLastEvent();
191187

192188
/// Constructs a SYCL queue from adapter interoperability handle.
193189
///
@@ -196,27 +192,28 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
196192
/// constructed.
197193
/// \param AsyncHandler is a SYCL asynchronous exception handler.
198194
/// \param PropList is the queue properties.
199-
queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context,
195+
queue_impl(ur_queue_handle_t UrQueue, context_impl &Context,
200196
const async_handler &AsyncHandler, const property_list &PropList,
201197
private_tag)
202198
: MDevice([&]() -> device_impl & {
203199
ur_device_handle_t DeviceUr{};
204-
const AdapterPtr &Adapter = Context->getAdapter();
200+
const AdapterPtr &Adapter = Context.getAdapter();
205201
// TODO catch an exception and put it to list of asynchronous
206202
// exceptions
207203
Adapter->call<UrApiKind::urQueueGetInfo>(
208204
UrQueue, UR_QUEUE_INFO_DEVICE, sizeof(DeviceUr), &DeviceUr,
209205
nullptr);
210-
device_impl *Device = Context->findMatchingDeviceImpl(DeviceUr);
206+
device_impl *Device = Context.findMatchingDeviceImpl(DeviceUr);
211207
if (Device == nullptr) {
212208
throw sycl::exception(
213209
make_error_code(errc::invalid),
214210
"Device provided by native Queue not found in Context.");
215211
}
216212
return *Device;
217213
}()),
218-
MContext(Context), MAsyncHandler(AsyncHandler), MPropList(PropList),
219-
MQueue(UrQueue), MIsInorder(has_property<property::queue::in_order>()),
214+
MContext(Context.shared_from_this()), MAsyncHandler(AsyncHandler),
215+
MPropList(PropList), MQueue(UrQueue),
216+
MIsInorder(has_property<property::queue::in_order>()),
220217
MIsProfilingEnabled(has_property<property::queue::enable_profiling>()),
221218
MQueueID{
222219
MNextAvailableQueueID.fetch_add(1, std::memory_order_relaxed)} {
@@ -985,7 +982,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
985982
mutable std::mutex MMutex;
986983

987984
device_impl &MDevice;
988-
const ContextImplPtr MContext;
985+
const std::shared_ptr<context_impl> MContext;
989986

990987
/// These events are tracked, but not owned, by the queue.
991988
std::vector<std::weak_ptr<event_impl>> MEventsWeak;

sycl/source/detail/scheduler/commands.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,6 @@ void Command::makeTraceEventEpilog() {
759759

760760
Command *Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep,
761761
std::vector<Command *> &ToCleanUp) {
762-
const ContextImplPtr &WorkerContext = getWorkerContext();
763762

764763
// 1. Non-host events can be ignored if they are not fully initialized.
765764
// 2. Some types of commands do not produce UR events after they are
@@ -780,8 +779,9 @@ Command *Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep,
780779
Command *ConnectionCmd = nullptr;
781780

782781
context_impl &DepEventContext = DepEvent->getContextImpl();
782+
context_impl *WorkerContext = getWorkerContext();
783783
// If contexts don't match we'll connect them using host task
784-
if (&DepEventContext != WorkerContext.get() && WorkerContext) {
784+
if (&DepEventContext != WorkerContext && WorkerContext) {
785785
Scheduler::GraphBuilder &GB = Scheduler::getInstance().MGraphBuilder;
786786
ConnectionCmd = GB.connectDepEvent(this, DepEvent, Dep, ToCleanUp);
787787
} else
@@ -790,10 +790,10 @@ Command *Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep,
790790
return ConnectionCmd;
791791
}
792792

793-
ContextImplPtr Command::getWorkerContext() const {
793+
context_impl *Command::getWorkerContext() const {
794794
if (!MQueue)
795795
return nullptr;
796-
return MQueue->getContextImplPtr();
796+
return &MQueue->getContextImpl();
797797
}
798798

799799
bool Command::producesPiEvent() const { return true; }
@@ -1547,10 +1547,10 @@ void MemCpyCommand::emitInstrumentationData() {
15471547
#endif
15481548
}
15491549

1550-
ContextImplPtr MemCpyCommand::getWorkerContext() const {
1550+
context_impl *MemCpyCommand::getWorkerContext() const {
15511551
if (!MWorkerQueue)
15521552
return nullptr;
1553-
return MWorkerQueue->getContextImplPtr();
1553+
return &MWorkerQueue->getContextImpl();
15541554
}
15551555

15561556
bool MemCpyCommand::producesPiEvent() const {
@@ -1720,10 +1720,10 @@ void MemCpyCommandHost::emitInstrumentationData() {
17201720
#endif
17211721
}
17221722

1723-
ContextImplPtr MemCpyCommandHost::getWorkerContext() const {
1723+
context_impl *MemCpyCommandHost::getWorkerContext() const {
17241724
if (!MWorkerQueue)
17251725
return nullptr;
1726-
return MWorkerQueue->getContextImplPtr();
1726+
return &MWorkerQueue->getContextImpl();
17271727
}
17281728

17291729
ur_result_t MemCpyCommandHost::enqueueImp() {

sycl/source/detail/scheduler/commands.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class Command {
221221

222222
/// Get the context of the queue this command will be submitted to. Could
223223
/// differ from the context of MQueue for memory copy commands.
224-
virtual ContextImplPtr getWorkerContext() const;
224+
virtual context_impl *getWorkerContext() const;
225225

226226
/// Returns true iff the command produces a UR event on non-host devices.
227227
virtual bool producesPiEvent() const;
@@ -584,7 +584,7 @@ class MemCpyCommand : public Command {
584584
void printDot(std::ostream &Stream) const final;
585585
const Requirement *getRequirement() const final { return &MDstReq; }
586586
void emitInstrumentationData() final;
587-
ContextImplPtr getWorkerContext() const final;
587+
context_impl *getWorkerContext() const final;
588588
bool producesPiEvent() const final;
589589

590590
private:
@@ -608,7 +608,7 @@ class MemCpyCommandHost : public Command {
608608
void printDot(std::ostream &Stream) const final;
609609
const Requirement *getRequirement() const final { return &MDstReq; }
610610
void emitInstrumentationData() final;
611-
ContextImplPtr getWorkerContext() const final;
611+
context_impl *getWorkerContext() const final;
612612

613613
private:
614614
ur_result_t enqueueImp() final;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
206206
cleanupCommand(Cmd);
207207
};
208208

209-
const ContextImplPtr &InteropCtxPtr = Req->MSYCLMemObj->getInteropContext();
209+
context_impl *InteropCtxPtr = Req->MSYCLMemObj->getInteropContext();
210210
if (InteropCtxPtr) {
211211
// The memory object has been constructed using interoperability constructor
212212
// which means that there is already an allocation(cl_mem) in some context.
@@ -221,10 +221,10 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
221221
// here, we need to create a dummy queue bound to the context and one of the
222222
// devices from the context.
223223
std::shared_ptr<queue_impl> InteropQueuePtr = queue_impl::create(
224-
Dev, InteropCtxPtr, async_handler{}, property_list{});
224+
Dev, *InteropCtxPtr, async_handler{}, property_list{});
225225

226226
MemObject->MRecord.reset(
227-
new MemObjRecord{InteropCtxPtr.get(), LeafLimit, AllocateDependency});
227+
new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency});
228228
std::vector<Command *> ToEnqueue;
229229
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req,
230230
InteropQueuePtr.get(), ToEnqueue);
@@ -1217,7 +1217,7 @@ void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {
12171217
Command *Scheduler::GraphBuilder::connectDepEvent(
12181218
Command *const Cmd, const EventImplPtr &DepEvent, const DepDesc &Dep,
12191219
std::vector<Command *> &ToCleanUp) {
1220-
assert(Cmd->getWorkerContext().get() != &DepEvent->getContextImpl());
1220+
assert(Cmd->getWorkerContext() != &DepEvent->getContextImpl());
12211221

12221222
// construct Host Task type command manually and make it depend on DepEvent
12231223
ExecCGCommand *ConnectCmd = nullptr;

sycl/source/detail/sycl_mem_obj_i.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class context_impl;
2222
struct MemObjRecord;
2323

2424
using EventImplPtr = std::shared_ptr<detail::event_impl>;
25-
using ContextImplPtr = std::shared_ptr<detail::context_impl>;
2625

2726
// The class serves as an interface in the scheduler for all SYCL memory
2827
// objects.
@@ -72,7 +71,7 @@ class SYCLMemObjI {
7271

7372
// Returns the context which is passed if a memory object is created using
7473
// interoperability constructor, nullptr otherwise.
75-
virtual ContextImplPtr getInteropContext() const = 0;
74+
virtual detail::context_impl *getInteropContext() const = 0;
7675

7776
protected:
7877
// Pointer to the record that contains the memory commands. This is managed

sycl/source/detail/sycl_mem_obj_t.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class event_impl;
3636
class Adapter;
3737
using AdapterPtr = std::shared_ptr<Adapter>;
3838

39-
using ContextImplPtr = std::shared_ptr<context_impl>;
4039
using EventImplPtr = std::shared_ptr<event_impl>;
4140

4241
// The class serves as a base for all SYCL memory objects.
@@ -281,7 +280,9 @@ class SYCLMemObjT : public SYCLMemObjI {
281280

282281
MemObjType getType() const override { return MemObjType::Undefined; }
283282

284-
ContextImplPtr getInteropContext() const override { return MInteropContext; }
283+
context_impl *getInteropContext() const override {
284+
return MInteropContext.get();
285+
}
285286

286287
bool isInterop() const override;
287288

@@ -339,7 +340,7 @@ class SYCLMemObjT : public SYCLMemObjI {
339340
// Should wait on this event before start working with such memory object.
340341
EventImplPtr MInteropEvent;
341342
// Context passed by user to interoperability constructor.
342-
ContextImplPtr MInteropContext;
343+
std::shared_ptr<context_impl> MInteropContext;
343344
// Native backend memory object handle passed by user to interoperability
344345
// constructor.
345346
ur_mem_handle_t MInteropMemObject;

sycl/source/queue.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ queue::queue(const context &SyclContext, const device_selector &DeviceSelector,
6565
const device &SyclDevice = *std::max_element(Devs.begin(), Devs.end(), Comp);
6666

6767
impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice),
68-
detail::getSyclObjImpl(SyclContext),
68+
*detail::getSyclObjImpl(SyclContext),
6969
AsyncHandler, PropList);
7070
}
7171

7272
queue::queue(const context &SyclContext, const device &SyclDevice,
7373
const async_handler &AsyncHandler, const property_list &PropList) {
7474
impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice),
75-
detail::getSyclObjImpl(SyclContext),
75+
*detail::getSyclObjImpl(SyclContext),
7676
AsyncHandler, PropList);
7777
}
7878

@@ -100,7 +100,7 @@ queue::queue(cl_command_queue clQueue, const context &SyclContext,
100100
impl = detail::queue_impl::create(
101101
// TODO(pi2ur): Don't cast straight from cl_command_queue
102102
reinterpret_cast<ur_queue_handle_t>(clQueue),
103-
detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList);
103+
*detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList);
104104
}
105105

106106
cl_command_queue queue::get() const { return impl->get(); }

sycl/unittests/scheduler/HostTaskAndBarrier.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
namespace {
2121
using namespace sycl;
2222
using EventImplPtr = std::shared_ptr<sycl::detail::event_impl>;
23-
using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
2423

2524
constexpr auto DisableCleanupName = "SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP";
2625

2726
class TestQueueImpl : public sycl::detail::queue_impl {
2827
public:
29-
TestQueueImpl(ContextImplPtr SyclContext, sycl::detail::device_impl &Dev)
28+
TestQueueImpl(sycl::detail::context_impl &SyclContext,
29+
sycl::detail::device_impl &Dev)
3030
: sycl::detail::queue_impl(Dev, SyclContext,
31-
SyclContext->get_async_handler(), {},
31+
SyclContext.get_async_handler(), {},
3232
sycl::detail::queue_impl::private_tag{}) {}
3333
using sycl::detail::queue_impl::MDefaultGraphDeps;
3434
using sycl::detail::queue_impl::MExtGraphDeps;
@@ -46,7 +46,7 @@ class BarrierHandlingWithHostTask : public ::testing::Test {
4646
sycl::device SyclDev =
4747
sycl::detail::select_device(sycl::default_selector_v, SyclContext);
4848
QueueDevImpl.reset(
49-
new TestQueueImpl(sycl::detail::getSyclObjImpl(SyclContext),
49+
new TestQueueImpl(*sycl::detail::getSyclObjImpl(SyclContext),
5050
*sycl::detail::getSyclObjImpl(SyclDev)));
5151

5252
MainLock.lock();

0 commit comments

Comments
 (0)