@@ -61,10 +61,22 @@ struct ur_kernel_handle_t_ {
6161 using args_t = std::array<char , MaxParamBytes>;
6262 using args_size_t = std::vector<size_t >;
6363 using args_index_t = std::vector<void *>;
64+ // / Storage shared by all args which is mem copied into when adding a new
65+ // / argument.
6466 args_t Storage;
67+ // / Aligned size of each parameter, including padding.
6568 args_size_t ParamSizes;
69+ // / Byte offset into /p Storage allocation for each parameter.
6670 args_index_t Indices;
67- args_size_t OffsetPerIndex;
71+ // / Aligned size in bytes for each local memory parameter after padding has
72+ // / been added. Zero if the argument at the index isn't a local memory
73+ // / argument.
74+ args_size_t AlignedLocalMemSize;
75+ // / Original size in bytes for each local memory parameter, prior to being
76+ // / padded to appropriate alignment. Zero if the argument at the index
77+ // / isn't a local memory argument.
78+ args_size_t OriginalLocalMemSize;
79+
6880 // A struct to keep track of memargs so that we can do dependency analysis
6981 // at urEnqueueKernelLaunch
7082 struct mem_obj_arg {
@@ -93,7 +105,8 @@ struct ur_kernel_handle_t_ {
93105 Indices.resize (Index + 2 , Indices.back ());
94106 // Ensure enough space for the new argument
95107 ParamSizes.resize (Index + 1 );
96- OffsetPerIndex.resize (Index + 1 );
108+ AlignedLocalMemSize.resize (Index + 1 );
109+ OriginalLocalMemSize.resize (Index + 1 );
97110 }
98111 ParamSizes[Index] = Size;
99112 // calculate the insertion point on the array
@@ -102,28 +115,81 @@ struct ur_kernel_handle_t_ {
102115 // Update the stored value for the argument
103116 std::memcpy (&Storage[InsertPos], Arg, Size);
104117 Indices[Index] = &Storage[InsertPos];
105- OffsetPerIndex [Index] = LocalSize;
118+ AlignedLocalMemSize [Index] = LocalSize;
106119 }
107120
108- void addLocalArg (size_t Index, size_t Size) {
109- size_t LocalOffset = this ->getLocalSize ();
121+ // / Returns the padded size and offset of a local memory argument.
122+ // / Local memory arguments need to be padded if the alignment for the size
123+ // / doesn't match the current offset into the kernel local data.
124+ // / @param Index Kernel arg index.
125+ // / @param Size User passed size of local parameter.
126+ // / @return Tuple of (Aligned size, Aligned offset into local data).
127+ std::pair<size_t , size_t > calcAlignedLocalArgument (size_t Index,
128+ size_t Size) {
129+ // Store the unpadded size of the local argument
130+ if (Index + 2 > Indices.size ()) {
131+ AlignedLocalMemSize.resize (Index + 1 );
132+ OriginalLocalMemSize.resize (Index + 1 );
133+ }
134+ OriginalLocalMemSize[Index] = Size;
135+
136+ // Calculate the current starting offset into local data
137+ const size_t LocalOffset = std::accumulate (
138+ std::begin (AlignedLocalMemSize),
139+ std::next (std::begin (AlignedLocalMemSize), Index), size_t {0 });
110140
111- // maximum required alignment is the size of the largest vector type
141+ // Maximum required alignment is the size of the largest vector type
112142 const size_t MaxAlignment = sizeof (double ) * 16 ;
113143
114- // for arguments smaller than the maximum alignment simply align to the
144+ // For arguments smaller than the maximum alignment simply align to the
115145 // size of the argument
116146 const size_t Alignment = std::min (MaxAlignment, Size);
117147
118- // align the argument
148+ // Align the argument
119149 size_t AlignedLocalOffset = LocalOffset;
120- size_t Pad = LocalOffset % Alignment;
150+ const size_t Pad = LocalOffset % Alignment;
121151 if (Pad != 0 ) {
122152 AlignedLocalOffset += Alignment - Pad;
123153 }
124154
155+ const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset);
156+ return std::make_pair (AlignedLocalSize, AlignedLocalOffset);
157+ }
158+
159+ void addLocalArg (size_t Index, size_t Size) {
160+ // Get the aligned argument size and offset into local data
161+ auto [AlignedLocalSize, AlignedLocalOffset] =
162+ calcAlignedLocalArgument (Index, Size);
163+
164+ // Store argument details
125165 addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
126- Size + (AlignedLocalOffset - LocalOffset));
166+ AlignedLocalSize);
167+
168+ // For every existing local argument which follows at later argument
169+ // indices, update the offset and pointer into the kernel local memory.
170+ // Required as padding will need to be recalculated.
171+ const size_t NumArgs = Indices.size () - 1 ; // Accounts for implicit arg
172+ for (auto SuccIndex = Index + 1 ; SuccIndex < NumArgs; SuccIndex++) {
173+ const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
174+ if (OriginalLocalSize == 0 ) {
175+ // Skip if successor argument isn't a local memory arg
176+ continue ;
177+ }
178+
179+ // Recalculate alignment
180+ auto [SuccAlignedLocalSize, SuccAlignedLocalOffset] =
181+ calcAlignedLocalArgument (SuccIndex, OriginalLocalSize);
182+
183+ // Store new local memory size
184+ AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
185+
186+ // Store new offset into local data
187+ const size_t InsertPos =
188+ std::accumulate (std::begin (ParamSizes),
189+ std::begin (ParamSizes) + SuccIndex, size_t {0 });
190+ std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
191+ sizeof (size_t ));
192+ }
127193 }
128194
129195 void addMemObjArg (int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -145,15 +211,11 @@ struct ur_kernel_handle_t_ {
145211 std::memcpy (ImplicitOffsetArgs, ImplicitOffset, Size);
146212 }
147213
148- void clearLocalSize () {
149- std::fill (std::begin (OffsetPerIndex), std::end (OffsetPerIndex), 0 );
150- }
151-
152214 const args_index_t &getIndices () const noexcept { return Indices; }
153215
154216 uint32_t getLocalSize () const {
155- return std::accumulate (std::begin (OffsetPerIndex ),
156- std::end (OffsetPerIndex ), 0 );
217+ return std::accumulate (std::begin (AlignedLocalMemSize ),
218+ std::end (AlignedLocalMemSize ), 0 );
157219 }
158220 } Args;
159221
@@ -240,7 +302,5 @@ struct ur_kernel_handle_t_ {
240302
241303 uint32_t getLocalSize () const noexcept { return Args.getLocalSize (); }
242304
243- void clearLocalSize () { Args.clearLocalSize (); }
244-
245305 size_t getRegsPerThread () const noexcept { return RegsPerThread; };
246306};
0 commit comments