Skip to content

Commit 1e36d65

Browse files
use the right queue for memory migration
Signed-off-by: Tikhomirova, Kseniya <[email protected]>
1 parent 06fec7d commit 1e36d65

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

sycl/source/detail/scheduler/commands.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class Command {
228228
/// Get the context of the queue this command will be submitted to. Could
229229
/// differ from the context of MQueue for memory copy commands.
230230
ContextImplPtr getWorkerContext() const;
231+
QueueImplPtr getWorkerQueue() const { return MWorkerQueue; }
231232

232233
/// Returns true iff the command produces a UR event on non-host devices.
233234
virtual bool producesPiEvent() const;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,10 @@ static Command *insertMapUnmapForLinkedCmds(AllocaCommandBase *AllocaCmdSrc,
304304
assert(AllocaCmdSrc->MIsActive &&
305305
"Expected source alloca command to be active");
306306

307-
if (!AllocaCmdSrc->getQueue()) {
307+
if (!AllocaCmdSrc->getWorkerQueue()) {
308308
UnMapMemObject *UnMapCmd = new UnMapMemObject(
309309
AllocaCmdDst, *AllocaCmdDst->getRequirement(),
310-
&AllocaCmdSrc->MMemAllocation, AllocaCmdDst->getQueue());
310+
&AllocaCmdSrc->MMemAllocation, AllocaCmdDst->getWorkerQueue());
311311

312312
std::swap(AllocaCmdSrc->MIsActive, AllocaCmdDst->MIsActive);
313313

@@ -316,7 +316,7 @@ static Command *insertMapUnmapForLinkedCmds(AllocaCommandBase *AllocaCmdSrc,
316316

317317
MapMemObject *MapCmd = new MapMemObject(
318318
AllocaCmdSrc, *AllocaCmdSrc->getRequirement(),
319-
&AllocaCmdDst->MMemAllocation, AllocaCmdSrc->getQueue(), MapMode);
319+
&AllocaCmdDst->MMemAllocation, AllocaCmdSrc->getWorkerQueue(), MapMode);
320320

321321
std::swap(AllocaCmdSrc->MIsActive, AllocaCmdDst->MIsActive);
322322

@@ -349,9 +349,10 @@ Command *Scheduler::GraphBuilder::insertMemoryMove(
349349
// current context, need to find a parent alloca command for it (it must be
350350
// there)
351351
auto IsSuitableAlloca = [Record](AllocaCommandBase *AllocaCmd) {
352-
bool Res = isOnSameContext(Record->MCurContext, AllocaCmd->getQueue()) &&
353-
// Looking for a parent buffer alloca command
354-
AllocaCmd->getType() == Command::CommandType::ALLOCA;
352+
bool Res =
353+
isOnSameContext(Record->MCurContext, AllocaCmd->getWorkerQueue()) &&
354+
// Looking for a parent buffer alloca command
355+
AllocaCmd->getType() == Command::CommandType::ALLOCA;
355356
return Res;
356357
};
357358
const auto It =
@@ -389,10 +390,10 @@ Command *Scheduler::GraphBuilder::insertMemoryMove(
389390
} else {
390391
// Full copy of buffer is needed to avoid loss of data that may be caused
391392
// by copying specific range from host to device and backwards.
392-
NewCmd =
393-
new MemCpyCommand(*AllocaCmdSrc->getRequirement(), AllocaCmdSrc,
394-
*AllocaCmdDst->getRequirement(), AllocaCmdDst,
395-
AllocaCmdSrc->getQueue(), AllocaCmdDst->getQueue());
393+
NewCmd = new MemCpyCommand(*AllocaCmdSrc->getRequirement(), AllocaCmdSrc,
394+
*AllocaCmdDst->getRequirement(), AllocaCmdDst,
395+
AllocaCmdSrc->getWorkerQueue(),
396+
AllocaCmdDst->getWorkerQueue());
396397
}
397398
}
398399
std::vector<Command *> ToCleanUp;
@@ -413,7 +414,7 @@ Command *Scheduler::GraphBuilder::insertMemoryMove(
413414
Command *Scheduler::GraphBuilder::remapMemoryObject(
414415
MemObjRecord *Record, Requirement *Req, AllocaCommandBase *HostAllocaCmd,
415416
std::vector<Command *> &ToEnqueue) {
416-
assert(!HostAllocaCmd->getQueue() && "Host alloca command expected");
417+
assert(!HostAllocaCmd->getWorkerQueue() && "Host alloca command expected");
417418
assert(HostAllocaCmd->MIsActive && "Active alloca command expected");
418419

419420
AllocaCommandBase *LinkedAllocaCmd = HostAllocaCmd->MLinkedAllocaCmd;
@@ -423,15 +424,16 @@ Command *Scheduler::GraphBuilder::remapMemoryObject(
423424

424425
UnMapMemObject *UnMapCmd = new UnMapMemObject(
425426
LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(),
426-
&HostAllocaCmd->MMemAllocation, LinkedAllocaCmd->getQueue());
427+
&HostAllocaCmd->MMemAllocation, LinkedAllocaCmd->getWorkerQueue());
427428

428429
// Map write only as read-write
429430
access::mode MapMode = Req->MAccessMode;
430431
if (MapMode == access::mode::write)
431432
MapMode = access::mode::read_write;
432-
MapMemObject *MapCmd = new MapMemObject(
433-
LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(),
434-
&HostAllocaCmd->MMemAllocation, LinkedAllocaCmd->getQueue(), MapMode);
433+
MapMemObject *MapCmd =
434+
new MapMemObject(LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(),
435+
&HostAllocaCmd->MMemAllocation,
436+
LinkedAllocaCmd->getWorkerQueue(), MapMode);
435437

436438
std::vector<Command *> ToCleanUp;
437439
for (Command *Dep : Deps) {
@@ -474,7 +476,7 @@ Scheduler::GraphBuilder::addCopyBack(Requirement *Req,
474476

475477
auto MemCpyCmdUniquePtr = std::make_unique<MemCpyToHostCommand>(
476478
*SrcAllocaCmd->getRequirement(), SrcAllocaCmd, *Req, &Req->MData,
477-
SrcAllocaCmd->getQueue());
479+
SrcAllocaCmd->getWorkerQueue());
478480

479481
if (!MemCpyCmdUniquePtr)
480482
throw exception(make_error_code(errc::memory_allocation),
@@ -522,7 +524,7 @@ Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
522524
AllocaCommandBase *HostAllocaCmd =
523525
getOrCreateAllocaForReq(Record, Req, nullptr, ToEnqueue);
524526

525-
if (isOnSameContext(Record->MCurContext, HostAllocaCmd->getQueue())) {
527+
if (isOnSameContext(Record->MCurContext, HostAllocaCmd->getWorkerQueue())) {
526528
if (!isAccessModeAllowed(Req->MAccessMode, Record->MHostAccess)) {
527529
remapMemoryObject(Record, Req,
528530
Req->MIsSubBuffer ? (static_cast<AllocaSubBufCommand *>(
@@ -603,7 +605,7 @@ Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record,
603605

604606
// Going through copying memory between contexts is not supported.
605607
if (Dep.MDepCommand) {
606-
auto DepQueue = Dep.MDepCommand->getQueue();
608+
auto DepQueue = Dep.MDepCommand->getWorkerQueue();
607609
CanBypassDep &= isOnSameContext(Context, DepQueue);
608610
}
609611

@@ -644,7 +646,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
644646
bool AllowConst) {
645647
auto IsSuitableAlloca = [&Context, Req,
646648
AllowConst](AllocaCommandBase *AllocaCmd) {
647-
bool Res = isOnSameContext(Context, AllocaCmd->getQueue());
649+
bool Res = isOnSameContext(Context, AllocaCmd->getWorkerQueue());
648650
if (IsSuitableSubReq(Req)) {
649651
const Requirement *TmpReq = AllocaCmd->getRequirement();
650652
Res &= AllocaCmd->getType() == Command::CommandType::ALLOCA_SUB_BUF;
@@ -1222,7 +1224,7 @@ Command *Scheduler::GraphBuilder::connectDepEvent(
12221224
try {
12231225
std::shared_ptr<detail::HostTask> HT(new detail::HostTask);
12241226
std::unique_ptr<detail::CG> ConnectCG(new detail::CGHostTask(
1225-
std::move(HT), /* Queue = */ Cmd->getQueue(), /* Context = */ {},
1227+
std::move(HT), /* Queue = */ Cmd->getWorkerQueue(), /* Context = */ {},
12261228
/* Args = */ {},
12271229
detail::CG::StorageInitHelper(
12281230
/* ArgsStorage = */ {}, /* AccStorage = */ {},

0 commit comments

Comments
 (0)