@@ -1427,11 +1427,10 @@ void exec_graph_impl::update(
14271427 // For each partition in the executable graph, call UR update on the
14281428 // command-buffer with the nodes to update.
14291429 auto PartitionedNodes = getURUpdatableNodes (Nodes);
1430- for (auto It = PartitionedNodes.begin (); It != PartitionedNodes.end ();
1431- It++) {
1432- auto &Partition = MPartitions[It->first ];
1430+ for (auto &[PartitionIndex, NodeImpl] : PartitionedNodes) {
1431+ auto &Partition = MPartitions[PartitionIndex];
14331432 auto CommandBuffer = Partition->MCommandBuffers [MDevice];
1434- updateURImpl (CommandBuffer, It-> second );
1433+ updateURImpl (CommandBuffer, NodeImpl );
14351434 }
14361435 }
14371436
@@ -1497,6 +1496,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
14971496 const std::shared_ptr<node_impl> &Node,
14981497 std::pair<ur_program_handle_t , ur_kernel_handle_t > &BundleObjs,
14991498 std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t > &MemobjDescs,
1499+ std::vector<ur_kernel_arg_mem_obj_properties_t > &MemobjProps,
15001500 std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t > &PtrDescs,
15011501 std::vector<ur_exp_command_buffer_update_value_arg_desc_t > &ValueDescs,
15021502 sycl::detail::NDRDescT &NDRDesc,
@@ -1580,6 +1580,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
15801580 MemobjDescs.reserve (MaskedArgs.size ());
15811581 PtrDescs.reserve (MaskedArgs.size ());
15821582 ValueDescs.reserve (MaskedArgs.size ());
1583+ MemobjProps.resize (MaskedArgs.size ()); // resize since we access by reference
15831584
15841585 UpdateDesc.stype =
15851586 UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC;
@@ -1606,27 +1607,27 @@ void exec_graph_impl::populateURKernelUpdateStructs(
16061607 sycl::detail::Requirement *Req =
16071608 static_cast <sycl::detail::Requirement *>(NodeArg.MPtr );
16081609
1609- ur_kernel_arg_mem_obj_properties_t MemObjProps ;
1610- MemObjProps .stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
1611- MemObjProps .pNext = nullptr ;
1610+ ur_kernel_arg_mem_obj_properties_t &MemObjProp = MemobjProps[i] ;
1611+ MemObjProp .stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
1612+ MemObjProp .pNext = nullptr ;
16121613 switch (Req->MAccessMode ) {
16131614 case access::mode::read: {
1614- MemObjProps .memoryAccess = UR_MEM_FLAG_READ_ONLY;
1615+ MemObjProp .memoryAccess = UR_MEM_FLAG_READ_ONLY;
16151616 break ;
16161617 }
16171618 case access::mode::write:
16181619 case access::mode::discard_write: {
1619- MemObjProps .memoryAccess = UR_MEM_FLAG_WRITE_ONLY;
1620+ MemObjProp .memoryAccess = UR_MEM_FLAG_WRITE_ONLY;
16201621 break ;
16211622 }
16221623 default : {
1623- MemObjProps .memoryAccess = UR_MEM_FLAG_READ_WRITE;
1624+ MemObjProp .memoryAccess = UR_MEM_FLAG_READ_WRITE;
16241625 break ;
16251626 }
16261627 }
16271628 MemobjDescs.push_back (
16281629 {UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC, nullptr ,
1629- static_cast <uint32_t >(NodeArg.MIndex ), &MemObjProps ,
1630+ static_cast <uint32_t >(NodeArg.MIndex ), &MemObjProp ,
16301631 static_cast <ur_mem_handle_t >(Req->MData )});
16311632
16321633 } break ;
@@ -1718,8 +1719,16 @@ void exec_graph_impl::updateURImpl(
17181719 return ;
17191720 }
17201721
1722+ // The urCommandBufferUpdateKernelLaunchExp API takes structs which contain
1723+ // members that are pointers to other structs. The lifetime of all the
1724+ // pointers (including nested pointers) needs to be valid at the time of the
1725+ // urCommandBufferUpdateKernelLaunchExp call. Define the objects here which
1726+ // will be populated and used in the urCommandBufferUpdateKernelLaunchExp
1727+ // call.
17211728 std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t >>
17221729 MemobjDescsList (NumUpdatableNodes);
1730+ std::vector<std::vector<ur_kernel_arg_mem_obj_properties_t >> MemobjPropsList (
1731+ NumUpdatableNodes);
17231732 std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t >>
17241733 PtrDescsList (NumUpdatableNodes);
17251734 std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t >>
@@ -1737,13 +1746,15 @@ void exec_graph_impl::updateURImpl(
17371746 assert (Node->MCGType == sycl::detail::CGType::Kernel);
17381747
17391748 auto &MemobjDescs = MemobjDescsList[StructListIndex];
1749+ auto &MemobjProps = MemobjPropsList[StructListIndex];
17401750 auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];
17411751 auto &PtrDescs = PtrDescsList[StructListIndex];
17421752 auto &ValueDescs = ValueDescsList[StructListIndex];
17431753 auto &NDRDesc = NDRDescList[StructListIndex];
17441754 auto &UpdateDesc = UpdateDescList[StructListIndex];
1745- populateURKernelUpdateStructs (Node, KernelBundleObjs, MemobjDescs, PtrDescs,
1746- ValueDescs, NDRDesc, UpdateDesc);
1755+ populateURKernelUpdateStructs (Node, KernelBundleObjs, MemobjDescs,
1756+ MemobjProps, PtrDescs, ValueDescs, NDRDesc,
1757+ UpdateDesc);
17471758 StructListIndex++;
17481759 }
17491760
0 commit comments