@@ -63,6 +63,8 @@ struct ur_kernel_handle_t_ {
6363 args_size_t ParamSizes;
6464 // / Byte offset into /p Storage allocation for each parameter.
6565 args_index_t Indices;
66+ // / Largest argument index that has been added to this kernel so far.
67+ size_t InsertPos = 0 ;
6668 // / Aligned size in bytes for each local memory parameter after padding has
6769 // / been added. Zero if the argument at the index isn't a local memory
6870 // / argument.
@@ -95,22 +97,30 @@ struct ur_kernel_handle_t_ {
9597 // / Implicit offset argument is kept at the back of the indices collection.
9698 void addArg (size_t Index, size_t Size, const void *Arg,
9799 size_t LocalSize = 0 ) {
100+ // Expand storage to accommodate this Index if needed.
98101 if (Index + 2 > Indices.size ()) {
99- // Move implicit offset argument Index with the end
102+ // Move implicit offset argument index with the end
100103 Indices.resize (Index + 2 , Indices.back ());
101104 // Ensure enough space for the new argument
102105 ParamSizes.resize (Index + 1 );
103106 AlignedLocalMemSize.resize (Index + 1 );
104107 OriginalLocalMemSize.resize (Index + 1 );
105108 }
106- ParamSizes[Index] = Size;
107- // calculate the insertion point on the array
108- size_t InsertPos = std::accumulate (std::begin (ParamSizes),
109- std::begin (ParamSizes) + Index, 0 );
110- // Update the stored value for the argument
111- std::memcpy (&Storage[InsertPos], Arg, Size);
112- Indices[Index] = &Storage[InsertPos];
113- AlignedLocalMemSize[Index] = LocalSize;
109+
110+ // Copy new argument to storage if it hasn't been added before.
111+ if (ParamSizes[Index] == 0 ) {
112+ ParamSizes[Index] = Size;
113+ std::memcpy (&Storage[InsertPos], Arg, Size);
114+ Indices[Index] = &Storage[InsertPos];
115+ AlignedLocalMemSize[Index] = LocalSize;
116+ InsertPos += Size;
117+ }
118+ // Otherwise, update the existing argument.
119+ else {
120+ std::memcpy (Indices[Index], Arg, Size);
121+ AlignedLocalMemSize[Index] = LocalSize;
122+ assert (Size == ParamSizes[Index]);
123+ }
114124 }
115125
116126 // / Returns the padded size and offset of a local memory argument.
@@ -151,20 +161,11 @@ struct ur_kernel_handle_t_ {
151161 return std::make_pair (AlignedLocalSize, AlignedLocalOffset);
152162 }
153163
154- void addLocalArg (size_t Index, size_t Size) {
155- // Get the aligned argument size and offset into local data
156- auto [AlignedLocalSize, AlignedLocalOffset] =
157- calcAlignedLocalArgument (Index, Size);
158-
159- // Store argument details
160- addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
161- AlignedLocalSize);
162-
163- // For every existing local argument which follows at later argument
164- // indices, update the offset and pointer into the kernel local memory.
165- // Required as padding will need to be recalculated.
164+ // Iterate over all existing local argument which follows StartIndex
165+ // index, update the offset and pointer into the kernel local memory.
166+ void updateLocalArgOffset (size_t StartIndex) {
166167 const size_t NumArgs = Indices.size () - 1 ; // Accounts for implicit arg
167- for (auto SuccIndex = Index + 1 ; SuccIndex < NumArgs; SuccIndex++) {
168+ for (auto SuccIndex = StartIndex ; SuccIndex < NumArgs; SuccIndex++) {
168169 const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
169170 if (OriginalLocalSize == 0 ) {
170171 // Skip if successor argument isn't a local memory arg
@@ -179,14 +180,26 @@ struct ur_kernel_handle_t_ {
179180 AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
180181
181182 // Store new offset into local data
182- const size_t InsertPos =
183- std::accumulate (std::begin (ParamSizes),
184- std::begin (ParamSizes) + SuccIndex, size_t {0 });
185- std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
183+ std::memcpy (Indices[SuccIndex], &SuccAlignedLocalOffset,
186184 sizeof (size_t ));
187185 }
188186 }
189187
188+ void addLocalArg (size_t Index, size_t Size) {
189+ // Get the aligned argument size and offset into local data
190+ auto [AlignedLocalSize, AlignedLocalOffset] =
191+ calcAlignedLocalArgument (Index, Size);
192+
193+ // Store argument details
194+ addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
195+ AlignedLocalSize);
196+
197+ // For every existing local argument which follows at later argument
198+ // indices, update the offset and pointer into the kernel local memory.
199+ // Required as padding will need to be recalculated.
200+ updateLocalArgOffset (Index + 1 );
201+ }
202+
190203 void addMemObjArg (int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
191204 assert (hMem && " Invalid mem handle" );
192205 // To avoid redundancy we are not storing mem obj with index i at index
0 commit comments