@@ -68,6 +68,8 @@ struct ur_kernel_handle_t_ {
6868 args_size_t ParamSizes;
6969 // / Byte offset into /p Storage allocation for each parameter.
7070 args_index_t Indices;
71+ // / Largest argument index that has been added to this kernel so far.
72+ size_t InsertPos = 0 ;
7173 // / Aligned size in bytes for each local memory parameter after padding has
7274 // / been added. Zero if the argument at the index isn't a local memory
7375 // / argument.
@@ -101,6 +103,8 @@ struct ur_kernel_handle_t_ {
101103 // / Implicit offset argument is kept at the back of the indices collection.
102104 void addArg (size_t Index, size_t Size, const void *Arg,
103105 size_t LocalSize = 0 ) {
106+
107+ // Expand storage to accommodate this Index if needed.
104108 if (Index + 2 > Indices.size ()) {
105109 // Move implicit offset argument index with the end
106110 Indices.resize (Index + 2 , Indices.back ());
@@ -109,14 +113,21 @@ struct ur_kernel_handle_t_ {
109113 AlignedLocalMemSize.resize (Index + 1 );
110114 OriginalLocalMemSize.resize (Index + 1 );
111115 }
112- ParamSizes[Index] = Size;
113- // calculate the insertion point on the array
114- size_t InsertPos = std::accumulate (std::begin (ParamSizes),
115- std::begin (ParamSizes) + Index, 0 );
116- // Update the stored value for the argument
117- std::memcpy (&Storage[InsertPos], Arg, Size);
118- Indices[Index] = &Storage[InsertPos];
119- AlignedLocalMemSize[Index] = LocalSize;
116+
117+ // Copy new argument to storage if it hasn't been added before.
118+ if (ParamSizes[Index] == 0 ) {
119+ ParamSizes[Index] = Size;
120+ std::memcpy (&Storage[InsertPos], Arg, Size);
121+ Indices[Index] = &Storage[InsertPos];
122+ AlignedLocalMemSize[Index] = LocalSize;
123+ InsertPos += Size;
124+ }
125+ // Otherwise, update the existing argument.
126+ else {
127+ std::memcpy (Indices[Index], Arg, Size);
128+ AlignedLocalMemSize[Index] = LocalSize;
129+ assert (Size == ParamSizes[Index]);
130+ }
120131 }
121132
122133 // / Returns the padded size and offset of a local memory argument.
@@ -177,10 +188,7 @@ struct ur_kernel_handle_t_ {
177188 AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
178189
179190 // Store new offset into local data
180- const size_t InsertPos =
181- std::accumulate (std::begin (ParamSizes),
182- std::begin (ParamSizes) + SuccIndex, size_t {0 });
183- std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
191+ std::memcpy (Indices[SuccIndex], &SuccAlignedLocalOffset,
184192 sizeof (size_t ));
185193 }
186194 }
0 commit comments