@@ -1381,18 +1381,72 @@ void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
13811381
13821382void exec_graph_impl::update (
13831383 const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1384-
13851384 if (!MIsUpdatable) {
13861385 throw sycl::exception (sycl::make_error_code (errc::invalid),
13871386 " update() cannot be called on a executable graph "
13881387 " which was not created with property::updatable" );
13891388 }
13901389
1390+ // If the graph contains host tasks we need special handling here because
1391+ // their state lives in the graph object itself, so we must do the update
1392+ // immediately here. Whereas all other command state lives in the backend so
1393+ // it can be scheduled along with other commands.
1394+ if (MContainsHostTask) {
1395+ updateHostTasksImpl (Nodes);
1396+ }
1397+
1398+ // If there are any accessor requirements, we have to update through the
1399+ // scheduler to ensure that any allocations have taken place before trying
1400+ // to update.
1401+ std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
1402+ bool NeedScheduledUpdate = needsScheduledUpdate (Nodes, UpdateRequirements);
1403+ if (NeedScheduledUpdate) {
1404+ auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
1405+ sycl::detail::getSyclObjImpl (MGraphImpl->getDevice ()),
1406+ sycl::detail::getSyclObjImpl (MGraphImpl->getContext ()),
1407+ sycl::async_handler{}, sycl::property_list{});
1408+
1409+ // Track the event for the update command since execution may be blocked by
1410+ // other scheduler commands
1411+ auto UpdateEvent =
1412+ sycl::detail::Scheduler::getInstance ().addCommandGraphUpdate (
1413+ this , Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
1414+
1415+ MExecutionEvents.push_back (UpdateEvent);
1416+ } else {
1417+ // For each partition in the executable graph, call UR update on the
1418+ // command-buffer with the nodes to update.
1419+ auto PartitionedNodes = getPartitionForNodes (Nodes);
1420+ for (auto It = PartitionedNodes.begin (); It != PartitionedNodes.end ();
1421+ It++) {
1422+ auto CommandBuffer = It->first ->MCommandBuffers [MDevice];
1423+ updateKernelsImpl (CommandBuffer, It->second );
1424+ }
1425+ }
1426+
1427+ // Rebuild cached requirements and accessor storage for this graph with
1428+ // updated nodes
1429+ MRequirements.clear ();
1430+ MAccessors.clear ();
1431+ for (auto &Node : MNodeStorage) {
1432+ if (!Node->MCommandGroup )
1433+ continue ;
1434+ MRequirements.insert (MRequirements.end (),
1435+ Node->MCommandGroup ->getRequirements ().begin (),
1436+ Node->MCommandGroup ->getRequirements ().end ());
1437+ MAccessors.insert (MAccessors.end (),
1438+ Node->MCommandGroup ->getAccStorage ().begin (),
1439+ Node->MCommandGroup ->getAccStorage ().end ());
1440+ }
1441+ }
1442+
1443+ bool exec_graph_impl::needsScheduledUpdate (
1444+ const std::vector<std::shared_ptr<node_impl>> &Nodes,
1445+ std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {
13911446 // If there are any accessor requirements, we have to update through the
13921447 // scheduler to ensure that any allocations have taken place before trying to
13931448 // update.
13941449 bool NeedScheduledUpdate = false ;
1395- std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
13961450 // At worst we may have as many requirements as there are for the entire graph
13971451 // for updating.
13981452 UpdateRequirements.reserve (MRequirements.size ());
@@ -1435,94 +1489,17 @@ void exec_graph_impl::update(
14351489 // ensure it is ordered correctly.
14361490 NeedScheduledUpdate |= MExecutionEvents.size () > 0 ;
14371491
1438- if (NeedScheduledUpdate) {
1439- // Copy the list of nodes as we may need to modify it
1440- auto NodesCopy = Nodes;
1441-
1442- // If the graph contains host tasks we need special handling here because
1443- // their state lives in the graph object itself, so we must do the update
1444- // immediately here. Whereas all other command state lives in the backend so
1445- // it can be scheduled along with other commands.
1446- if (MContainsHostTask) {
1447- std::vector<std::shared_ptr<node_impl>> HostTasks;
1448- // Remove any nodes that are host tasks and put them in HostTasks
1449- auto RemovedIter = std::remove_if (
1450- NodesCopy.begin (), NodesCopy.end (),
1451- [&HostTasks](const std::shared_ptr<node_impl> &Node) -> bool {
1452- if (Node->MNodeType == node_type::host_task) {
1453- HostTasks.push_back (Node);
1454- return true ;
1455- }
1456- return false ;
1457- });
1458- // Clean up extra elements in NodesCopy after the remove
1459- NodesCopy.erase (RemovedIter, NodesCopy.end ());
1460-
1461- // Update host-tasks synchronously
1462- for (auto &HostTaskNode : HostTasks) {
1463- updateImpl (HostTaskNode);
1464- }
1465- }
1466-
1467- auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
1468- sycl::detail::getSyclObjImpl (MGraphImpl->getDevice ()),
1469- sycl::detail::getSyclObjImpl (MGraphImpl->getContext ()),
1470- sycl::async_handler{}, sycl::property_list{});
1471-
1472- // Track the event for the update command since execution may be blocked by
1473- // other scheduler commands
1474- auto UpdateEvent =
1475- sycl::detail::Scheduler::getInstance ().addCommandGraphUpdate (
1476- this , std::move (NodesCopy), AllocaQueue, UpdateRequirements,
1477- MExecutionEvents);
1478-
1479- MExecutionEvents.push_back (UpdateEvent);
1480- } else {
1481- for (auto &Node : Nodes) {
1482- updateImpl (Node);
1483- }
1484- }
1485-
1486- // Rebuild cached requirements and accessor storage for this graph with
1487- // updated nodes
1488- MRequirements.clear ();
1489- MAccessors.clear ();
1490- for (auto &Node : MNodeStorage) {
1491- if (!Node->MCommandGroup )
1492- continue ;
1493- MRequirements.insert (MRequirements.end (),
1494- Node->MCommandGroup ->getRequirements ().begin (),
1495- Node->MCommandGroup ->getRequirements ().end ());
1496- MAccessors.insert (MAccessors.end (),
1497- Node->MCommandGroup ->getAccStorage ().begin (),
1498- Node->MCommandGroup ->getAccStorage ().end ());
1499- }
1492+ return NeedScheduledUpdate;
15001493}
15011494
1502- void exec_graph_impl::updateImpl (std::shared_ptr<node_impl> Node) {
1503- // Updating empty or barrier nodes is a no-op
1504- if (Node->isEmpty () || Node->MNodeType == node_type::ext_oneapi_barrier) {
1505- return ;
1506- }
1507-
1508- // Query the ID cache to find the equivalent exec node for the node passed to
1509- // this function.
1510- // TODO: Handle subgraphs or any other cases where multiple nodes may be
1511- // associated with a single key, once those node types are supported for
1512- // update.
1513- auto ExecNode = MIDCache.find (Node->MID );
1514- assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1515-
1516- // Update ExecNode with new values from Node, in case we ever need to
1517- // rebuild the command buffers
1518- ExecNode->second ->updateFromOtherNode (Node);
1519-
1520- // Host task update only requires updating the node itself, so can return
1521- // early
1522- if (Node->MNodeType == node_type::host_task) {
1523- return ;
1524- }
1525-
1495+ void exec_graph_impl::populateURKernelUpdateStructs (
1496+ const std::shared_ptr<node_impl> &Node,
1497+ std::pair<ur_program_handle_t , ur_kernel_handle_t > &BundleObjs,
1498+ std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t > &MemobjDescs,
1499+ std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t > &PtrDescs,
1500+ std::vector<ur_exp_command_buffer_update_value_arg_desc_t > &ValueDescs,
1501+ sycl::detail::NDRDescT &NDRDesc,
1502+ ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) {
15261503 auto ContextImpl = sycl::detail::getSyclObjImpl (MContext);
15271504 const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter ();
15281505 auto DeviceImpl = sycl::detail::getSyclObjImpl (MGraphImpl->getDevice ());
@@ -1533,9 +1510,8 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
15331510 // Copy args because we may modify them
15341511 std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments ();
15351512 // Copy NDR desc since we need to modify it
1536- auto NDRDesc = ExecCG.MNDRDesc ;
1513+ NDRDesc = ExecCG.MNDRDesc ;
15371514
1538- ur_program_handle_t UrProgram = nullptr ;
15391515 ur_kernel_handle_t UrKernel = nullptr ;
15401516 auto Kernel = ExecCG.MSyclKernel ;
15411517 auto KernelBundleImplPtr = ExecCG.MKernelBundle ;
@@ -1560,9 +1536,11 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
15601536 UrKernel = Kernel->getHandleRef ();
15611537 EliminatedArgMask = Kernel->getKernelArgMask ();
15621538 } else {
1539+ ur_program_handle_t UrProgram = nullptr ;
15631540 std::tie (UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
15641541 sycl::detail::ProgramManager::getInstance ().getOrCreateKernel (
15651542 ContextImpl, DeviceImpl, ExecCG.MKernelName );
1543+ BundleObjs = std::make_pair (UrProgram, UrKernel);
15661544 }
15671545
15681546 // Remove eliminated args
@@ -1596,17 +1574,12 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
15961574 if (EnforcedLocalSize)
15971575 LocalSize = RequiredWGSize;
15981576 }
1599- // Create update descriptor
16001577
16011578 // Storage for individual arg descriptors
1602- std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t > MemobjDescs;
1603- std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t > PtrDescs;
1604- std::vector<ur_exp_command_buffer_update_value_arg_desc_t > ValueDescs;
16051579 MemobjDescs.reserve (MaskedArgs.size ());
16061580 PtrDescs.reserve (MaskedArgs.size ());
16071581 ValueDescs.reserve (MaskedArgs.size ());
16081582
1609- ur_exp_command_buffer_update_kernel_launch_desc_t UpdateDesc{};
16101583 UpdateDesc.stype =
16111584 UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC;
16121585 UpdateDesc.pNext = nullptr ;
@@ -1675,20 +1648,131 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
16751648 UpdateDesc.pNewLocalWorkSize = LocalSize;
16761649 UpdateDesc.newWorkDim = NDRDesc.Dims ;
16771650
1651+ // Query the ID cache to find the equivalent exec node for the node passed to
1652+ // this function.
1653+ // TODO: Handle subgraphs or any other cases where multiple nodes may be
1654+ // associated with a single key, once those node types are supported for
1655+ // update.
1656+ auto ExecNode = MIDCache.find (Node->MID );
1657+ assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1658+
16781659 ur_exp_command_buffer_command_handle_t Command =
16791660 MCommandMap[ExecNode->second ];
1680- ur_result_t Res = Adapter->call_nocheck <
1681- sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
1682- Command, &UpdateDesc);
1661+ UpdateDesc.hCommand = Command;
1662+
1663+ // Update ExecNode with new values from Node, in case we ever need to
1664+ // rebuild the command buffers
1665+ ExecNode->second ->updateFromOtherNode (Node);
1666+ }
1667+
1668+ std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1669+ exec_graph_impl::getPartitionForNodes (
1670+ const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1671+ // Iterate over each partition in the executable graph, and find the nodes
1672+ // in "Nodes" that also exist in the partition.
1673+ std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
1674+ PartitionedNodes;
1675+ for (const auto &Partition : MPartitions) {
1676+ std::vector<std::shared_ptr<node_impl>> NodesForPartition;
1677+ const auto PartitionBegin = Partition->MSchedule .begin ();
1678+ const auto PartitionEnd = Partition->MSchedule .end ();
1679+ for (auto &Node : Nodes) {
1680+ auto ExecNode = MIDCache.find (Node->MID );
1681+ assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1682+
1683+ if (std::find_if (PartitionBegin, PartitionEnd,
1684+ [ExecNode](const auto &PartitionNode) {
1685+ return PartitionNode->MID == ExecNode->second ->MID ;
1686+ }) != PartitionEnd) {
1687+ NodesForPartition.push_back (Node);
1688+ }
1689+ }
1690+ if (!NodesForPartition.empty ()) {
1691+ PartitionedNodes.insert ({Partition, NodesForPartition});
1692+ }
1693+ }
1694+
1695+ return PartitionedNodes;
1696+ }
1697+
1698+ void exec_graph_impl::updateHostTasksImpl (
1699+ const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1700+ for (auto &Node : Nodes) {
1701+ if (Node->MNodeType != node_type::host_task) {
1702+ continue ;
1703+ }
1704+ // Query the ID cache to find the equivalent exec node for the node passed
1705+ // to this function.
1706+ // TODO: Handle subgraphs or any other cases where multiple nodes may be
1707+ // associated with a single key, once those node types are supported for
1708+ // update.
1709+ auto ExecNode = MIDCache.find (Node->MID );
1710+ assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
16831711
1684- if (UrProgram) {
1685- // We retained these objects by calling getOrCreateKernel()
1686- Adapter->call <sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
1687- Adapter->call <sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
1712+ ExecNode->second ->updateFromOtherNode (Node);
16881713 }
1714+ }
16891715
1690- if (Res != UR_RESULT_SUCCESS) {
1691- throw sycl::exception (errc::invalid, " Error updating command_graph" );
1716+ void exec_graph_impl::updateKernelsImpl (
1717+ ur_exp_command_buffer_handle_t CommandBuffer,
1718+ const std::vector<std::shared_ptr<node_impl>> &Nodes) {
1719+ // Kernel node update is the only command type supported in UR for update.
1720+ // Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
1721+ size_t NumKernelNodes = 0 ;
1722+ for (auto &N : Nodes) {
1723+ if (N->MCGType == sycl::detail::CGType::Kernel) {
1724+ NumKernelNodes++;
1725+ }
1726+ }
1727+
1728+ // Don't need to call through to UR if no kernel nodes to update
1729+ if (NumKernelNodes == 0 ) {
1730+ return ;
1731+ }
1732+
1733+ std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t >>
1734+ MemobjDescsList (NumKernelNodes);
1735+ std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t >>
1736+ PtrDescsList (NumKernelNodes);
1737+ std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t >>
1738+ ValueDescsList (NumKernelNodes);
1739+ std::vector<sycl::detail::NDRDescT> NDRDescList (NumKernelNodes);
1740+ std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t > UpdateDescList (
1741+ NumKernelNodes);
1742+ std::vector<std::pair<ur_program_handle_t , ur_kernel_handle_t >>
1743+ KernelBundleObjList (NumKernelNodes);
1744+
1745+ size_t StructListIndex = 0 ;
1746+ for (auto &Node : Nodes) {
1747+ if (Node->MCGType != sycl::detail::CGType::Kernel) {
1748+ continue ;
1749+ }
1750+
1751+ auto &MemobjDescs = MemobjDescsList[StructListIndex];
1752+ auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];
1753+ auto &PtrDescs = PtrDescsList[StructListIndex];
1754+ auto &ValueDescs = ValueDescsList[StructListIndex];
1755+ auto &NDRDesc = NDRDescList[StructListIndex];
1756+ auto &UpdateDesc = UpdateDescList[StructListIndex];
1757+ populateURKernelUpdateStructs (Node, KernelBundleObjs, MemobjDescs, PtrDescs,
1758+ ValueDescs, NDRDesc, UpdateDesc);
1759+ StructListIndex++;
1760+ }
1761+
1762+ auto ContextImpl = sycl::detail::getSyclObjImpl (MContext);
1763+ const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter ();
1764+ Adapter->call <sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
1765+ CommandBuffer, UpdateDescList.size (), UpdateDescList.data ());
1766+
1767+ for (auto &BundleObjs : KernelBundleObjList) {
1768+ // We retained these objects by inside populateUpdateStruct() by calling
1769+ // getOrCreateKernel()
1770+ if (auto &UrKernel = BundleObjs.second ; nullptr != UrKernel) {
1771+ Adapter->call <sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
1772+ }
1773+ if (auto &UrProgram = BundleObjs.first ; nullptr != UrProgram) {
1774+ Adapter->call <sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
1775+ }
16921776 }
16931777}
16941778
0 commit comments