Skip to content

Commit d1a6bef

Browse files
author
Fábio Mestre
committed
Implement Dynamic Local Accessors
1 parent 305705c commit d1a6bef

File tree

9 files changed

+584
-0
lines changed

9 files changed

+584
-0
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,11 @@ class command_graph<graph_state::executable>
447447
namespace detail {
448448
class __SYCL_EXPORT dynamic_parameter_base {
449449
public:
450+
451+
dynamic_parameter_base(
452+
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
453+
Graph);
454+
450455
dynamic_parameter_base(
451456
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
452457
Graph,
@@ -461,6 +466,13 @@ class __SYCL_EXPORT dynamic_parameter_base {
461466
void updateValue(const raw_kernel_arg *NewRawValue, size_t Size);
462467

463468
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
469+
470+
sycl::detail::LocalAccessorImplPtr getLocalAccessor(handler* Handler);
471+
472+
void registerLocalAccessor(sycl::detail::LocalAccessorBaseHost* LocalAccBaseHost, handler* Handler);
473+
474+
void updateLocalAccessor(range<3> NewAllocationSize);
475+
464476
std::shared_ptr<dynamic_parameter_impl> impl;
465477

466478
template <class Obj>
@@ -498,6 +510,42 @@ class dynamic_parameter : public detail::dynamic_parameter_base {
498510
}
499511
};
500512

513+
template <typename DataT, int Dimensions = 1>
514+
class dynamic_local_accessor : public detail::dynamic_parameter_base {
515+
public:
516+
template <int Dims = Dimensions, typename = std::enable_if_t<(Dims > 0)>>
517+
dynamic_local_accessor(command_graph<graph_state::modifiable> Graph,
518+
range<Dimensions> AllocationSize,
519+
const property_list &PropList = {})
520+
: detail::dynamic_parameter_base(Graph), AllocationSize(AllocationSize) {
521+
(void)PropList;
522+
}
523+
524+
void update(range<Dimensions> NewAllocationSize) {
525+
detail::dynamic_parameter_base::updateLocalAccessor(
526+
::sycl::detail::convertToArrayOfN<3, 1>(NewAllocationSize));
527+
};
528+
529+
local_accessor<DataT, Dimensions> get(handler &CGH) {
530+
#ifndef __SYCL_DEVICE_ONLY__
531+
::sycl::detail::LocalAccessorImplPtr BaseLocalAcc = getLocalAccessor(&CGH);
532+
if (BaseLocalAcc) {
533+
return sycl::detail::createSyclObjFromImpl<local_accessor<DataT, Dimensions>>(BaseLocalAcc);
534+
} else {
535+
local_accessor<DataT, Dimensions> LocalAccessor(AllocationSize, CGH);
536+
registerLocalAccessor(
537+
static_cast<sycl::detail::LocalAccessorBaseHost *>(&LocalAccessor), &CGH);
538+
return LocalAccessor;
539+
}
540+
#else
541+
return local_accessor<DataT, Dimensions>();
542+
#endif
543+
};
544+
545+
private:
546+
range<Dimensions> AllocationSize;
547+
};
548+
501549
/// Additional CTAD deduction guides.
502550
template <typename ValueT>
503551
dynamic_parameter(experimental::command_graph<graph_state::modifiable> Graph,

sycl/include/sycl/handler.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,22 @@ class __SYCL_EXPORT handler {
647647
registerDynamicParameter(DynamicParam, ArgIndex);
648648
}
649649

650+
// setArgHelper for graph dynamic_local_accessors.
651+
template <typename DataT, int Dims>
652+
void
653+
setArgHelper(int ArgIndex,
654+
ext::oneapi::experimental::dynamic_local_accessor<DataT, Dims>
655+
&DynamicLocalAccessor) {
656+
#ifndef __SYCL_DEVICE_ONLY__
657+
auto LocalAccessor = DynamicLocalAccessor.get(*this);
658+
setArgHelper(ArgIndex, LocalAccessor);
659+
registerDynamicParameter(DynamicLocalAccessor, ArgIndex);
660+
#else
661+
(void)ArgIndex;
662+
(void)DynamicLocalAccessor;
663+
#endif
664+
}
665+
650666
// setArgHelper for the raw_kernel_arg extension type.
651667
void setArgHelper(int ArgIndex,
652668
sycl::ext::oneapi::experimental::raw_kernel_arg &&Arg) {
@@ -1839,6 +1855,13 @@ class __SYCL_EXPORT handler {
18391855
setArgHelper(argIndex, dynamicParam);
18401856
}
18411857

1858+
template <typename DataT, int Dims>
1859+
void set_arg(int argIndex,
1860+
ext::oneapi::experimental::dynamic_local_accessor<DataT, Dims>
1861+
&DynamicLocalAccessor) {
1862+
setArgHelper(argIndex, DynamicLocalAccessor);
1863+
}
1864+
18421865
// set_arg for the raw_kernel_arg extension type.
18431866
void set_arg(int argIndex, ext::oneapi::experimental::raw_kernel_arg &&Arg) {
18441867
setArgHelper(argIndex, std::move(Arg));

sycl/source/detail/graph_impl.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,6 +1899,11 @@ dynamic_parameter_base::dynamic_parameter_base(
18991899
: impl(std::make_shared<dynamic_parameter_impl>(
19001900
sycl::detail::getSyclObjImpl(Graph), ParamSize, Data)) {}
19011901

1902+
dynamic_parameter_base::dynamic_parameter_base(
1903+
command_graph<graph_state::modifiable> Graph)
1904+
: impl(std::make_shared<dynamic_parameter_impl>(
1905+
sycl::detail::getSyclObjImpl(Graph))) {}
1906+
19021907
void dynamic_parameter_base::updateValue(const void *NewValue, size_t Size) {
19031908
impl->updateValue(NewValue, Size);
19041909
}
@@ -1913,6 +1918,20 @@ void dynamic_parameter_base::updateAccessor(
19131918
impl->updateAccessor(Acc);
19141919
}
19151920

1921+
sycl::detail::LocalAccessorImplPtr
1922+
dynamic_parameter_base::getLocalAccessor(handler *Handler) {
1923+
return impl->getLocalAccessor(Handler);
1924+
}
1925+
1926+
void dynamic_parameter_base::registerLocalAccessor(
1927+
sycl::detail::LocalAccessorBaseHost *LocalAccBaseHost, handler *Handler) {
1928+
impl->registerLocalAccessor(LocalAccBaseHost, Handler);
1929+
}
1930+
1931+
void dynamic_parameter_base::updateLocalAccessor(range<3> NewAllocationSize) {
1932+
impl->updateLocalAccessor(NewAllocationSize);
1933+
}
1934+
19161935
void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue,
19171936
size_t Size) {
19181937
// Number of bytes is taken from member of raw_kernel_arg object rather
@@ -1968,6 +1987,53 @@ void dynamic_parameter_impl::updateAccessor(
19681987
sizeof(sycl::detail::AccessorBaseHost));
19691988
}
19701989

1990+
sycl::detail::LocalAccessorImplPtr
1991+
dynamic_parameter_impl::getLocalAccessor(handler *Handler) {
1992+
auto HandlerImpl = sycl::detail::getSyclObjImpl(*Handler);
1993+
auto FindLocalAcc = MHandlerToLocalAccMap.find(HandlerImpl);
1994+
1995+
if (FindLocalAcc != MHandlerToLocalAccMap.end()) {
1996+
auto LocalAccImpl = FindLocalAcc->second;
1997+
return LocalAccImpl;
1998+
}
1999+
return nullptr;
2000+
}
2001+
2002+
void dynamic_parameter_impl::registerLocalAccessor(
2003+
sycl::detail::LocalAccessorBaseHost *LocalAccBaseHost, handler *Handler) {
2004+
2005+
auto HandlerImpl = sycl::detail::getSyclObjImpl(*Handler);
2006+
auto LocalAccImpl = sycl::detail::getSyclObjImpl(*LocalAccBaseHost);
2007+
2008+
MHandlerToLocalAccMap.insert({HandlerImpl, LocalAccImpl});
2009+
}
2010+
2011+
void dynamic_parameter_impl::updateLocalAccessor(range<3> NewAllocationSize) {
2012+
2013+
for (auto &[NodeWeak, ArgIndex] : MNodes) {
2014+
auto NodeShared = NodeWeak.lock();
2015+
if (NodeShared) {
2016+
// We can use the first local accessor in the map since the dimensions
2017+
// and element type should be identical.
2018+
auto LocalAccessor = MHandlerToLocalAccMap.begin()->second;
2019+
dynamic_parameter_impl::updateCGLocalAccessor(
2020+
NodeShared->MCommandGroup, ArgIndex, NewAllocationSize,
2021+
LocalAccessor->MDims, LocalAccessor->MElemSize);
2022+
}
2023+
}
2024+
2025+
for (auto &DynCGInfo : MDynCGs) {
2026+
auto DynCG = DynCGInfo.DynCG.lock();
2027+
if (DynCG) {
2028+
auto &CG = DynCG->MKernels[DynCGInfo.CGIndex];
2029+
auto LocalAccessor = MHandlerToLocalAccMap.begin()->second;
2030+
dynamic_parameter_impl::updateCGLocalAccessor(
2031+
CG, DynCGInfo.ArgIndex, NewAllocationSize, LocalAccessor->MDims,
2032+
LocalAccessor->MElemSize);
2033+
}
2034+
}
2035+
}
2036+
19712037
void dynamic_parameter_impl::updateCGArgValue(
19722038
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex, const void *NewValue,
19732039
size_t Size) {
@@ -2033,6 +2099,27 @@ void dynamic_parameter_impl::updateCGAccessor(
20332099
}
20342100
}
20352101

2102+
void dynamic_parameter_impl::updateCGLocalAccessor(
2103+
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex,
2104+
range<3> NewAllocationSize, int Dims, int ElemSize) {
2105+
auto &Args = static_cast<sycl::detail::CGExecKernel *>(CG.get())->MArgs;
2106+
2107+
for (auto &Arg : Args) {
2108+
if (Arg.MIndex != ArgIndex) {
2109+
continue;
2110+
}
2111+
assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_std_layout);
2112+
2113+
int SizeInBytes = ElemSize;
2114+
for (int I = 0; I < Dims; ++I)
2115+
SizeInBytes *= NewAllocationSize[I];
2116+
SizeInBytes = std::max(SizeInBytes, 1);
2117+
2118+
Arg.MSize = SizeInBytes;
2119+
break;
2120+
}
2121+
}
2122+
20362123
dynamic_command_group_impl::dynamic_command_group_impl(
20372124
const command_graph<graph_state::modifiable> &Graph)
20382125
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {}
@@ -2154,6 +2241,7 @@ size_t dynamic_command_group::get_active_index() const {
21542241
void dynamic_command_group::set_active_index(size_t Index) {
21552242
return impl->setActiveIndex(Index);
21562243
}
2244+
21572245
} // namespace experimental
21582246
} // namespace oneapi
21592247
} // namespace ext

sycl/source/detail/graph_impl.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,10 @@ class exec_graph_impl {
14121412

14131413
class dynamic_parameter_impl {
14141414
public:
1415+
/// Used for parameters that don't have data such as local_accessors.
1416+
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl)
1417+
: MGraph(GraphImpl) {}
1418+
14151419
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
14161420
size_t ParamSize, const void *Data)
14171421
: MGraph(GraphImpl), MValueStorage(ParamSize) {
@@ -1477,6 +1481,26 @@ class dynamic_parameter_impl {
14771481
/// @param Acc The new accessor value
14781482
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
14791483

1484+
/// Updates the value of all local accessors in registered nodes and dynamic
1485+
/// CGs.
1486+
/// @param NewAllocationSize The new size for the update local accessors.
1487+
void updateLocalAccessor(range<3> NewAllocationSize);
1488+
1489+
/// Gets the implementation for the local accessor that is associated with
1490+
/// a specific handler.
1491+
/// @param The handler that the local accessor is associated with.
1492+
/// @return returns the impl object for the local accessor that is associated
1493+
/// with this handler. Or nullptr if no local accessor has been registered
1494+
/// for this handler.
1495+
sycl::detail::LocalAccessorImplPtr getLocalAccessor(handler *Handler);
1496+
1497+
/// Associates a local accessor with this dynamic local accessor for a
1498+
/// specific handler.
1499+
/// @param LocalAccBase the local accessor that needs to be registered.
1500+
/// @param Handler the handler that the LocalAccessor is associated with.
1501+
void registerLocalAccessor(sycl::detail::LocalAccessorBaseHost *LocalAccBase,
1502+
handler *Handler);
1503+
14801504
/// Static helper function for updating command-group value arguments.
14811505
/// @param CG The command-group to update the argument information for.
14821506
/// @param ArgIndex The argument index to update.
@@ -1493,13 +1517,29 @@ class dynamic_parameter_impl {
14931517
int ArgIndex,
14941518
const sycl::detail::AccessorBaseHost *Acc);
14951519

1520+
/// Static helper function for updating command-group local accessor
1521+
/// arguments.
1522+
/// @param CG The command-group to update the argument information for.
1523+
/// @param ArgIndex The argument index to update.
1524+
/// @param NewAllocationSize The new allocation size for the local accessor
1525+
/// argument.
1526+
/// @param Dims The dimensions of the local accessor argument.
1527+
/// @param ElemSize The size of each element in the local accessor.
1528+
static void updateCGLocalAccessor(std::shared_ptr<sycl::detail::CG> CG,
1529+
int ArgIndex, range<3> NewAllocationSize,
1530+
int Dims, int ElemSize);
1531+
14961532
// Weak ptrs to node_impls which will be updated
14971533
std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
14981534
// Dynamic command-groups which will be updated
14991535
std::vector<DynamicCGInfo> MDynCGs;
15001536

15011537
std::shared_ptr<graph_impl> MGraph;
15021538
std::vector<std::byte> MValueStorage;
1539+
1540+
std::unordered_map<std::shared_ptr<sycl::detail::handler_impl>,
1541+
sycl::detail::LocalAccessorImplPtr>
1542+
MHandlerToLocalAccMap;
15031543
};
15041544

15051545
class dynamic_command_group_impl
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG
4+
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
5+
// Extra run to check for immediate-command-list in Level Zero
6+
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
7+
8+
// Tests using dynamic command-group objects with dynamic local accessors.
9+
10+
#include "../graph_common.hpp"
11+
12+
int main() {
13+
using T = int;
14+
15+
const size_t LocalMemSize = 128;
16+
17+
queue Queue{};
18+
19+
std::vector<T> HostDataBeforeUpdate(Size);
20+
std::vector<T> HostDataAfterUpdate(Size);
21+
std::iota(HostDataBeforeUpdate.begin(), HostDataBeforeUpdate.end(), 10);
22+
23+
T *PtrA = malloc_device<T>(Size, Queue);
24+
Queue.copy(HostDataBeforeUpdate.data(), PtrA, Size);
25+
Queue.wait_and_throw();
26+
27+
exp_ext::command_graph Graph{Queue.get_context(), Queue.get_device()};
28+
29+
exp_ext::dynamic_local_accessor<T, 1> DynLocalAccessor{Graph, LocalMemSize};
30+
31+
auto CGFA = [&](handler &CGH) {
32+
CGH.set_arg(0, DynLocalAccessor);
33+
auto LocalMem = DynLocalAccessor.get(CGH);
34+
35+
CGH.parallel_for(nd_range({Size}, {LocalMemSize}), [=](nd_item<1> Item) {
36+
LocalMem[Item.get_local_linear_id()] = Item.get_local_linear_id();
37+
PtrA[Item.get_global_linear_id()] = LocalMem[Item.get_local_linear_id()];
38+
});
39+
};
40+
41+
auto CGFB = [&](handler &CGH) {
42+
CGH.set_arg(0, DynLocalAccessor);
43+
auto LocalMem = DynLocalAccessor.get(CGH);
44+
45+
CGH.parallel_for(
46+
nd_range({Size}, {LocalMemSize * 2}), [=](nd_item<1> Item) {
47+
LocalMem[Item.get_local_linear_id()] = Item.get_local_linear_id();
48+
PtrA[Item.get_global_linear_id()] =
49+
LocalMem[Item.get_local_linear_id()] * 2;
50+
});
51+
};
52+
53+
auto DynamicCG = exp_ext::dynamic_command_group(Graph, {CGFA, CGFB});
54+
auto DynamicCGNode = Graph.add(DynamicCG);
55+
auto ExecGraph = Graph.finalize(exp_ext::property::graph::updatable{});
56+
57+
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecGraph); });
58+
Queue.wait_and_throw();
59+
Queue.copy(PtrA, HostDataBeforeUpdate.data(), Size);
60+
Queue.wait_and_throw();
61+
62+
DynLocalAccessor.update(LocalMemSize * 2);
63+
DynamicCG.set_active_index(1);
64+
ExecGraph.update(DynamicCGNode);
65+
Queue.ext_oneapi_graph(ExecGraph).wait();
66+
67+
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecGraph); });
68+
Queue.wait_and_throw();
69+
Queue.copy(PtrA, HostDataAfterUpdate.data(), Size);
70+
Queue.wait_and_throw();
71+
72+
for (size_t i = 0; i < Size; i++) {
73+
T Ref = i % LocalMemSize;
74+
assert(check_value(i, Ref, HostDataBeforeUpdate[i], "PtrA Before Update"));
75+
}
76+
77+
for (size_t i = 0; i < Size; i++) {
78+
T Ref = i % (LocalMemSize * 2) * 2;
79+
assert(check_value(i, Ref, HostDataAfterUpdate[i], "PtrA After Update"));
80+
}
81+
82+
free(PtrA, Queue);
83+
84+
return 0;
85+
}

0 commit comments

Comments
 (0)