Skip to content

Commit fa81c82

Browse files
Ewan CrawfordRossBrunton
andcommitted
[SYCL][Graph][UR] Propagate graph update list to UR
Update the `urCommandBufferUpdateKernelLaunchExp` API for updating commands in a command-buffer to take a list of commands. The current API operates on a single command, this means that the SYCL-Graph `update(std::vector<nodes>)` API needs to serialize the list into N calls to the UR API. Given that both OpenCL `clUpdateMutableCommandsKHR` and Level-Zero `zeCommandListUpdateMutableCommandsExp` can operate on a list of commands, this serialization at the UR layer of the stack introduces extra host API calls. This PR improves the `urCommandBufferUpdateKernelLaunchExp` API so that a list of commands is passed all the way from SYCL to the native backend API. As highlighted in oneapi-src/unified-runtime#2671 this patch requires the handle translation auto generated code to be able to handle a list of structs, which is not currently the case. This is PR includes a API specific temporary workaround in the mako file which will unblock this PR until a more permanent solution is completed. Co-authored-by: Ross Brunton <[email protected]>
1 parent f022906 commit fa81c82

32 files changed

+1577
-1030
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,7 @@ Exceptions:
14311431
created.
14321432
* Throws with error code `invalid` if `node` is not part of the
14331433
graph.
1434+
* If any other exception is thrown the state of the graph node is undefined.
14341435

14351436
|
14361437
[source,c++]
@@ -1465,6 +1466,7 @@ Exceptions:
14651466
`property::graph::updatable` was not set when the executable graph was created.
14661467
* Throws with error code `invalid` if any node in `nodes` is not part of the
14671468
graph.
1469+
* If any other exception is thrown the state of the graph nodes is undefined.
14681470

14691471
|
14701472
[source, c++]
@@ -1517,6 +1519,8 @@ Exceptions:
15171519
* Throws synchronously with error code `invalid` if
15181520
`property::graph::updatable` was not set when the executable graph was
15191521
created.
1522+
1523+
* If any other exception is thrown the state of the graph nodes is undefined.
15201524
|===
15211525

15221526
Table {counter: tableNumber}. Member functions of the `command_graph` class for

sycl/source/detail/graph_impl.cpp

Lines changed: 188 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,18 +1381,72 @@ void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
13811381

13821382
void exec_graph_impl::update(
13831383
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1384-
13851384
if (!MIsUpdatable) {
13861385
throw sycl::exception(sycl::make_error_code(errc::invalid),
13871386
"update() cannot be called on a executable graph "
13881387
"which was not created with property::updatable");
13891388
}
13901389

1390+
// If the graph contains host tasks we need special handling here because
1391+
// their state lives in the graph object itself, so we must do the update
1392+
// immediately here. Whereas all other command state lives in the backend so
1393+
// it can be scheduled along with other commands.
1394+
if (MContainsHostTask) {
1395+
updateHostTasksImpl(Nodes);
1396+
}
1397+
1398+
// If there are any accessor requirements, we have to update through the
1399+
// scheduler to ensure that any allocations have taken place before trying
1400+
// to update.
1401+
std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
1402+
bool NeedScheduledUpdate = needsScheduledUpdate(Nodes, UpdateRequirements);
1403+
if (NeedScheduledUpdate) {
1404+
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
1405+
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
1406+
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
1407+
sycl::async_handler{}, sycl::property_list{});
1408+
1409+
// Track the event for the update command since execution may be blocked by
1410+
// other scheduler commands
1411+
auto UpdateEvent =
1412+
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
1413+
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
1414+
1415+
MExecutionEvents.push_back(UpdateEvent);
1416+
} else {
1417+
// For each partition in the executable graph, call UR update on the
1418+
// command-buffer with the nodes to update.
1419+
auto PartitionedNodes = getPartitionForNodes(Nodes);
1420+
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end();
1421+
It++) {
1422+
auto CommandBuffer = It->first->MCommandBuffers[MDevice];
1423+
updateKernelsImpl(CommandBuffer, It->second);
1424+
}
1425+
}
1426+
1427+
// Rebuild cached requirements and accessor storage for this graph with
1428+
// updated nodes
1429+
MRequirements.clear();
1430+
MAccessors.clear();
1431+
for (auto &Node : MNodeStorage) {
1432+
if (!Node->MCommandGroup)
1433+
continue;
1434+
MRequirements.insert(MRequirements.end(),
1435+
Node->MCommandGroup->getRequirements().begin(),
1436+
Node->MCommandGroup->getRequirements().end());
1437+
MAccessors.insert(MAccessors.end(),
1438+
Node->MCommandGroup->getAccStorage().begin(),
1439+
Node->MCommandGroup->getAccStorage().end());
1440+
}
1441+
}
1442+
1443+
bool exec_graph_impl::needsScheduledUpdate(
1444+
const std::vector<std::shared_ptr<node_impl>> &Nodes,
1445+
std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {
13911446
// If there are any accessor requirements, we have to update through the
13921447
// scheduler to ensure that any allocations have taken place before trying to
13931448
// update.
13941449
bool NeedScheduledUpdate = false;
1395-
std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
13961450
// At worst we may have as many requirements as there are for the entire graph
13971451
// for updating.
13981452
UpdateRequirements.reserve(MRequirements.size());
@@ -1435,94 +1489,17 @@ void exec_graph_impl::update(
14351489
// ensure it is ordered correctly.
14361490
NeedScheduledUpdate |= MExecutionEvents.size() > 0;
14371491

1438-
if (NeedScheduledUpdate) {
1439-
// Copy the list of nodes as we may need to modify it
1440-
auto NodesCopy = Nodes;
1441-
1442-
// If the graph contains host tasks we need special handling here because
1443-
// their state lives in the graph object itself, so we must do the update
1444-
// immediately here. Whereas all other command state lives in the backend so
1445-
// it can be scheduled along with other commands.
1446-
if (MContainsHostTask) {
1447-
std::vector<std::shared_ptr<node_impl>> HostTasks;
1448-
// Remove any nodes that are host tasks and put them in HostTasks
1449-
auto RemovedIter = std::remove_if(
1450-
NodesCopy.begin(), NodesCopy.end(),
1451-
[&HostTasks](const std::shared_ptr<node_impl> &Node) -> bool {
1452-
if (Node->MNodeType == node_type::host_task) {
1453-
HostTasks.push_back(Node);
1454-
return true;
1455-
}
1456-
return false;
1457-
});
1458-
// Clean up extra elements in NodesCopy after the remove
1459-
NodesCopy.erase(RemovedIter, NodesCopy.end());
1460-
1461-
// Update host-tasks synchronously
1462-
for (auto &HostTaskNode : HostTasks) {
1463-
updateImpl(HostTaskNode);
1464-
}
1465-
}
1466-
1467-
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
1468-
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
1469-
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
1470-
sycl::async_handler{}, sycl::property_list{});
1471-
1472-
// Track the event for the update command since execution may be blocked by
1473-
// other scheduler commands
1474-
auto UpdateEvent =
1475-
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
1476-
this, std::move(NodesCopy), AllocaQueue, UpdateRequirements,
1477-
MExecutionEvents);
1478-
1479-
MExecutionEvents.push_back(UpdateEvent);
1480-
} else {
1481-
for (auto &Node : Nodes) {
1482-
updateImpl(Node);
1483-
}
1484-
}
1485-
1486-
// Rebuild cached requirements and accessor storage for this graph with
1487-
// updated nodes
1488-
MRequirements.clear();
1489-
MAccessors.clear();
1490-
for (auto &Node : MNodeStorage) {
1491-
if (!Node->MCommandGroup)
1492-
continue;
1493-
MRequirements.insert(MRequirements.end(),
1494-
Node->MCommandGroup->getRequirements().begin(),
1495-
Node->MCommandGroup->getRequirements().end());
1496-
MAccessors.insert(MAccessors.end(),
1497-
Node->MCommandGroup->getAccStorage().begin(),
1498-
Node->MCommandGroup->getAccStorage().end());
1499-
}
1492+
return NeedScheduledUpdate;
15001493
}
15011494

1502-
void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
1503-
// Updating empty or barrier nodes is a no-op
1504-
if (Node->isEmpty() || Node->MNodeType == node_type::ext_oneapi_barrier) {
1505-
return;
1506-
}
1507-
1508-
// Query the ID cache to find the equivalent exec node for the node passed to
1509-
// this function.
1510-
// TODO: Handle subgraphs or any other cases where multiple nodes may be
1511-
// associated with a single key, once those node types are supported for
1512-
// update.
1513-
auto ExecNode = MIDCache.find(Node->MID);
1514-
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1515-
1516-
// Update ExecNode with new values from Node, in case we ever need to
1517-
// rebuild the command buffers
1518-
ExecNode->second->updateFromOtherNode(Node);
1519-
1520-
// Host task update only requires updating the node itself, so can return
1521-
// early
1522-
if (Node->MNodeType == node_type::host_task) {
1523-
return;
1524-
}
1525-
1495+
void exec_graph_impl::populateURKernelUpdateStructs(
1496+
const std::shared_ptr<node_impl> &Node,
1497+
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
1498+
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
1499+
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
1500+
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
1501+
sycl::detail::NDRDescT &NDRDesc,
1502+
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) {
15261503
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
15271504
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
15281505
auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());
@@ -1533,9 +1510,8 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
15331510
// Copy args because we may modify them
15341511
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments();
15351512
// Copy NDR desc since we need to modify it
1536-
auto NDRDesc = ExecCG.MNDRDesc;
1513+
NDRDesc = ExecCG.MNDRDesc;
15371514

1538-
ur_program_handle_t UrProgram = nullptr;
15391515
ur_kernel_handle_t UrKernel = nullptr;
15401516
auto Kernel = ExecCG.MSyclKernel;
15411517
auto KernelBundleImplPtr = ExecCG.MKernelBundle;
@@ -1560,9 +1536,11 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
15601536
UrKernel = Kernel->getHandleRef();
15611537
EliminatedArgMask = Kernel->getKernelArgMask();
15621538
} else {
1539+
ur_program_handle_t UrProgram = nullptr;
15631540
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
15641541
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
15651542
ContextImpl, DeviceImpl, ExecCG.MKernelName);
1543+
BundleObjs = std::make_pair(UrProgram, UrKernel);
15661544
}
15671545

15681546
// Remove eliminated args
@@ -1596,17 +1574,12 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
15961574
if (EnforcedLocalSize)
15971575
LocalSize = RequiredWGSize;
15981576
}
1599-
// Create update descriptor
16001577

16011578
// Storage for individual arg descriptors
1602-
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> MemobjDescs;
1603-
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> PtrDescs;
1604-
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> ValueDescs;
16051579
MemobjDescs.reserve(MaskedArgs.size());
16061580
PtrDescs.reserve(MaskedArgs.size());
16071581
ValueDescs.reserve(MaskedArgs.size());
16081582

1609-
ur_exp_command_buffer_update_kernel_launch_desc_t UpdateDesc{};
16101583
UpdateDesc.stype =
16111584
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC;
16121585
UpdateDesc.pNext = nullptr;
@@ -1675,20 +1648,131 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
16751648
UpdateDesc.pNewLocalWorkSize = LocalSize;
16761649
UpdateDesc.newWorkDim = NDRDesc.Dims;
16771650

1651+
// Query the ID cache to find the equivalent exec node for the node passed to
1652+
// this function.
1653+
// TODO: Handle subgraphs or any other cases where multiple nodes may be
1654+
// associated with a single key, once those node types are supported for
1655+
// update.
1656+
auto ExecNode = MIDCache.find(Node->MID);
1657+
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1658+
16781659
ur_exp_command_buffer_command_handle_t Command =
16791660
MCommandMap[ExecNode->second];
1680-
ur_result_t Res = Adapter->call_nocheck<
1681-
sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
1682-
Command, &UpdateDesc);
1661+
UpdateDesc.hCommand = Command;
1662+
1663+
// Update ExecNode with new values from Node, in case we ever need to
1664+
// rebuild the command buffers
1665+
ExecNode->second->updateFromOtherNode(Node);
1666+
}
1667+
1668+
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1669+
exec_graph_impl::getPartitionForNodes(
1670+
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1671+
// Iterate over each partition in the executable graph, and find the nodes
1672+
// in "Nodes" that also exist in the partition.
1673+
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1674+
PartitionedNodes;
1675+
for (const auto &Partition : MPartitions) {
1676+
std::vector<std::shared_ptr<node_impl>> NodesForPartition;
1677+
const auto PartitionBegin = Partition->MSchedule.begin();
1678+
const auto PartitionEnd = Partition->MSchedule.end();
1679+
for (auto &Node : Nodes) {
1680+
auto ExecNode = MIDCache.find(Node->MID);
1681+
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1682+
1683+
if (std::find_if(PartitionBegin, PartitionEnd,
1684+
[ExecNode](const auto &PartitionNode) {
1685+
return PartitionNode->MID == ExecNode->second->MID;
1686+
}) != PartitionEnd) {
1687+
NodesForPartition.push_back(Node);
1688+
}
1689+
}
1690+
if (!NodesForPartition.empty()) {
1691+
PartitionedNodes.insert({Partition, NodesForPartition});
1692+
}
1693+
}
1694+
1695+
return PartitionedNodes;
1696+
}
1697+
1698+
void exec_graph_impl::updateHostTasksImpl(
1699+
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1700+
for (auto &Node : Nodes) {
1701+
if (Node->MNodeType != node_type::host_task) {
1702+
continue;
1703+
}
1704+
// Query the ID cache to find the equivalent exec node for the node passed
1705+
// to this function.
1706+
// TODO: Handle subgraphs or any other cases where multiple nodes may be
1707+
// associated with a single key, once those node types are supported for
1708+
// update.
1709+
auto ExecNode = MIDCache.find(Node->MID);
1710+
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
16831711

1684-
if (UrProgram) {
1685-
// We retained these objects by calling getOrCreateKernel()
1686-
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
1687-
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
1712+
ExecNode->second->updateFromOtherNode(Node);
16881713
}
1714+
}
16891715

1690-
if (Res != UR_RESULT_SUCCESS) {
1691-
throw sycl::exception(errc::invalid, "Error updating command_graph");
1716+
void exec_graph_impl::updateKernelsImpl(
1717+
ur_exp_command_buffer_handle_t CommandBuffer,
1718+
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1719+
// Kernel node update is the only command type supported in UR for update.
1720+
// Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
1721+
size_t NumKernelNodes = 0;
1722+
for (auto &N : Nodes) {
1723+
if (N->MCGType == sycl::detail::CGType::Kernel) {
1724+
NumKernelNodes++;
1725+
}
1726+
}
1727+
1728+
// Don't need to call through to UR if no kernel nodes to update
1729+
if (NumKernelNodes == 0) {
1730+
return;
1731+
}
1732+
1733+
std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>
1734+
MemobjDescsList(NumKernelNodes);
1735+
std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t>>
1736+
PtrDescsList(NumKernelNodes);
1737+
std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t>>
1738+
ValueDescsList(NumKernelNodes);
1739+
std::vector<sycl::detail::NDRDescT> NDRDescList(NumKernelNodes);
1740+
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> UpdateDescList(
1741+
NumKernelNodes);
1742+
std::vector<std::pair<ur_program_handle_t, ur_kernel_handle_t>>
1743+
KernelBundleObjList(NumKernelNodes);
1744+
1745+
size_t StructListIndex = 0;
1746+
for (auto &Node : Nodes) {
1747+
if (Node->MCGType != sycl::detail::CGType::Kernel) {
1748+
continue;
1749+
}
1750+
1751+
auto &MemobjDescs = MemobjDescsList[StructListIndex];
1752+
auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];
1753+
auto &PtrDescs = PtrDescsList[StructListIndex];
1754+
auto &ValueDescs = ValueDescsList[StructListIndex];
1755+
auto &NDRDesc = NDRDescList[StructListIndex];
1756+
auto &UpdateDesc = UpdateDescList[StructListIndex];
1757+
populateURKernelUpdateStructs(Node, KernelBundleObjs, MemobjDescs, PtrDescs,
1758+
ValueDescs, NDRDesc, UpdateDesc);
1759+
StructListIndex++;
1760+
}
1761+
1762+
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
1763+
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
1764+
Adapter->call<sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
1765+
CommandBuffer, UpdateDescList.size(), UpdateDescList.data());
1766+
1767+
for (auto &BundleObjs : KernelBundleObjList) {
1768+
// We retained these objects by inside populateUpdateStruct() by calling
1769+
// getOrCreateKernel()
1770+
if (auto &UrKernel = BundleObjs.second; nullptr != UrKernel) {
1771+
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
1772+
}
1773+
if (auto &UrProgram = BundleObjs.first; nullptr != UrProgram) {
1774+
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
1775+
}
16921776
}
16931777
}
16941778

0 commit comments

Comments
 (0)