@@ -1401,6 +1401,16 @@ void exec_graph_impl::update(
14011401 std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
14021402 bool NeedScheduledUpdate = needsScheduledUpdate (Nodes, UpdateRequirements);
14031403 if (NeedScheduledUpdate) {
1404+ // Clean up any execution events which have finished so we don't pass them
1405+ // to the scheduler.
1406+ for (auto It = MExecutionEvents.begin (); It != MExecutionEvents.end ();) {
1407+ if ((*It)->isCompleted ()) {
1408+ It = MExecutionEvents.erase (It);
1409+ continue ;
1410+ }
1411+ ++It;
1412+ }
1413+
14041414 auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
14051415 sycl::detail::getSyclObjImpl (MGraphImpl->getDevice ()),
14061416 sycl::detail::getSyclObjImpl (MGraphImpl->getContext ()),
@@ -1416,11 +1426,12 @@ void exec_graph_impl::update(
14161426 } else {
14171427 // For each partition in the executable graph, call UR update on the
14181428 // command-buffer with the nodes to update.
1419- auto PartitionedNodes = getPartitionForNodes (Nodes);
1429+ auto PartitionedNodes = getURUpdatableNodes (Nodes);
14201430 for (auto It = PartitionedNodes.begin (); It != PartitionedNodes.end ();
14211431 It++) {
1422- auto CommandBuffer = It->first ->MCommandBuffers [MDevice];
1423- updateKernelsImpl (CommandBuffer, It->second );
1432+ auto &Partition = MPartitions[It->first ];
1433+ auto CommandBuffer = Partition->MCommandBuffers [MDevice];
1434+ updateURImpl (CommandBuffer, It->second );
14241435 }
14251436 }
14261437
@@ -1475,16 +1486,6 @@ bool exec_graph_impl::needsScheduledUpdate(
14751486 }
14761487 }
14771488
1478- // Clean up any execution events which have finished so we don't pass them to
1479- // the scheduler.
1480- for (auto It = MExecutionEvents.begin (); It != MExecutionEvents.end ();) {
1481- if ((*It)->isCompleted ()) {
1482- It = MExecutionEvents.erase (It);
1483- continue ;
1484- }
1485- ++It;
1486- }
1487-
14881489 // If we have previous execution events do the update through the scheduler to
14891490 // ensure it is ordered correctly.
14901491 NeedScheduledUpdate |= MExecutionEvents.size () > 0 ;
@@ -1499,7 +1500,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
14991500 std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t > &PtrDescs,
15001501 std::vector<ur_exp_command_buffer_update_value_arg_desc_t > &ValueDescs,
15011502 sycl::detail::NDRDescT &NDRDesc,
1502- ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) {
1503+ ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const {
15031504 auto ContextImpl = sycl::detail::getSyclObjImpl (MContext);
15041505 const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter ();
15051506 auto DeviceImpl = sycl::detail::getSyclObjImpl (MGraphImpl->getDevice ());
@@ -1656,97 +1657,84 @@ void exec_graph_impl::populateURKernelUpdateStructs(
16561657 auto ExecNode = MIDCache.find (Node->MID );
16571658 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
16581659
1659- ur_exp_command_buffer_command_handle_t Command =
1660- MCommandMap[ExecNode-> second ] ;
1661- UpdateDesc.hCommand = Command;
1660+ auto Command = MCommandMap. find (ExecNode-> second );
1661+ assert (Command != MCommandMap. end ()) ;
1662+ UpdateDesc.hCommand = Command-> second ;
16621663
16631664 // Update ExecNode with new values from Node, in case we ever need to
16641665 // rebuild the command buffers
16651666 ExecNode->second ->updateFromOtherNode (Node);
16661667}
16671668
1668- std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1669- exec_graph_impl::getPartitionForNodes (
1670- const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1671- // Iterate over each partition in the executable graph, and find the nodes
1672- // in "Nodes" that also exist in the partition.
1673- std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1674- PartitionedNodes;
1675- for (const auto &Partition : MPartitions) {
1676- std::vector<std::shared_ptr<node_impl>> NodesForPartition;
1677- const auto PartitionBegin = Partition->MSchedule .begin ();
1678- const auto PartitionEnd = Partition->MSchedule .end ();
1679- for (auto &Node : Nodes) {
1680- auto ExecNode = MIDCache.find (Node->MID );
1681- assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1682-
1683- if (std::find_if (PartitionBegin, PartitionEnd,
1684- [ExecNode](const auto &PartitionNode) {
1685- return PartitionNode->MID == ExecNode->second ->MID ;
1686- }) != PartitionEnd) {
1687- NodesForPartition.push_back (Node);
1688- }
1689- }
1690- if (!NodesForPartition.empty ()) {
1691- PartitionedNodes.insert ({Partition, NodesForPartition});
1669+ std::map<int , std::vector<std::shared_ptr<node_impl>>>
1670+ exec_graph_impl::getURUpdatableNodes (
1671+ const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1672+ // Iterate over the list of nodes, and for every node that can
1673+ // be updated through UR, add it to the list of nodes for
1674+ // that can be updated for the UR command-buffer partition.
1675+ std::map<int , std::vector<std::shared_ptr<node_impl>>> PartitionedNodes;
1676+
1677+ // Initialize vector for each partition
1678+ for (size_t i = 0 ; i < MPartitions.size (); i++) {
1679+ PartitionedNodes[i] = {};
1680+ }
1681+
1682+ for (auto &Node : Nodes) {
1683+ // Kernel node update is the only command type supported in UR for update.
1684+ if (Node->MCGType != sycl::detail::CGType::Kernel) {
1685+ continue ;
16921686 }
1687+
1688+ auto ExecNode = MIDCache.find (Node->MID );
1689+ assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1690+ auto PartitionIndex = MPartitionNodes.find (ExecNode->second );
1691+ assert (PartitionIndex != MPartitionNodes.end ());
1692+ PartitionedNodes[PartitionIndex->second ].push_back (Node);
16931693 }
16941694
16951695 return PartitionedNodes;
16961696}
16971697
16981698void exec_graph_impl::updateHostTasksImpl (
1699- const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1699+ const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
17001700 for (auto &Node : Nodes) {
17011701 if (Node->MNodeType != node_type::host_task) {
17021702 continue ;
17031703 }
17041704 // Query the ID cache to find the equivalent exec node for the node passed
17051705 // to this function.
1706- // TODO: Handle subgraphs or any other cases where multiple nodes may be
1707- // associated with a single key, once those node types are supported for
1708- // update.
17091706 auto ExecNode = MIDCache.find (Node->MID );
17101707 assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
17111708
17121709 ExecNode->second ->updateFromOtherNode (Node);
17131710 }
17141711}
17151712
1716- void exec_graph_impl::updateKernelsImpl (
1713+ void exec_graph_impl::updateURImpl (
17171714 ur_exp_command_buffer_handle_t CommandBuffer,
1718- const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1719- // Kernel node update is the only command type supported in UR for update.
1720- // Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
1721- size_t NumKernelNodes = 0 ;
1722- for (auto &N : Nodes) {
1723- if (N->MCGType == sycl::detail::CGType::Kernel) {
1724- NumKernelNodes++;
1725- }
1726- }
1727-
1728- // Don't need to call through to UR if no kernel nodes to update
1729- if (NumKernelNodes == 0 ) {
1715+ const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1716+ const size_t NumUpdatableNodes = Nodes.size ();
1717+ if (NumUpdatableNodes == 0 ) {
17301718 return ;
17311719 }
17321720
17331721 std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t >>
1734- MemobjDescsList (NumKernelNodes );
1722+ MemobjDescsList (NumUpdatableNodes );
17351723 std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t >>
1736- PtrDescsList (NumKernelNodes );
1724+ PtrDescsList (NumUpdatableNodes );
17371725 std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t >>
1738- ValueDescsList (NumKernelNodes );
1739- std::vector<sycl::detail::NDRDescT> NDRDescList (NumKernelNodes );
1726+ ValueDescsList (NumUpdatableNodes );
1727+ std::vector<sycl::detail::NDRDescT> NDRDescList (NumUpdatableNodes );
17401728 std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t > UpdateDescList (
1741- NumKernelNodes );
1729+ NumUpdatableNodes );
17421730 std::vector<std::pair<ur_program_handle_t , ur_kernel_handle_t >>
1743- KernelBundleObjList (NumKernelNodes );
1731+ KernelBundleObjList (NumUpdatableNodes );
17441732
17451733 size_t StructListIndex = 0 ;
17461734 for (auto &Node : Nodes) {
1747- if (Node-> MCGType != sycl::detail::CGType::Kernel) {
1748- continue ;
1749- }
1735+ // This should be the case when getURUpdatableNodes() is used to
1736+ // create the list of nodes.
1737+ assert (Node-> MCGType == sycl::detail::CGType::Kernel);
17501738
17511739 auto &MemobjDescs = MemobjDescsList[StructListIndex];
17521740 auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];
0 commit comments