Skip to content

Commit 139b815

Browse files
author
Ewan Crawford
committed
Fix bug in lifetime of mem obj properties
1 parent 1d63e97 commit 139b815

File tree

5 files changed

+30
-18
lines changed

5 files changed

+30
-18
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,11 +1427,10 @@ void exec_graph_impl::update(
14271427
// For each partition in the executable graph, call UR update on the
14281428
// command-buffer with the nodes to update.
14291429
auto PartitionedNodes = getURUpdatableNodes(Nodes);
1430-
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end();
1431-
It++) {
1432-
auto &Partition = MPartitions[It->first];
1430+
for (auto &[PartitionIndex, NodeImpl] : PartitionedNodes) {
1431+
auto &Partition = MPartitions[PartitionIndex];
14331432
auto CommandBuffer = Partition->MCommandBuffers[MDevice];
1434-
updateURImpl(CommandBuffer, It->second);
1433+
updateURImpl(CommandBuffer, NodeImpl);
14351434
}
14361435
}
14371436

@@ -1497,6 +1496,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
14971496
const std::shared_ptr<node_impl> &Node,
14981497
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
14991498
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
1499+
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
15001500
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
15011501
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
15021502
sycl::detail::NDRDescT &NDRDesc,
@@ -1580,6 +1580,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
15801580
MemobjDescs.reserve(MaskedArgs.size());
15811581
PtrDescs.reserve(MaskedArgs.size());
15821582
ValueDescs.reserve(MaskedArgs.size());
1583+
MemobjProps.resize(MaskedArgs.size()); // resize since we access by reference
15831584

15841585
UpdateDesc.stype =
15851586
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC;
@@ -1606,27 +1607,27 @@ void exec_graph_impl::populateURKernelUpdateStructs(
16061607
sycl::detail::Requirement *Req =
16071608
static_cast<sycl::detail::Requirement *>(NodeArg.MPtr);
16081609

1609-
ur_kernel_arg_mem_obj_properties_t MemObjProps;
1610-
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
1611-
MemObjProps.pNext = nullptr;
1610+
ur_kernel_arg_mem_obj_properties_t &MemObjProp = MemobjProps[i];
1611+
MemObjProp.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
1612+
MemObjProp.pNext = nullptr;
16121613
switch (Req->MAccessMode) {
16131614
case access::mode::read: {
1614-
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
1615+
MemObjProp.memoryAccess = UR_MEM_FLAG_READ_ONLY;
16151616
break;
16161617
}
16171618
case access::mode::write:
16181619
case access::mode::discard_write: {
1619-
MemObjProps.memoryAccess = UR_MEM_FLAG_WRITE_ONLY;
1620+
MemObjProp.memoryAccess = UR_MEM_FLAG_WRITE_ONLY;
16201621
break;
16211622
}
16221623
default: {
1623-
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_WRITE;
1624+
MemObjProp.memoryAccess = UR_MEM_FLAG_READ_WRITE;
16241625
break;
16251626
}
16261627
}
16271628
MemobjDescs.push_back(
16281629
{UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC, nullptr,
1629-
static_cast<uint32_t>(NodeArg.MIndex), &MemObjProps,
1630+
static_cast<uint32_t>(NodeArg.MIndex), &MemObjProp,
16301631
static_cast<ur_mem_handle_t>(Req->MData)});
16311632

16321633
} break;
@@ -1718,8 +1719,16 @@ void exec_graph_impl::updateURImpl(
17181719
return;
17191720
}
17201721

1722+
// The urCommandBufferUpdateKernelLaunchExp API takes structs which contain
1723+
// members that are pointers to other structs. The lifetime of all the
1724+
// pointers (including nested pointers) needs to be valid at the time of the
1725+
// urCommandBufferUpdateKernelLaunchExp call. Define the objects here which
1726+
// will be populated and used in the urCommandBufferUpdateKernelLaunchExp
1727+
// call.
17211728
std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>
17221729
MemobjDescsList(NumUpdatableNodes);
1730+
std::vector<std::vector<ur_kernel_arg_mem_obj_properties_t>> MemobjPropsList(
1731+
NumUpdatableNodes);
17231732
std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t>>
17241733
PtrDescsList(NumUpdatableNodes);
17251734
std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t>>
@@ -1737,13 +1746,15 @@ void exec_graph_impl::updateURImpl(
17371746
assert(Node->MCGType == sycl::detail::CGType::Kernel);
17381747

17391748
auto &MemobjDescs = MemobjDescsList[StructListIndex];
1749+
auto &MemobjProps = MemobjPropsList[StructListIndex];
17401750
auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];
17411751
auto &PtrDescs = PtrDescsList[StructListIndex];
17421752
auto &ValueDescs = ValueDescsList[StructListIndex];
17431753
auto &NDRDesc = NDRDescList[StructListIndex];
17441754
auto &UpdateDesc = UpdateDescList[StructListIndex];
1745-
populateURKernelUpdateStructs(Node, KernelBundleObjs, MemobjDescs, PtrDescs,
1746-
ValueDescs, NDRDesc, UpdateDesc);
1755+
populateURKernelUpdateStructs(Node, KernelBundleObjs, MemobjDescs,
1756+
MemobjProps, PtrDescs, ValueDescs, NDRDesc,
1757+
UpdateDesc);
17471758
StructListIndex++;
17481759
}
17491760

sycl/source/detail/graph_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,7 @@ class exec_graph_impl {
14091409
/// @param[out] BundleObjs UR objects created from kernel bundle.
14101410
/// Responsibility of the caller to release.
14111411
/// @param[out] MemobjDescs Memory object arguments to update.
1412+
/// @param[out] MemobjProps Properties used in /p MemobjDescs structs.
14121413
/// @param[out] PtrDescs Pointer arguments to update.
14131414
/// @param[out] ValueDescs Value arguments to update.
14141415
/// @param[out] NDRDesc ND-Range to update.
@@ -1417,6 +1418,7 @@ class exec_graph_impl {
14171418
const std::shared_ptr<node_impl> &Node,
14181419
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
14191420
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
1421+
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
14201422
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
14211423
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
14221424
sycl::detail::NDRDescT &NDRDesc,

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3749,10 +3749,9 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() {
37493749
auto PartitionedNodes = MGraph->getURUpdatableNodes(MNodes);
37503750
auto Device = MQueue->get_device();
37513751
auto &Partitions = MGraph->getPartitions();
3752-
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); It++) {
3753-
const int PartitionIndex = It->first;
3752+
for (auto &[PartitionIndex, NodeImpl] : PartitionedNodes) {
37543753
auto CommandBuffer = Partitions[PartitionIndex]->MCommandBuffers[Device];
3755-
MGraph->updateURImpl(CommandBuffer, It->second);
3754+
MGraph->updateURImpl(CommandBuffer, NodeImpl);
37563755
}
37573756

37583757
return UR_RESULT_SUCCESS;

unified-runtime/scripts/templates/ldrddi.cpp.mako

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ namespace ur_loader
215215
%if func_basename == "CommandBufferUpdateKernelLaunchExp":
216216
## CommandBufferUpdateKernelLaunchExp entry-point takes a list of structs with
217217
## handle members, as well as members defining a nested list of structs
218-
## containing handles. This useage is not supported yet, so special case as
218+
## containing handles. This usage is not supported yet, so special case as
219219
## a temporary measure.
220220
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> pUpdateKernelLaunchVector = {};
221221
std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>

unified-runtime/source/adapters/opencl/command_buffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
678678
cl_uint NumConfigs = ConfigList.size();
679679
std::vector<cl_command_buffer_update_type_khr> ConfigTypes(
680680
NumConfigs, CL_STRUCTURE_TYPE_MUTABLE_DISPATCH_CONFIG_KHR);
681-
std::vector<void *> ConfigPtrs(NumConfigs);
681+
std::vector<const void *> ConfigPtrs(NumConfigs);
682682
for (cl_uint i = 0; i < NumConfigs; i++) {
683683
ConfigPtrs[i] = &ConfigList[i];
684684
}

0 commit comments

Comments
 (0)