Skip to content

Commit 1d63e97

Browse files
author
Ewan Crawford
committed
Refactor how we find updatable partitions
1 parent fa81c82 commit 1d63e97

File tree

3 files changed

+84
-87
lines changed

3 files changed

+84
-87
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,16 @@ void exec_graph_impl::update(
14011401
std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
14021402
bool NeedScheduledUpdate = needsScheduledUpdate(Nodes, UpdateRequirements);
14031403
if (NeedScheduledUpdate) {
1404+
// Clean up any execution events which have finished so we don't pass them
1405+
// to the scheduler.
1406+
for (auto It = MExecutionEvents.begin(); It != MExecutionEvents.end();) {
1407+
if ((*It)->isCompleted()) {
1408+
It = MExecutionEvents.erase(It);
1409+
continue;
1410+
}
1411+
++It;
1412+
}
1413+
14041414
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
14051415
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
14061416
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
@@ -1416,11 +1426,12 @@ void exec_graph_impl::update(
14161426
} else {
14171427
// For each partition in the executable graph, call UR update on the
14181428
// command-buffer with the nodes to update.
1419-
auto PartitionedNodes = getPartitionForNodes(Nodes);
1429+
auto PartitionedNodes = getURUpdatableNodes(Nodes);
14201430
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end();
14211431
It++) {
1422-
auto CommandBuffer = It->first->MCommandBuffers[MDevice];
1423-
updateKernelsImpl(CommandBuffer, It->second);
1432+
auto &Partition = MPartitions[It->first];
1433+
auto CommandBuffer = Partition->MCommandBuffers[MDevice];
1434+
updateURImpl(CommandBuffer, It->second);
14241435
}
14251436
}
14261437

@@ -1475,16 +1486,6 @@ bool exec_graph_impl::needsScheduledUpdate(
14751486
}
14761487
}
14771488

1478-
// Clean up any execution events which have finished so we don't pass them to
1479-
// the scheduler.
1480-
for (auto It = MExecutionEvents.begin(); It != MExecutionEvents.end();) {
1481-
if ((*It)->isCompleted()) {
1482-
It = MExecutionEvents.erase(It);
1483-
continue;
1484-
}
1485-
++It;
1486-
}
1487-
14881489
// If we have previous execution events do the update through the scheduler to
14891490
// ensure it is ordered correctly.
14901491
NeedScheduledUpdate |= MExecutionEvents.size() > 0;
@@ -1499,7 +1500,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
14991500
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
15001501
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
15011502
sycl::detail::NDRDescT &NDRDesc,
1502-
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) {
1503+
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const {
15031504
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
15041505
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
15051506
auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());
@@ -1656,97 +1657,84 @@ void exec_graph_impl::populateURKernelUpdateStructs(
16561657
auto ExecNode = MIDCache.find(Node->MID);
16571658
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
16581659

1659-
ur_exp_command_buffer_command_handle_t Command =
1660-
MCommandMap[ExecNode->second];
1661-
UpdateDesc.hCommand = Command;
1660+
auto Command = MCommandMap.find(ExecNode->second);
1661+
assert(Command != MCommandMap.end());
1662+
UpdateDesc.hCommand = Command->second;
16621663

16631664
// Update ExecNode with new values from Node, in case we ever need to
16641665
// rebuild the command buffers
16651666
ExecNode->second->updateFromOtherNode(Node);
16661667
}
16671668

1668-
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1669-
exec_graph_impl::getPartitionForNodes(
1670-
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1671-
// Iterate over each partition in the executable graph, and find the nodes
1672-
// in "Nodes" that also exist in the partition.
1673-
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1674-
PartitionedNodes;
1675-
for (const auto &Partition : MPartitions) {
1676-
std::vector<std::shared_ptr<node_impl>> NodesForPartition;
1677-
const auto PartitionBegin = Partition->MSchedule.begin();
1678-
const auto PartitionEnd = Partition->MSchedule.end();
1679-
for (auto &Node : Nodes) {
1680-
auto ExecNode = MIDCache.find(Node->MID);
1681-
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1682-
1683-
if (std::find_if(PartitionBegin, PartitionEnd,
1684-
[ExecNode](const auto &PartitionNode) {
1685-
return PartitionNode->MID == ExecNode->second->MID;
1686-
}) != PartitionEnd) {
1687-
NodesForPartition.push_back(Node);
1688-
}
1689-
}
1690-
if (!NodesForPartition.empty()) {
1691-
PartitionedNodes.insert({Partition, NodesForPartition});
1669+
std::map<int, std::vector<std::shared_ptr<node_impl>>>
1670+
exec_graph_impl::getURUpdatableNodes(
1671+
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1672+
// Iterate over the list of nodes, and for every node that can
1673+
// be updated through UR, add it to the list of nodes for
1674+
// that can be updated for the UR command-buffer partition.
1675+
std::map<int, std::vector<std::shared_ptr<node_impl>>> PartitionedNodes;
1676+
1677+
// Initialize vector for each partition
1678+
for (size_t i = 0; i < MPartitions.size(); i++) {
1679+
PartitionedNodes[i] = {};
1680+
}
1681+
1682+
for (auto &Node : Nodes) {
1683+
// Kernel node update is the only command type supported in UR for update.
1684+
if (Node->MCGType != sycl::detail::CGType::Kernel) {
1685+
continue;
16921686
}
1687+
1688+
auto ExecNode = MIDCache.find(Node->MID);
1689+
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1690+
auto PartitionIndex = MPartitionNodes.find(ExecNode->second);
1691+
assert(PartitionIndex != MPartitionNodes.end());
1692+
PartitionedNodes[PartitionIndex->second].push_back(Node);
16931693
}
16941694

16951695
return PartitionedNodes;
16961696
}
16971697

16981698
void exec_graph_impl::updateHostTasksImpl(
1699-
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1699+
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
17001700
for (auto &Node : Nodes) {
17011701
if (Node->MNodeType != node_type::host_task) {
17021702
continue;
17031703
}
17041704
// Query the ID cache to find the equivalent exec node for the node passed
17051705
// to this function.
1706-
// TODO: Handle subgraphs or any other cases where multiple nodes may be
1707-
// associated with a single key, once those node types are supported for
1708-
// update.
17091706
auto ExecNode = MIDCache.find(Node->MID);
17101707
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17111708

17121709
ExecNode->second->updateFromOtherNode(Node);
17131710
}
17141711
}
17151712

1716-
void exec_graph_impl::updateKernelsImpl(
1713+
void exec_graph_impl::updateURImpl(
17171714
ur_exp_command_buffer_handle_t CommandBuffer,
1718-
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1719-
// Kernel node update is the only command type supported in UR for update.
1720-
// Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
1721-
size_t NumKernelNodes = 0;
1722-
for (auto &N : Nodes) {
1723-
if (N->MCGType == sycl::detail::CGType::Kernel) {
1724-
NumKernelNodes++;
1725-
}
1726-
}
1727-
1728-
// Don't need to call through to UR if no kernel nodes to update
1729-
if (NumKernelNodes == 0) {
1715+
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1716+
const size_t NumUpdatableNodes = Nodes.size();
1717+
if (NumUpdatableNodes == 0) {
17301718
return;
17311719
}
17321720

17331721
std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>
1734-
MemobjDescsList(NumKernelNodes);
1722+
MemobjDescsList(NumUpdatableNodes);
17351723
std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t>>
1736-
PtrDescsList(NumKernelNodes);
1724+
PtrDescsList(NumUpdatableNodes);
17371725
std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t>>
1738-
ValueDescsList(NumKernelNodes);
1739-
std::vector<sycl::detail::NDRDescT> NDRDescList(NumKernelNodes);
1726+
ValueDescsList(NumUpdatableNodes);
1727+
std::vector<sycl::detail::NDRDescT> NDRDescList(NumUpdatableNodes);
17401728
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> UpdateDescList(
1741-
NumKernelNodes);
1729+
NumUpdatableNodes);
17421730
std::vector<std::pair<ur_program_handle_t, ur_kernel_handle_t>>
1743-
KernelBundleObjList(NumKernelNodes);
1731+
KernelBundleObjList(NumUpdatableNodes);
17441732

17451733
size_t StructListIndex = 0;
17461734
for (auto &Node : Nodes) {
1747-
if (Node->MCGType != sycl::detail::CGType::Kernel) {
1748-
continue;
1749-
}
1735+
// This should be the case when getURUpdatableNodes() is used to
1736+
// create the list of nodes.
1737+
assert(Node->MCGType == sycl::detail::CGType::Kernel);
17501738

17511739
auto &MemobjDescs = MemobjDescsList[StructListIndex];
17521740
auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];

sycl/source/detail/graph_impl.hpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,18 +1303,30 @@ class exec_graph_impl {
13031303
void update(std::shared_ptr<node_impl> Node);
13041304
void update(const std::vector<std::shared_ptr<node_impl>> &Nodes);
13051305

1306-
/// Calls UR entry-point to update kernel nodes in command-buffer.
1306+
/// Calls UR entry-point to update nodes in command-buffer.
13071307
/// @param CommandBuffer The UR command-buffer to update commands in.
1308-
/// @param Nodes List of nodes to update. May contain nodes of non-kernel
1309-
/// type, but only kernel nodes from the list will be used for update
1310-
void updateKernelsImpl(ur_exp_command_buffer_handle_t CommandBuffer,
1311-
const std::vector<std::shared_ptr<node_impl>> &Nodes);
1308+
/// @param Nodes List of nodes to update. Only nodes which can be updated
1309+
/// through UR should be included in this list, currently this is only
1310+
/// nodes of kernel type.
1311+
void updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer,
1312+
const std::vector<std::shared_ptr<node_impl>> &Nodes) const;
13121313

1313-
/// Splits a list of nodes into separate lists depending on partition.
1314+
/// Update host-task nodes
1315+
/// @param Nodes List of nodes to update, any node that is not a host-task
1316+
/// will be ignored.
1317+
void updateHostTasksImpl(
1318+
const std::vector<std::shared_ptr<node_impl>> &Nodes) const;
1319+
1320+
/// Splits a list of nodes into separate lists of nodes for each
1321+
/// command-buffer partition.
1322+
///
1323+
/// Only nodes that can be updated through the UR interface are included
1324+
/// in the list. Currently this is only kernel node types.
1325+
///
13141326
/// @param Nodes List of nodes to split
1315-
/// @return Map of partitions to nodes
1316-
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1317-
getPartitionForNodes(const std::vector<std::shared_ptr<node_impl>> &Nodes);
1327+
/// @return Map of partition indexes to nodes
1328+
std::map<int, std::vector<std::shared_ptr<node_impl>>> getURUpdatableNodes(
1329+
const std::vector<std::shared_ptr<node_impl>> &Nodes) const;
13181330

13191331
unsigned long long getID() const { return MID; }
13201332

@@ -1408,13 +1420,7 @@ class exec_graph_impl {
14081420
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
14091421
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
14101422
sycl::detail::NDRDescT &NDRDesc,
1411-
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc);
1412-
1413-
/// Updates host-task nodes in the graph
1414-
/// @param Nodes List of nodes to update, any node that is not a host-task
1415-
/// will be ignored.
1416-
void
1417-
updateHostTasksImpl(const std::vector<std::shared_ptr<node_impl>> &Nodes);
1423+
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const;
14181424

14191425
/// Execution schedule of nodes in the graph.
14201426
std::list<std::shared_ptr<node_impl>> MSchedule;

sycl/source/detail/scheduler/commands.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3744,12 +3744,15 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() {
37443744
}
37453745

37463746
// Split list of nodes into nodes per UR command-buffer partition, then
3747-
// call UR update on each command-buffer partition
3748-
auto PartitionedNodes = MGraph->getPartitionForNodes(MNodes);
3747+
// call UR update on each command-buffer partition with those updatable
3748+
// nodes.
3749+
auto PartitionedNodes = MGraph->getURUpdatableNodes(MNodes);
37493750
auto Device = MQueue->get_device();
3751+
auto &Partitions = MGraph->getPartitions();
37503752
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); It++) {
3751-
auto CommandBuffer = It->first->MCommandBuffers[Device];
3752-
MGraph->updateKernelsImpl(CommandBuffer, It->second);
3753+
const int PartitionIndex = It->first;
3754+
auto CommandBuffer = Partitions[PartitionIndex]->MCommandBuffers[Device];
3755+
MGraph->updateURImpl(CommandBuffer, It->second);
37533756
}
37543757

37553758
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)