Skip to content

Commit e9ecf06

Browse files
author
Ewan Crawford
committed
Improve solution
Iterate on previous solution so that the local argument offsets at following inidices are updated when an earlier local argument is updated
1 parent bddcb96 commit e9ecf06

File tree

5 files changed

+300
-77
lines changed

5 files changed

+300
-77
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
522522
DepsList.data(), DepsList.size(),
523523
&NodeParams));
524524

525-
if (LocalSize != 0)
526-
hKernel->clearLocalSize();
527-
528525
// Add signal node if external return event is used.
529526
CUgraphNode SignalNode = nullptr;
530527
if (phEvent) {
@@ -1396,22 +1393,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13961393

13971394
CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params;
13981395

1399-
const auto LocalSize = KernelCommandHandle->Kernel->getLocalSize();
1400-
if (LocalSize != 0) {
1401-
// Clean the local size, otherwise calling updateKernelArguments() in
1402-
// future updates with local arguments will incorrectly increase the
1403-
// size further.
1404-
KernelCommandHandle->Kernel->clearLocalSize();
1405-
}
1406-
14071396
Params.func = CuFunc;
1408-
Params.gridDimX = static_cast<unsigned int>(BlocksPerGrid[0]);
1409-
Params.gridDimY = static_cast<unsigned int>(BlocksPerGrid[1]);
1410-
Params.gridDimZ = static_cast<unsigned int>(BlocksPerGrid[2]);
1411-
Params.blockDimX = static_cast<unsigned int>(ThreadsPerBlock[0]);
1412-
Params.blockDimY = static_cast<unsigned int>(ThreadsPerBlock[1]);
1413-
Params.blockDimZ = static_cast<unsigned int>(ThreadsPerBlock[2]);
1414-
Params.sharedMemBytes = LocalSize;
1397+
Params.gridDimX = BlocksPerGrid[0];
1398+
Params.gridDimY = BlocksPerGrid[1];
1399+
Params.gridDimZ = BlocksPerGrid[2];
1400+
Params.blockDimX = ThreadsPerBlock[0];
1401+
Params.blockDimY = ThreadsPerBlock[1];
1402+
Params.blockDimZ = ThreadsPerBlock[2];
1403+
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
14151404
Params.kernelParams =
14161405
const_cast<void **>(KernelCommandHandle->Kernel->getArgIndices().data());
14171406

source/adapters/cuda/enqueue.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
485485
ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2], LocalSize,
486486
CuStream, const_cast<void **>(ArgIndices.data()), nullptr));
487487

488-
if (LocalSize != 0)
489-
hKernel->clearLocalSize();
490-
491488
if (phEvent) {
492489
UR_CHECK_ERROR(RetImplEvent->record());
493490
*phEvent = RetImplEvent.release();
@@ -665,9 +662,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
665662
const_cast<void **>(ArgIndices.data()),
666663
nullptr));
667664

668-
if (LocalSize != 0)
669-
hKernel->clearLocalSize();
670-
671665
if (phEvent) {
672666
UR_CHECK_ERROR(RetImplEvent->record());
673667
*phEvent = RetImplEvent.release();

source/adapters/cuda/kernel.hpp

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,22 @@ struct ur_kernel_handle_t_ {
6161
using args_t = std::array<char, MaxParamBytes>;
6262
using args_size_t = std::vector<size_t>;
6363
using args_index_t = std::vector<void *>;
64+
/// Storage shared by all args which is mem copied into when adding a new
65+
/// argument.
6466
args_t Storage;
67+
/// Aligned size of each parameter, including padding.
6568
args_size_t ParamSizes;
69+
/// Byte offset into /p Storage allocation for each parameter.
6670
args_index_t Indices;
67-
args_size_t OffsetPerIndex;
71+
/// Aligned size in bytes for each local memory parameter after padding has
72+
/// been added. Zero if the argument at the index isn't a local memory
73+
/// argument.
74+
args_size_t AlignedLocalMemSize;
75+
/// Original size in bytes for each local memory parameter, prior to being
76+
/// padded to appropriate alignment. Zero if the argument at the index
77+
/// isn't a local memory argument.
78+
args_size_t OriginalLocalMemSize;
79+
6880
// A struct to keep track of memargs so that we can do dependency analysis
6981
// at urEnqueueKernelLaunch
7082
struct mem_obj_arg {
@@ -93,7 +105,8 @@ struct ur_kernel_handle_t_ {
93105
Indices.resize(Index + 2, Indices.back());
94106
// Ensure enough space for the new argument
95107
ParamSizes.resize(Index + 1);
96-
OffsetPerIndex.resize(Index + 1);
108+
AlignedLocalMemSize.resize(Index + 1);
109+
OriginalLocalMemSize.resize(Index + 1);
97110
}
98111
ParamSizes[Index] = Size;
99112
// calculate the insertion point on the array
@@ -102,28 +115,83 @@ struct ur_kernel_handle_t_ {
102115
// Update the stored value for the argument
103116
std::memcpy(&Storage[InsertPos], Arg, Size);
104117
Indices[Index] = &Storage[InsertPos];
105-
OffsetPerIndex[Index] = LocalSize;
118+
AlignedLocalMemSize[Index] = LocalSize;
106119
}
107120

108-
void addLocalArg(size_t Index, size_t Size) {
109-
size_t LocalOffset = this->getLocalSize();
121+
/// Returns the padded size and offset of a local memory argument.
122+
/// Local memory arguments need to be padded if the alignment for the size
123+
/// doesn't match the current offset into the kernel local data.
124+
/// @param Index Kernel arg index.
125+
/// @param Size User passed size of local parameter.
126+
/// @return Tuple of (Aligned size, Aligned offset into local data).
127+
std::pair<size_t, size_t> calcAlignedLocalArgument(size_t Index,
128+
size_t Size) {
129+
// Store the unpadded size of the local argument
130+
if (Index + 2 > Indices.size()) {
131+
AlignedLocalMemSize.resize(Index + 1);
132+
OriginalLocalMemSize.resize(Index + 1);
133+
}
134+
OriginalLocalMemSize[Index] = Size;
135+
136+
// Calculate the current starting offset into local data
137+
const size_t LocalOffset = std::accumulate(
138+
std::begin(AlignedLocalMemSize),
139+
std::next(std::begin(AlignedLocalMemSize), Index), size_t{0});
110140

111-
// maximum required alignment is the size of the largest vector type
141+
// Maximum required alignment is the size of the largest vector type
112142
const size_t MaxAlignment = sizeof(double) * 16;
113143

114-
// for arguments smaller than the maximum alignment simply align to the
144+
// For arguments smaller than the maximum alignment simply align to the
115145
// size of the argument
116146
const size_t Alignment = std::min(MaxAlignment, Size);
117147

118-
// align the argument
148+
// Align the argument
119149
size_t AlignedLocalOffset = LocalOffset;
120-
size_t Pad = LocalOffset % Alignment;
150+
const size_t Pad = LocalOffset % Alignment;
121151
if (Pad != 0) {
122152
AlignedLocalOffset += Alignment - Pad;
123153
}
124154

155+
const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset);
156+
return std::make_pair(AlignedLocalSize, AlignedLocalOffset);
157+
}
158+
159+
void addLocalArg(size_t Index, size_t Size) {
160+
// Get the aligned argument size and offset into local data
161+
size_t AlignedLocalSize, AlignedLocalOffset;
162+
std::tie(AlignedLocalSize, AlignedLocalOffset) =
163+
calcAlignedLocalArgument(Index, Size);
164+
165+
// Store argument details
125166
addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset),
126-
Size + (AlignedLocalOffset - LocalOffset));
167+
AlignedLocalSize);
168+
169+
// For every existing local argument which follows at later argument
170+
// indices, updated the offset and pointer into the kernel local memory.
171+
// Required as padding will need to be recalculated.
172+
const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg
173+
for (auto SuccIndex = Index + 1; SuccIndex < NumArgs; SuccIndex++) {
174+
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
175+
if (OriginalLocalSize == 0) {
176+
// Skip if successor argument isn't a local memory arg
177+
continue;
178+
}
179+
180+
// Recalculate alignment
181+
size_t SuccAlignedLocalSize, SuccAlignedLocalOffset;
182+
std::tie(SuccAlignedLocalSize, SuccAlignedLocalOffset) =
183+
calcAlignedLocalArgument(SuccIndex, OriginalLocalSize);
184+
185+
// Store new local memory size
186+
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
187+
188+
// Store new offset into local data
189+
const size_t InsertPos =
190+
std::accumulate(std::begin(ParamSizes),
191+
std::begin(ParamSizes) + SuccIndex, size_t{0});
192+
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
193+
sizeof(size_t));
194+
}
127195
}
128196

129197
void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -145,15 +213,11 @@ struct ur_kernel_handle_t_ {
145213
std::memcpy(ImplicitOffsetArgs, ImplicitOffset, Size);
146214
}
147215

148-
void clearLocalSize() {
149-
std::fill(std::begin(OffsetPerIndex), std::end(OffsetPerIndex), 0);
150-
}
151-
152216
const args_index_t &getIndices() const noexcept { return Indices; }
153217

154218
uint32_t getLocalSize() const {
155-
return std::accumulate(std::begin(OffsetPerIndex),
156-
std::end(OffsetPerIndex), 0);
219+
return std::accumulate(std::begin(AlignedLocalMemSize),
220+
std::end(AlignedLocalMemSize), 0);
157221
}
158222
} Args;
159223

@@ -240,7 +304,5 @@ struct ur_kernel_handle_t_ {
240304

241305
uint32_t getLocalSize() const noexcept { return Args.getLocalSize(); }
242306

243-
void clearLocalSize() { Args.clearLocalSize(); }
244-
245307
size_t getRegsPerThread() const noexcept { return RegsPerThread; };
246308
};

test/conformance/device_code/saxpy_usm_local_mem.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int main() {
1616

1717
sycl_queue.submit([&](sycl::handler &cgh) {
1818
sycl::local_accessor<uint32_t, 1> local_mem_A(local_size, cgh);
19-
sycl::local_accessor<uint32_t, 1> local_mem_B(1, cgh);
19+
sycl::local_accessor<uint32_t, 1> local_mem_B(local_size * 2, cgh);
2020

2121
cgh.parallel_for<class saxpy_usm_local_mem>(
2222
sycl::nd_range<1>{{array_size}, {local_size}},
@@ -25,17 +25,12 @@ int main() {
2525
auto local_id = itemId.get_local_linear_id();
2626

2727
local_mem_A[local_id] = i;
28-
if (i == 0) {
29-
local_mem_B[0] = 0xA;
30-
}
28+
local_mem_B[local_id * 2] = -i;
29+
local_mem_B[(local_id * 2) + 1] = itemId.get_local_range(0);
3130

3231
Z[i] = A * X[i] + Y[i] + local_mem_A[local_id] +
33-
itemId.get_local_range(0);
34-
35-
if (i == 0) {
36-
Z[i] += local_mem_B[0];
37-
}
38-
32+
local_mem_B[local_id * 2] +
33+
local_mem_B[(local_id * 2) + 1];
3934
});
4035
});
4136
return 0;

0 commit comments

Comments
 (0)