@@ -194,13 +194,20 @@ ur_kernel_handle_t_::getProperties(ur_device_handle_t hDevice) const {
194194}
195195
196196ur_result_t ur_kernel_handle_t_::setArgValue (
197- uint32_t argIndex, size_t argSize,
197+ ur_device_handle_t hDevice, uint32_t argIndex, size_t argSize,
198198 const ur_kernel_arg_value_properties_t * /* pProperties*/ ,
199199 const void *pArgValue) {
200200 if (argIndex > zeCommonProperties.numKernelArgs - 1 ) {
201201 return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
202202 }
203203
204+ if (hDevice) { // Set argument only on the specified device
205+ auto &deviceKernel = deviceKernels[deviceIndex (hDevice)].value ();
206+ UR_CALL (setArgValueOnZeKernel (deviceKernel.hKernel .get (), argIndex, argSize,
207+ pArgValue));
208+ return UR_RESULT_SUCCESS;
209+ }
210+
204211 for (auto &singleDeviceKernel : deviceKernels) {
205212 if (!singleDeviceKernel.has_value ()) {
206213 continue ;
@@ -213,12 +220,13 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
213220}
214221
215222ur_result_t ur_kernel_handle_t_::setArgPointer (
216- uint32_t argIndex,
223+ ur_device_handle_t hDevice, uint32_t argIndex,
217224 const ur_kernel_arg_pointer_properties_t * /* pProperties*/ ,
218225 const void *pArgValue) {
219226
220227 // KernelSetArgValue is expecting a pointer to the argument
221- return setArgValue (argIndex, sizeof (const void *), nullptr , &pArgValue);
228+ return setArgValue (hDevice, argIndex, sizeof (const void *), nullptr ,
229+ &pArgValue);
222230}
223231
224232ur_program_handle_t ur_kernel_handle_t_::getProgramHandle () const {
@@ -429,7 +437,8 @@ ur_result_t urKernelSetArgValue(
429437 TRACK_SCOPE_LATENCY (" urKernelSetArgValue" );
430438
431439 std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
432- return hKernel->setArgValue (argIndex, argSize, pProperties, pArgValue);
440+ return hKernel->setArgValue (nullptr , argIndex, argSize, pProperties,
441+ pArgValue);
433442} catch (...) {
434443 return exceptionToResult (std::current_exception ());
435444}
@@ -492,7 +501,7 @@ ur_result_t urKernelSetArgLocal(
492501
493502 std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
494503
495- return hKernel->setArgValue (argIndex, argSize, nullptr , nullptr );
504+ return hKernel->setArgValue (nullptr , argIndex, argSize, nullptr , nullptr );
496505} catch (...) {
497506 return exceptionToResult (std::current_exception ());
498507}
@@ -736,7 +745,7 @@ ur_result_t urKernelSetArgSampler(
736745 ur_sampler_handle_t hArgValue) try {
737746 TRACK_SCOPE_LATENCY (" urKernelSetArgSampler" );
738747 std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
739- return hKernel->setArgValue (argIndex, sizeof (void *), nullptr ,
748+ return hKernel->setArgValue (nullptr , argIndex, sizeof (void *), nullptr ,
740749 &hArgValue->ZeSampler );
741750} catch (...) {
742751 return exceptionToResult (std::current_exception ());
0 commit comments