Skip to content

Commit c6d0a75

Browse files
[NFC][SYCL] Pass queue_impl by raw ptr/ref (mostly scheduler)
Continuation of the refactoring efforts in #18715 #18748 #18830 #18907 #18983 #19006
1 parent 61eba7c commit c6d0a75

24 files changed

+154
-177
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
907907

908908
sycl::detail::EventImplPtr Event =
909909
sycl::detail::Scheduler::getInstance().addCG(
910-
Node->getCGCopy(), MQueueImpl,
910+
Node->getCGCopy(), *MQueueImpl,
911911
/*EventNeeded=*/true, CommandBuffer, Deps);
912912

913913
if (MIsUpdatable) {
@@ -1089,7 +1089,7 @@ EventImplPtr exec_graph_impl::enqueueHostTaskPartition(
10891089
NodeCommandGroup->getType()));
10901090

10911091
EventImplPtr SchedulerEvent = sycl::detail::Scheduler::getInstance().addCG(
1092-
std::move(CommandGroup), Queue.shared_from_this(), EventNeeded);
1092+
std::move(CommandGroup), Queue, EventNeeded);
10931093

10941094
if (EventNeeded) {
10951095
return SchedulerEvent;
@@ -1117,7 +1117,7 @@ EventImplPtr exec_graph_impl::enqueuePartitionWithScheduler(
11171117
CommandBuffer, nullptr, std::move(CGData));
11181118

11191119
EventImplPtr SchedulerEvent = sycl::detail::Scheduler::getInstance().addCG(
1120-
std::move(CommandGroup), Queue.shared_from_this(), EventNeeded);
1120+
std::move(CommandGroup), Queue, EventNeeded);
11211121

11221122
if (EventNeeded) {
11231123
SchedulerEvent->setEventFromSubmittedExecCommandBuffer(true);
@@ -1592,7 +1592,7 @@ void exec_graph_impl::update(
15921592
// other scheduler commands
15931593
auto UpdateEvent =
15941594
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
1595-
this, Nodes, MQueueImpl, std::move(UpdateRequirements),
1595+
this, Nodes, MQueueImpl.get(), std::move(UpdateRequirements),
15961596
MSchedulerDependencies);
15971597

15981598
MSchedulerDependencies.push_back(UpdateEvent);

sycl/source/detail/queue_impl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ queue_impl::get_backend_info<info::device::backend_version>() const {
118118
}
119119
#endif
120120

121-
static event prepareSYCLEventAssociatedWithQueue(
122-
const std::shared_ptr<detail::queue_impl> &QueueImpl) {
123-
auto EventImpl = detail::event_impl::create_device_event(*QueueImpl);
124-
EventImpl->setContextImpl(QueueImpl->getContextImpl());
121+
static event
122+
prepareSYCLEventAssociatedWithQueue(detail::queue_impl &QueueImpl) {
123+
auto EventImpl = detail::event_impl::create_device_event(QueueImpl);
124+
EventImpl->setContextImpl(QueueImpl.getContextImpl());
125125
EventImpl->setStateIncomplete();
126126
return detail::createSyclObjFromImpl<event>(EventImpl);
127127
}
@@ -464,7 +464,7 @@ event queue_impl::submitMemOpHelper(const std::vector<event> &DepEvents,
464464
event_impl::create_discarded_event());
465465
}
466466

467-
event ResEvent = prepareSYCLEventAssociatedWithQueue(shared_from_this());
467+
event ResEvent = prepareSYCLEventAssociatedWithQueue(*this);
468468
const auto &EventImpl = detail::getSyclObjImpl(ResEvent);
469469
{
470470
NestedCallsTracker tracker;

sycl/source/detail/queue_impl.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
653653
static ContextImplPtr getContext(queue_impl *Queue) {
654654
return Queue ? Queue->getContextImplPtr() : nullptr;
655655
}
656-
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
657-
return getContext(Queue.get());
658-
}
659656

660657
// Must be called under MMutex protection
661658
void doUnenqueuedCommandCleanup(
@@ -692,7 +689,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
692689
protected:
693690
template <typename HandlerType = handler>
694691
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {
695-
auto &Queue = Handler.impl->get_queue();
692+
queue_impl &Queue = Handler.impl->get_queue();
696693
auto ResEvent = detail::event_impl::create_device_event(Queue);
697694
ur_event_handle_t UREvent = nullptr;
698695
getAdapter()->call<UrApiKind::urEnqueueEventsWaitWithBarrier>(

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ static bool isOnSameContext(const ContextImplPtr Context, queue_impl *Queue) {
5656
// contexts comparison.
5757
return Context == queue_impl::getContext(Queue);
5858
}
59-
static bool isOnSameContext(const ContextImplPtr Context,
60-
const QueueImplPtr &Queue) {
61-
return isOnSameContext(Context, Queue.get());
62-
}
6359

6460
/// Checks if the required access mode is allowed under the current one.
6561
static bool isAccessModeAllowed(access::mode Required, access::mode Current) {
@@ -182,7 +178,7 @@ MemObjRecord *Scheduler::GraphBuilder::getMemObjRecord(SYCLMemObjI *MemObject) {
182178
}
183179

184180
MemObjRecord *
185-
Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
181+
Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
186182
const Requirement *Req) {
187183
SYCLMemObjI *MemObject = Req->MSYCLMemObj;
188184
MemObjRecord *Record = getMemObjRecord(MemObject);
@@ -230,8 +226,8 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
230226
MemObject->MRecord.reset(
231227
new MemObjRecord{InteropCtxPtr.get(), LeafLimit, AllocateDependency});
232228
std::vector<Command *> ToEnqueue;
233-
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr,
234-
ToEnqueue);
229+
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req,
230+
InteropQueuePtr.get(), ToEnqueue);
235231
assert(ToEnqueue.empty() && "Creation of the first alloca for a record "
236232
"shouldn't lead to any enqueuing (no linked "
237233
"alloca or exceeding the leaf limit).");
@@ -273,14 +269,13 @@ void Scheduler::GraphBuilder::addNodeToLeaves(
273269
}
274270

275271
UpdateHostRequirementCommand *Scheduler::GraphBuilder::insertUpdateHostReqCmd(
276-
MemObjRecord *Record, Requirement *Req, const QueueImplPtr &Queue,
272+
MemObjRecord *Record, Requirement *Req, queue_impl *Queue,
277273
std::vector<Command *> &ToEnqueue) {
278274
auto Context = queue_impl::getContext(Queue);
279275
AllocaCommandBase *AllocaCmd = findAllocaForReq(Record, Req, Context);
280276
assert(AllocaCmd && "There must be alloca for requirement!");
281277
UpdateHostRequirementCommand *UpdateCommand =
282-
new UpdateHostRequirementCommand(Queue.get(), *Req, AllocaCmd,
283-
&Req->MData);
278+
new UpdateHostRequirementCommand(Queue, *Req, AllocaCmd, &Req->MData);
284279
// Need copy of requirement because after host accessor destructor call
285280
// dependencies become invalid if requirement is stored by pointer.
286281
const Requirement *StoredReq = UpdateCommand->getRequirement();
@@ -329,9 +324,10 @@ static Command *insertMapUnmapForLinkedCmds(AllocaCommandBase *AllocaCmdSrc,
329324
return MapCmd;
330325
}
331326

332-
Command *Scheduler::GraphBuilder::insertMemoryMove(
333-
MemObjRecord *Record, Requirement *Req, const QueueImplPtr &Queue,
334-
std::vector<Command *> &ToEnqueue) {
327+
Command *
328+
Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
329+
Requirement *Req, queue_impl *Queue,
330+
std::vector<Command *> &ToEnqueue) {
335331
AllocaCommandBase *AllocaCmdDst =
336332
getOrCreateAllocaForReq(Record, Req, Queue, ToEnqueue);
337333
if (!AllocaCmdDst)
@@ -518,7 +514,7 @@ Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
518514
auto SYCLMemObj = static_cast<detail::SYCLMemObjT *>(Req->MSYCLMemObj);
519515
SYCLMemObj->handleWriteAccessorCreation();
520516
}
521-
// Host accessor is not attached to any queue so no QueueImplPtr object to be
517+
// Host accessor is not attached to any queue so no queue object to be
522518
// sent to getOrInsertMemObjRecord.
523519
MemObjRecord *Record = getOrInsertMemObjRecord(nullptr, Req);
524520
if (MPrintOptionsArray[BeforeAddHostAcc])
@@ -690,7 +686,7 @@ static bool checkHostUnifiedMemory(const ContextImplPtr &Ctx) {
690686
// Note, creation of new allocation command can lead to the current context
691687
// (Record->MCurContext) change.
692688
AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
693-
MemObjRecord *Record, const Requirement *Req, const QueueImplPtr &Queue,
689+
MemObjRecord *Record, const Requirement *Req, queue_impl *Queue,
694690
std::vector<Command *> &ToEnqueue) {
695691
auto Context = queue_impl::getContext(Queue);
696692
AllocaCommandBase *AllocaCmd =
@@ -709,8 +705,8 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
709705

710706
auto *ParentAlloca =
711707
getOrCreateAllocaForReq(Record, &ParentRequirement, Queue, ToEnqueue);
712-
AllocaCmd = new AllocaSubBufCommand(Queue.get(), *Req, ParentAlloca,
713-
ToEnqueue, ToCleanUp);
708+
AllocaCmd = new AllocaSubBufCommand(Queue, *Req, ParentAlloca, ToEnqueue,
709+
ToCleanUp);
714710
} else {
715711

716712
const Requirement FullReq(/*Offset*/ {0, 0, 0}, Req->MMemoryRange,
@@ -786,8 +782,8 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
786782
}
787783
}
788784

789-
AllocaCmd = new AllocaCommand(Queue.get(), FullReq, InitFromUserData,
790-
LinkedAllocaCmd);
785+
AllocaCmd =
786+
new AllocaCommand(Queue, FullReq, InitFromUserData, LinkedAllocaCmd);
791787

792788
// Update linked command
793789
if (LinkedAllocaCmd) {
@@ -925,16 +921,16 @@ static void combineAccessModesOfReqs(std::vector<Requirement *> &Reqs) {
925921
}
926922

927923
Command *Scheduler::GraphBuilder::addCG(
928-
std::unique_ptr<detail::CG> CommandGroup, const QueueImplPtr &Queue,
924+
std::unique_ptr<detail::CG> CommandGroup, queue_impl *Queue,
929925
std::vector<Command *> &ToEnqueue, bool EventNeeded,
930926
ur_exp_command_buffer_handle_t CommandBuffer,
931927
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies) {
932928
std::vector<Requirement *> &Reqs = CommandGroup->getRequirements();
933929
std::vector<detail::EventImplPtr> &Events = CommandGroup->getEvents();
934930

935-
auto NewCmd = std::make_unique<ExecCGCommand>(
936-
std::move(CommandGroup), Queue.get(), EventNeeded, CommandBuffer,
937-
std::move(Dependencies));
931+
auto NewCmd = std::make_unique<ExecCGCommand>(std::move(CommandGroup), Queue,
932+
EventNeeded, CommandBuffer,
933+
std::move(Dependencies));
938934

939935
if (!NewCmd)
940936
throw exception(make_error_code(errc::memory_allocation),
@@ -957,9 +953,9 @@ Command *Scheduler::GraphBuilder::addCG(
957953
bool isSameCtx = false;
958954

959955
{
960-
const QueueImplPtr &QueueForAlloca =
956+
queue_impl *QueueForAlloca =
961957
isInteropTask
962-
? static_cast<detail::CGHostTask &>(NewCmd->getCG()).MQueue
958+
? static_cast<detail::CGHostTask &>(NewCmd->getCG()).MQueue.get()
963959
: Queue;
964960

965961
Record = getOrInsertMemObjRecord(QueueForAlloca, Req);
@@ -989,15 +985,15 @@ Command *Scheduler::GraphBuilder::addCG(
989985
// Cannot directly copy memory from OpenCL device to OpenCL device -
990986
// create two copies: device->host and host->device.
991987
bool NeedMemMoveToHost = false;
992-
auto MemMoveTargetQueue = Queue;
988+
queue_impl *MemMoveTargetQueue = Queue;
993989

994990
if (isInteropTask) {
995991
const detail::CGHostTask &HT =
996992
static_cast<detail::CGHostTask &>(NewCmd->getCG());
997993

998-
if (!isOnSameContext(Record->MCurContext, HT.MQueue)) {
994+
if (!isOnSameContext(Record->MCurContext, HT.MQueue.get())) {
999995
NeedMemMoveToHost = true;
1000-
MemMoveTargetQueue = HT.MQueue;
996+
MemMoveTargetQueue = HT.MQueue.get();
1001997
}
1002998
} else if (Queue && Record->MCurContext)
1003999
NeedMemMoveToHost = true;
@@ -1229,7 +1225,9 @@ Command *Scheduler::GraphBuilder::connectDepEvent(
12291225
try {
12301226
std::shared_ptr<detail::HostTask> HT(new detail::HostTask);
12311227
std::unique_ptr<detail::CG> ConnectCG(new detail::CGHostTask(
1232-
std::move(HT), /* Queue = */ Cmd->getQueue(), /* Context = */ {},
1228+
std::move(HT),
1229+
/* Queue = */ Cmd->getQueue(),
1230+
/* Context = */ {},
12331231
/* Args = */ {},
12341232
detail::CG::StorageInitHelper(
12351233
/* ArgsStorage = */ {}, /* AccStorage = */ {},
@@ -1280,11 +1278,11 @@ Command *Scheduler::GraphBuilder::addCommandGraphUpdate(
12801278
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
12811279
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
12821280
Nodes,
1283-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
1281+
queue_impl *Queue, std::vector<Requirement *> Requirements,
12841282
std::vector<detail::EventImplPtr> &Events,
12851283
std::vector<Command *> &ToEnqueue) {
12861284
auto NewCmd =
1287-
std::make_unique<UpdateCommandBufferCommand>(Queue.get(), Graph, Nodes);
1285+
std::make_unique<UpdateCommandBufferCommand>(Queue, Graph, Nodes);
12881286
// If there are multiple requirements for the same memory object, its
12891287
// AllocaCommand creation will be dependent on the access mode of the first
12901288
// requirement. Combine these access modes to take all of them into account.

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void Scheduler::waitForRecordToFinish(MemObjRecord *Record,
102102
}
103103

104104
EventImplPtr Scheduler::addCG(
105-
std::unique_ptr<detail::CG> CommandGroup, const QueueImplPtr &Queue,
105+
std::unique_ptr<detail::CG> CommandGroup, queue_impl &Queue,
106106
bool EventNeeded, ur_exp_command_buffer_handle_t CommandBuffer,
107107
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies) {
108108
EventImplPtr NewEvent = nullptr;
@@ -127,7 +127,7 @@ EventImplPtr Scheduler::addCG(
127127
break;
128128
}
129129
default:
130-
NewCmd = MGraphBuilder.addCG(std::move(CommandGroup), std::move(Queue),
130+
NewCmd = MGraphBuilder.addCG(std::move(CommandGroup), &Queue,
131131
AuxiliaryCmds, EventNeeded, CommandBuffer,
132132
std::move(Dependencies));
133133
}
@@ -645,7 +645,7 @@ EventImplPtr Scheduler::addCommandGraphUpdate(
645645
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
646646
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
647647
Nodes,
648-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
648+
queue_impl *Queue, std::vector<Requirement *> Requirements,
649649
std::vector<detail::EventImplPtr> &Events) {
650650
std::vector<Command *> AuxiliaryCmds;
651651
EventImplPtr NewCmdEvent = nullptr;

sycl/source/detail/scheduler/scheduler.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ class DispatchHostTask;
187187

188188
using ContextImplPtr = std::shared_ptr<detail::context_impl>;
189189
using EventImplPtr = std::shared_ptr<detail::event_impl>;
190-
using QueueImplPtr = std::shared_ptr<detail::queue_impl>;
191190
using StreamImplPtr = std::shared_ptr<detail::stream_impl>;
192191

193192
using CommandPtr = std::unique_ptr<Command>;
@@ -379,7 +378,7 @@ class Scheduler {
379378
/// \return an event object to wait on for command group completion. It can
380379
/// be a discarded event.
381380
EventImplPtr addCG(
382-
std::unique_ptr<detail::CG> CommandGroup, const QueueImplPtr &Queue,
381+
std::unique_ptr<detail::CG> CommandGroup, queue_impl &Queue,
383382
bool EventNeeded, ur_exp_command_buffer_handle_t CommandBuffer = nullptr,
384383
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies = {});
385384

@@ -477,7 +476,7 @@ class Scheduler {
477476
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
478477
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
479478
Nodes,
480-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
479+
queue_impl *Queue, std::vector<Requirement *> Requirements,
481480
std::vector<detail::EventImplPtr> &Events);
482481

483482
static bool CheckEventReadiness(context_impl &Context,
@@ -560,9 +559,8 @@ class Scheduler {
560559
/// \return a command that represents command group execution and a bool
561560
/// indicating whether this command should be enqueued to the graph
562561
/// processor right away or not.
563-
Command *addCG(std::unique_ptr<detail::CG> CommandGroup,
564-
const QueueImplPtr &Queue, std::vector<Command *> &ToEnqueue,
565-
bool EventNeeded,
562+
Command *addCG(std::unique_ptr<detail::CG> CommandGroup, queue_impl *Queue,
563+
std::vector<Command *> &ToEnqueue, bool EventNeeded,
566564
ur_exp_command_buffer_handle_t CommandBuffer = nullptr,
567565
const std::vector<ur_exp_command_buffer_sync_point_t>
568566
&Dependencies = {});
@@ -600,15 +598,15 @@ class Scheduler {
600598
/// used when the user provides a "secondary" queue to the submit method
601599
/// which may be used when the command fails to enqueue/execute in the
602600
/// primary queue.
603-
void rescheduleCommand(Command *Cmd, const QueueImplPtr &Queue);
601+
void rescheduleCommand(Command *Cmd, queue_impl *Queue);
604602

605603
/// \return a pointer to the corresponding memory object record for the
606604
/// SYCL memory object provided, or nullptr if it does not exist.
607605
MemObjRecord *getMemObjRecord(SYCLMemObjI *MemObject);
608606

609607
/// \return a pointer to MemObjRecord for pointer to memory object. If the
610608
/// record is not found, nullptr is returned.
611-
MemObjRecord *getOrInsertMemObjRecord(const QueueImplPtr &Queue,
609+
MemObjRecord *getOrInsertMemObjRecord(queue_impl *Queue,
612610
const Requirement *Req);
613611

614612
/// Decrements leaf counters for all leaves of the record.
@@ -656,7 +654,7 @@ class Scheduler {
656654
std::vector<
657655
std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
658656
Nodes,
659-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
657+
queue_impl *Queue, std::vector<Requirement *> Requirements,
660658
std::vector<detail::EventImplPtr> &Events,
661659
std::vector<Command *> &ToEnqueue);
662660

@@ -673,7 +671,7 @@ class Scheduler {
673671
/// \param Req is a Requirement describing destination.
674672
/// \param Queue is a queue that is bound to target context.
675673
Command *insertMemoryMove(MemObjRecord *Record, Requirement *Req,
676-
const QueueImplPtr &Queue,
674+
queue_impl *Queue,
677675
std::vector<Command *> &ToEnqueue);
678676

679677
// Inserts commands required to remap the memory object to its current host
@@ -684,7 +682,7 @@ class Scheduler {
684682

685683
UpdateHostRequirementCommand *
686684
insertUpdateHostReqCmd(MemObjRecord *Record, Requirement *Req,
687-
const QueueImplPtr &Queue,
685+
queue_impl *Queue,
688686
std::vector<Command *> &ToEnqueue);
689687

690688
/// Finds dependencies for the requirement.
@@ -717,7 +715,7 @@ class Scheduler {
717715
/// If none found, creates new one.
718716
AllocaCommandBase *
719717
getOrCreateAllocaForReq(MemObjRecord *Record, const Requirement *Req,
720-
const QueueImplPtr &Queue,
718+
queue_impl *Queue,
721719
std::vector<Command *> &ToEnqueue);
722720

723721
void markModifiedIfWrite(MemObjRecord *Record, Requirement *Req);

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ event handler::finalize() {
934934
CommandGroup->getRequirements().size() == 0;
935935

936936
detail::EventImplPtr Event = detail::Scheduler::getInstance().addCG(
937-
std::move(CommandGroup), Queue->shared_from_this(), !DiscardEvent);
937+
std::move(CommandGroup), *Queue, !DiscardEvent);
938938

939939
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
940940
MLastEvent = DiscardEvent ? nullptr : Event;

0 commit comments

Comments
 (0)