@@ -56,10 +56,22 @@ struct ur_kernel_handle_t_ {
5656 using args_t = std::array<char , MAX_PARAM_BYTES>;
5757 using args_size_t = std::vector<size_t >;
5858 using args_index_t = std::vector<void *>;
59+ // / Storage shared by all args which is mem copied into when adding a new
60+ // / argument.
5961 args_t Storage;
62+ // / Aligned size of each parameter, including padding.
6063 args_size_t ParamSizes;
64+ // / Byte offset into /p Storage allocation for each parameter.
6165 args_index_t Indices;
62- args_size_t OffsetPerIndex;
66+ // / Aligned size in bytes for each local memory parameter after padding has
67+ // / been added. Zero if the argument at the index isn't a local memory
68+ // / argument.
69+ args_size_t AlignedLocalMemSize;
70+ // / Original size in bytes for each local memory parameter, prior to being
71+ // / padded to appropriate alignment. Zero if the argument at the index
72+ // / isn't a local memory argument.
73+ args_size_t OriginalLocalMemSize;
74+
6375 // A struct to keep track of memargs so that we can do dependency analysis
6476 // at urEnqueueKernelLaunch
6577 struct mem_obj_arg {
@@ -88,7 +100,8 @@ struct ur_kernel_handle_t_ {
88100 Indices.resize (Index + 2 , Indices.back ());
89101 // Ensure enough space for the new argument
90102 ParamSizes.resize (Index + 1 );
91- OffsetPerIndex.resize (Index + 1 );
103+ AlignedLocalMemSize.resize (Index + 1 );
104+ OriginalLocalMemSize.resize (Index + 1 );
92105 }
93106 ParamSizes[Index] = Size;
94107 // calculate the insertion point on the array
@@ -97,28 +110,83 @@ struct ur_kernel_handle_t_ {
97110 // Update the stored value for the argument
98111 std::memcpy (&Storage[InsertPos], Arg, Size);
99112 Indices[Index] = &Storage[InsertPos];
100- OffsetPerIndex [Index] = LocalSize;
113+ AlignedLocalMemSize [Index] = LocalSize;
101114 }
102115
103- void addLocalArg (size_t Index, size_t Size) {
104- size_t LocalOffset = this ->getLocalSize ();
116+ // / Returns the padded size and offset of a local memory argument.
117+ // / Local memory arguments need to be padded if the alignment for the size
118+ // / doesn't match the current offset into the kernel local data.
119+ // / @param Index Kernel arg index.
120+ // / @param Size User passed size of local parameter.
121+ // / @return Tuple of (Aligned size, Aligned offset into local data).
122+ std::pair<size_t , size_t > calcAlignedLocalArgument (size_t Index,
123+ size_t Size) {
124+ // Store the unpadded size of the local argument
125+ if (Index + 2 > Indices.size ()) {
126+ AlignedLocalMemSize.resize (Index + 1 );
127+ OriginalLocalMemSize.resize (Index + 1 );
128+ }
129+ OriginalLocalMemSize[Index] = Size;
105130
106- // maximum required alignment is the size of the largest vector type
131+ // Calculate the current starting offset into local data
132+ const size_t LocalOffset = std::accumulate (
133+ std::begin (AlignedLocalMemSize),
134+ std::next (std::begin (AlignedLocalMemSize), Index), size_t {0 });
135+
136+ // Maximum required alignment is the size of the largest vector type
107137 const size_t MaxAlignment = sizeof (double ) * 16 ;
108138
109- // for arguments smaller than the maximum alignment simply align to the
139+ // For arguments smaller than the maximum alignment simply align to the
110140 // size of the argument
111141 const size_t Alignment = std::min (MaxAlignment, Size);
112142
113- // align the argument
143+ // Align the argument
114144 size_t AlignedLocalOffset = LocalOffset;
115- size_t Pad = LocalOffset % Alignment;
145+ const size_t Pad = LocalOffset % Alignment;
116146 if (Pad != 0 ) {
117147 AlignedLocalOffset += Alignment - Pad;
118148 }
119149
120- addArg (Index, sizeof (size_t ), (const void *)&AlignedLocalOffset,
121- Size + AlignedLocalOffset - LocalOffset);
150+ const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset);
151+ return std::make_pair (AlignedLocalSize, AlignedLocalOffset);
152+ }
153+
154+ void addLocalArg (size_t Index, size_t Size) {
155+ // Get the aligned argument size and offset into local data
156+ size_t AlignedLocalSize, AlignedLocalOffset;
157+ std::tie (AlignedLocalSize, AlignedLocalOffset) =
158+ calcAlignedLocalArgument (Index, Size);
159+
160+ // Store argument details
161+ addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
162+ AlignedLocalSize);
163+
164+ // For every existing local argument which follows at later argument
165+ // indices, updated the offset and pointer into the kernel local memory.
166+ // Required as padding will need to be recalculated.
167+ const size_t NumArgs = Indices.size () - 1 ; // Accounts for implicit arg
168+ for (auto SuccIndex = Index + 1 ; SuccIndex < NumArgs; SuccIndex++) {
169+ const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
170+ if (OriginalLocalSize == 0 ) {
171+ // Skip if successor argument isn't a local memory arg
172+ continue ;
173+ }
174+
175+ // Recalculate alignment
176+ size_t SuccAlignedLocalSize, SuccAlignedLocalOffset;
177+ std::tie (SuccAlignedLocalSize, SuccAlignedLocalOffset) =
178+ calcAlignedLocalArgument (SuccIndex, OriginalLocalSize);
179+
180+ // Store new local memory size
181+ AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
182+
183+ // Store new offset into local data
184+ const size_t InsertPos =
185+ std::accumulate (std::begin (ParamSizes),
186+ std::begin (ParamSizes) + SuccIndex, size_t {0 });
187+ std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
188+ sizeof (size_t ));
189+ }
122190 }
123191
124192 void addMemObjArg (int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -140,15 +208,11 @@ struct ur_kernel_handle_t_ {
140208 std::memcpy (ImplicitOffsetArgs, ImplicitOffset, Size);
141209 }
142210
143- void clearLocalSize () {
144- std::fill (std::begin (OffsetPerIndex), std::end (OffsetPerIndex), 0 );
145- }
146-
147211 const args_index_t &getIndices () const noexcept { return Indices; }
148212
149213 uint32_t getLocalSize () const {
150- return std::accumulate (std::begin (OffsetPerIndex ),
151- std::end (OffsetPerIndex ), 0 );
214+ return std::accumulate (std::begin (AlignedLocalMemSize ),
215+ std::end (AlignedLocalMemSize ), 0 );
152216 }
153217 } Args;
154218
@@ -220,6 +284,4 @@ struct ur_kernel_handle_t_ {
220284 }
221285
222286 uint32_t getLocalSize () const noexcept { return Args.getLocalSize (); }
223-
224- void clearLocalSize () { Args.clearLocalSize (); }
225287};
0 commit comments