Skip to content

Commit c3ef726

Browse files
committed
Add hDevice argument to ur_kernel_handle_t_::setArgValue()
Add hDevice argument to ur_kernel_handle_t_::setArgValue() to make it possible to set an argument only on the specified device. Signed-off-by: Lukasz Dorau <lukasz.dorau@intel.com>
1 parent 1c74478 commit c3ef726

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,19 +1123,20 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld(
11231123
wait_list_view &waitListView, ur_event_handle_t phEvent) {
11241124
{
11251125
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
1126+
ur_device_handle_t hDevice = this->hDevice.get();
11261127
for (uint32_t argIndex = 0; argIndex < numArgs; argIndex++) {
11271128
switch (pArgs[argIndex].type) {
11281129
case UR_EXP_KERNEL_ARG_TYPE_LOCAL:
1129-
UR_CALL(hKernel->setArgValue(pArgs[argIndex].index,
1130+
UR_CALL(hKernel->setArgValue(hDevice, pArgs[argIndex].index,
11301131
pArgs[argIndex].size, nullptr, nullptr));
11311132
break;
11321133
case UR_EXP_KERNEL_ARG_TYPE_VALUE:
1133-
UR_CALL(hKernel->setArgValue(pArgs[argIndex].index,
1134+
UR_CALL(hKernel->setArgValue(hDevice, pArgs[argIndex].index,
11341135
pArgs[argIndex].size, nullptr,
11351136
pArgs[argIndex].value.value));
11361137
break;
11371138
case UR_EXP_KERNEL_ARG_TYPE_POINTER:
1138-
UR_CALL(hKernel->setArgPointer(pArgs[argIndex].index, nullptr,
1139+
UR_CALL(hKernel->setArgPointer(hDevice, pArgs[argIndex].index, nullptr,
11391140
pArgs[argIndex].value.pointer));
11401141
break;
11411142
case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ:
@@ -1147,7 +1148,7 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld(
11471148
break;
11481149
case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: {
11491150
UR_CALL(
1150-
hKernel->setArgValue(argIndex, sizeof(void *), nullptr,
1151+
hKernel->setArgValue(hDevice, argIndex, sizeof(void *), nullptr,
11511152
&pArgs[argIndex].value.sampler->ZeSampler));
11521153
break;
11531154
}

unified-runtime/source/adapters/level_zero/v2/kernel.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,20 @@ ur_kernel_handle_t_::getProperties(ur_device_handle_t hDevice) const {
194194
}
195195

196196
ur_result_t ur_kernel_handle_t_::setArgValue(
197-
uint32_t argIndex, size_t argSize,
197+
ur_device_handle_t hDevice, uint32_t argIndex, size_t argSize,
198198
const ur_kernel_arg_value_properties_t * /*pProperties*/,
199199
const void *pArgValue) {
200200
if (argIndex > zeCommonProperties.numKernelArgs - 1) {
201201
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
202202
}
203203

204+
if (hDevice) { // Set argument only on the specified device
205+
auto &deviceKernel = deviceKernels[deviceIndex(hDevice)].value();
206+
UR_CALL(setArgValueOnZeKernel(deviceKernel.hKernel.get(), argIndex, argSize,
207+
pArgValue));
208+
return UR_RESULT_SUCCESS;
209+
}
210+
204211
for (auto &singleDeviceKernel : deviceKernels) {
205212
if (!singleDeviceKernel.has_value()) {
206213
continue;
@@ -213,12 +220,13 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
213220
}
214221

215222
ur_result_t ur_kernel_handle_t_::setArgPointer(
216-
uint32_t argIndex,
223+
ur_device_handle_t hDevice, uint32_t argIndex,
217224
const ur_kernel_arg_pointer_properties_t * /*pProperties*/,
218225
const void *pArgValue) {
219226

220227
// KernelSetArgValue is expecting a pointer to the argument
221-
return setArgValue(argIndex, sizeof(const void *), nullptr, &pArgValue);
228+
return setArgValue(hDevice, argIndex, sizeof(const void *), nullptr,
229+
&pArgValue);
222230
}
223231

224232
ur_program_handle_t ur_kernel_handle_t_::getProgramHandle() const {
@@ -429,7 +437,8 @@ ur_result_t urKernelSetArgValue(
429437
TRACK_SCOPE_LATENCY("urKernelSetArgValue");
430438

431439
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
432-
return hKernel->setArgValue(argIndex, argSize, pProperties, pArgValue);
440+
return hKernel->setArgValue(nullptr, argIndex, argSize, pProperties,
441+
pArgValue);
433442
} catch (...) {
434443
return exceptionToResult(std::current_exception());
435444
}
@@ -492,7 +501,7 @@ ur_result_t urKernelSetArgLocal(
492501

493502
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
494503

495-
return hKernel->setArgValue(argIndex, argSize, nullptr, nullptr);
504+
return hKernel->setArgValue(nullptr, argIndex, argSize, nullptr, nullptr);
496505
} catch (...) {
497506
return exceptionToResult(std::current_exception());
498507
}
@@ -736,7 +745,7 @@ ur_result_t urKernelSetArgSampler(
736745
ur_sampler_handle_t hArgValue) try {
737746
TRACK_SCOPE_LATENCY("urKernelSetArgSampler");
738747
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
739-
return hKernel->setArgValue(argIndex, sizeof(void *), nullptr,
748+
return hKernel->setArgValue(nullptr, argIndex, sizeof(void *), nullptr,
740749
&hArgValue->ZeSampler);
741750
} catch (...) {
742751
return exceptionToResult(std::current_exception());

unified-runtime/source/adapters/level_zero/v2/kernel.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,14 @@ struct ur_kernel_handle_t_ : ur_object {
6767
const ze_kernel_properties_t &getProperties(ur_device_handle_t hDevice) const;
6868

6969
// Implementation of urKernelSetArgValue.
70-
ur_result_t setArgValue(uint32_t argIndex, size_t argSize,
70+
ur_result_t setArgValue(ur_device_handle_t hDevice, uint32_t argIndex,
71+
size_t argSize,
7172
const ur_kernel_arg_value_properties_t *pProperties,
7273
const void *pArgValue);
7374

7475
// Implementation of urKernelSetArgPointer.
7576
ur_result_t
76-
setArgPointer(uint32_t argIndex,
77+
setArgPointer(ur_device_handle_t hDevice, uint32_t argIndex,
7778
const ur_kernel_arg_pointer_properties_t *pProperties,
7879
const void *pArgValue);
7980

0 commit comments

Comments
 (0)