@@ -61,10 +61,22 @@ struct ur_kernel_handle_t_ {
6161 using args_t = std::array<char , MaxParamBytes>;
6262 using args_size_t = std::vector<size_t >;
6363 using args_index_t = std::vector<void *>;
64+ // / Storage shared by all args which is mem copied into when adding a new
65+ // / argument.
6466 args_t Storage;
67+ // / Aligned size of each parameter, including padding.
6568 args_size_t ParamSizes;
69+ // / Byte offset into /p Storage allocation for each parameter.
6670 args_index_t Indices;
67- args_size_t OffsetPerIndex;
71+ // / Aligned size in bytes for each local memory parameter after padding has
72+ // / been added. Zero if the argument at the index isn't a local memory
73+ // / argument.
74+ args_size_t AlignedLocalMemSize;
75+ // / Original size in bytes for each local memory parameter, prior to being
76+ // / padded to appropriate alignment. Zero if the argument at the index
77+ // / isn't a local memory argument.
78+ args_size_t OriginalLocalMemSize;
79+
6880 // A struct to keep track of memargs so that we can do dependency analysis
6981 // at urEnqueueKernelLaunch
7082 struct mem_obj_arg {
@@ -93,7 +105,8 @@ struct ur_kernel_handle_t_ {
93105 Indices.resize (Index + 2 , Indices.back ());
94106 // Ensure enough space for the new argument
95107 ParamSizes.resize (Index + 1 );
96- OffsetPerIndex.resize (Index + 1 );
108+ AlignedLocalMemSize.resize (Index + 1 );
109+ OriginalLocalMemSize.resize (Index + 1 );
97110 }
98111 ParamSizes[Index] = Size;
99112 // calculate the insertion point on the array
@@ -102,28 +115,83 @@ struct ur_kernel_handle_t_ {
102115 // Update the stored value for the argument
103116 std::memcpy (&Storage[InsertPos], Arg, Size);
104117 Indices[Index] = &Storage[InsertPos];
105- OffsetPerIndex [Index] = LocalSize;
118+ AlignedLocalMemSize [Index] = LocalSize;
106119 }
107120
108- void addLocalArg (size_t Index, size_t Size) {
109- size_t LocalOffset = this ->getLocalSize ();
121+ // / Returns the padded size and offset of a local memory argument.
122+ // / Local memory arguments need to be padded if the alignment for the size
123+ // / doesn't match the current offset into the kernel local data.
124+ // / @param Index Kernel arg index.
125+ // / @param Size User passed size of local parameter.
126+ // / @return Tuple of (Aligned size, Aligned offset into local data).
127+ std::pair<size_t , size_t > calcAlignedLocalArgument (size_t Index,
128+ size_t Size) {
129+ // Store the unpadded size of the local argument
130+ if (Index + 2 > Indices.size ()) {
131+ AlignedLocalMemSize.resize (Index + 1 );
132+ OriginalLocalMemSize.resize (Index + 1 );
133+ }
134+ OriginalLocalMemSize[Index] = Size;
135+
136+ // Calculate the current starting offset into local data
137+ const size_t LocalOffset = std::accumulate (
138+ std::begin (AlignedLocalMemSize),
139+ std::next (std::begin (AlignedLocalMemSize), Index), size_t {0 });
110140
111- // maximum required alignment is the size of the largest vector type
141+ // Maximum required alignment is the size of the largest vector type
112142 const size_t MaxAlignment = sizeof (double ) * 16 ;
113143
114- // for arguments smaller than the maximum alignment simply align to the
144+ // For arguments smaller than the maximum alignment simply align to the
115145 // size of the argument
116146 const size_t Alignment = std::min (MaxAlignment, Size);
117147
118- // align the argument
148+ // Align the argument
119149 size_t AlignedLocalOffset = LocalOffset;
120- size_t Pad = LocalOffset % Alignment;
150+ const size_t Pad = LocalOffset % Alignment;
121151 if (Pad != 0 ) {
122152 AlignedLocalOffset += Alignment - Pad;
123153 }
124154
155+ const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset);
156+ return std::make_pair (AlignedLocalSize, AlignedLocalOffset);
157+ }
158+
159+ void addLocalArg (size_t Index, size_t Size) {
160+ // Get the aligned argument size and offset into local data
161+ size_t AlignedLocalSize, AlignedLocalOffset;
162+ std::tie (AlignedLocalSize, AlignedLocalOffset) =
163+ calcAlignedLocalArgument (Index, Size);
164+
165+ // Store argument details
125166 addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
126- Size + (AlignedLocalOffset - LocalOffset));
167+ AlignedLocalSize);
168+
169+ // For every existing local argument which follows at later argument
170+ // indices, updated the offset and pointer into the kernel local memory.
171+ // Required as padding will need to be recalculated.
172+ const size_t NumArgs = Indices.size () - 1 ; // Accounts for implicit arg
173+ for (auto SuccIndex = Index + 1 ; SuccIndex < NumArgs; SuccIndex++) {
174+ const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
175+ if (OriginalLocalSize == 0 ) {
176+ // Skip if successor argument isn't a local memory arg
177+ continue ;
178+ }
179+
180+ // Recalculate alignment
181+ size_t SuccAlignedLocalSize, SuccAlignedLocalOffset;
182+ std::tie (SuccAlignedLocalSize, SuccAlignedLocalOffset) =
183+ calcAlignedLocalArgument (SuccIndex, OriginalLocalSize);
184+
185+ // Store new local memory size
186+ AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
187+
188+ // Store new offset into local data
189+ const size_t InsertPos =
190+ std::accumulate (std::begin (ParamSizes),
191+ std::begin (ParamSizes) + SuccIndex, size_t {0 });
192+ std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
193+ sizeof (size_t ));
194+ }
127195 }
128196
129197 void addMemObjArg (int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -145,15 +213,11 @@ struct ur_kernel_handle_t_ {
145213 std::memcpy (ImplicitOffsetArgs, ImplicitOffset, Size);
146214 }
147215
148- void clearLocalSize () {
149- std::fill (std::begin (OffsetPerIndex), std::end (OffsetPerIndex), 0 );
150- }
151-
152216 const args_index_t &getIndices () const noexcept { return Indices; }
153217
154218 uint32_t getLocalSize () const {
155- return std::accumulate (std::begin (OffsetPerIndex ),
156- std::end (OffsetPerIndex ), 0 );
219+ return std::accumulate (std::begin (AlignedLocalMemSize ),
220+ std::end (AlignedLocalMemSize ), 0 );
157221 }
158222 } Args;
159223
@@ -240,7 +304,5 @@ struct ur_kernel_handle_t_ {
240304
241305 uint32_t getLocalSize () const noexcept { return Args.getLocalSize (); }
242306
243- void clearLocalSize () { Args.clearLocalSize (); }
244-
245307 size_t getRegsPerThread () const noexcept { return RegsPerThread; };
246308};
0 commit comments