@@ -66,8 +66,10 @@ struct ur_kernel_handle_t_ {
6666 args_t Storage;
6767 // / Aligned size of each parameter, including padding.
6868 args_size_t ParamSizes;
69- // / Byte offset into /p Storage allocation for each parameter.
70- args_index_t Indices;
69+ // / Byte offset into /p Storage allocation for each argument.
70+ args_index_t ArgPointers;
71+ // / Position in the Storage array where the next argument should added.
72+ size_t InsertPos = 0 ;
7173 // / Aligned size in bytes for each local memory parameter after padding has
7274 // / been added. Zero if the argument at the index isn't a local memory
7375 // / argument.
@@ -90,33 +92,43 @@ struct ur_kernel_handle_t_ {
9092 std::uint32_t ImplicitOffsetArgs[3 ] = {0 , 0 , 0 };
9193
9294 arguments () {
93- // Place the implicit offset index at the end of the indicies collection
94- Indices.emplace_back (&ImplicitOffsetArgs);
95+ // Place the implicit offset index at the end of the ArgPointers
96+ // collection.
97+ ArgPointers.emplace_back (&ImplicitOffsetArgs);
9598 }
9699
97100 // / Add an argument to the kernel.
98101 // / If the argument existed before, it is replaced.
99102 // / Otherwise, it is added.
100103 // / Gaps are filled with empty arguments.
101- // / Implicit offset argument is kept at the back of the indices collection.
104+ // / Implicit offset argument is kept at the back of the ArgPointers
105+ // / collection.
102106 void addArg (size_t Index, size_t Size, const void *Arg,
103107 size_t LocalSize = 0 ) {
104- if (Index + 2 > Indices.size ()) {
108+ // Expand storage to accommodate this Index if needed.
109+ if (Index + 2 > ArgPointers.size ()) {
105110 // Move implicit offset argument index with the end
106- Indices .resize (Index + 2 , Indices .back ());
111+ ArgPointers .resize (Index + 2 , ArgPointers .back ());
107112 // Ensure enough space for the new argument
108113 ParamSizes.resize (Index + 1 );
109114 AlignedLocalMemSize.resize (Index + 1 );
110115 OriginalLocalMemSize.resize (Index + 1 );
111116 }
112- ParamSizes[Index] = Size;
113- // calculate the insertion point on the array
114- size_t InsertPos = std::accumulate (std::begin (ParamSizes),
115- std::begin (ParamSizes) + Index, 0 );
116- // Update the stored value for the argument
117- std::memcpy (&Storage[InsertPos], Arg, Size);
118- Indices[Index] = &Storage[InsertPos];
119- AlignedLocalMemSize[Index] = LocalSize;
117+
118+ // Copy new argument to storage if it hasn't been added before.
119+ if (ParamSizes[Index] == 0 ) {
120+ ParamSizes[Index] = Size;
121+ std::memcpy (&Storage[InsertPos], Arg, Size);
122+ ArgPointers[Index] = &Storage[InsertPos];
123+ AlignedLocalMemSize[Index] = LocalSize;
124+ InsertPos += Size;
125+ }
126+ // Otherwise, update the existing argument.
127+ else {
128+ std::memcpy (ArgPointers[Index], Arg, Size);
129+ AlignedLocalMemSize[Index] = LocalSize;
130+ assert (Size == ParamSizes[Index]);
131+ }
120132 }
121133
122134 // / Returns the padded size and offset of a local memory argument.
@@ -128,7 +140,7 @@ struct ur_kernel_handle_t_ {
128140 std::pair<size_t , size_t > calcAlignedLocalArgument (size_t Index,
129141 size_t Size) {
130142 // Store the unpadded size of the local argument
131- if (Index + 2 > Indices .size ()) {
143+ if (Index + 2 > ArgPointers .size ()) {
132144 AlignedLocalMemSize.resize (Index + 1 );
133145 OriginalLocalMemSize.resize (Index + 1 );
134146 }
@@ -158,10 +170,11 @@ struct ur_kernel_handle_t_ {
158170 return std::make_pair (AlignedLocalSize, AlignedLocalOffset);
159171 }
160172
161- // Iterate over all existing local argument which follows StartIndex
173+ // Iterate over each existing local argument which follows StartIndex
162174 // index, update the offset and pointer into the kernel local memory.
163175 void updateLocalArgOffset (size_t StartIndex) {
164- const size_t NumArgs = Indices.size () - 1 ; // Accounts for implicit arg
176+ const size_t NumArgs =
177+ ArgPointers.size () - 1 ; // Accounts for implicit arg
165178 for (auto SuccIndex = StartIndex; SuccIndex < NumArgs; SuccIndex++) {
166179 const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
167180 if (OriginalLocalSize == 0 ) {
@@ -177,10 +190,7 @@ struct ur_kernel_handle_t_ {
177190 AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
178191
179192 // Store new offset into local data
180- const size_t InsertPos =
181- std::accumulate (std::begin (ParamSizes),
182- std::begin (ParamSizes) + SuccIndex, size_t {0 });
183- std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
193+ std::memcpy (ArgPointers[SuccIndex], &SuccAlignedLocalOffset,
184194 sizeof (size_t ));
185195 }
186196 }
@@ -228,7 +238,7 @@ struct ur_kernel_handle_t_ {
228238 std::memcpy (ImplicitOffsetArgs, ImplicitOffset, Size);
229239 }
230240
231- const args_index_t &getIndices () const noexcept { return Indices ; }
241+ const args_index_t &getArgPointers () const noexcept { return ArgPointers ; }
232242
233243 uint32_t getLocalSize () const {
234244 return std::accumulate (std::begin (AlignedLocalMemSize),
@@ -299,7 +309,7 @@ struct ur_kernel_handle_t_ {
299309 // / real one required by the kernel, since this cannot be queried from
300310 // / the CUDA Driver API
301311 uint32_t getNumArgs () const noexcept {
302- return static_cast <uint32_t >(Args.Indices .size () - 1 );
312+ return static_cast <uint32_t >(Args.ArgPointers .size () - 1 );
303313 }
304314
305315 void setKernelArg (int Index, size_t Size, const void *Arg) {
@@ -314,8 +324,8 @@ struct ur_kernel_handle_t_ {
314324 return Args.setImplicitOffset (Size, ImplicitOffset);
315325 }
316326
317- const arguments::args_index_t &getArgIndices () const {
318- return Args.getIndices ();
327+ const arguments::args_index_t &getArgPointers () const {
328+ return Args.getArgPointers ();
319329 }
320330
321331 void setWorkGroupMemory (size_t MemSize) { Args.setWorkGroupMemory (MemSize); }
0 commit comments