@@ -710,7 +710,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
710710 ZeKernelDesc.pKernelName = KernelName;
711711
712712 ze_kernel_handle_t ZeKernel;
713- ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
713+ auto ZeResult =
714+ ZE_CALL_NOCHECK (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
715+ // Gracefully handle the case that kernel create fails.
716+ if (ZeResult != ZE_RESULT_SUCCESS) {
717+ delete *RetKernel;
718+ *RetKernel = nullptr ;
719+ return ze2urResult (ZeResult);
720+ }
714721
715722 auto ZeDevice = It.first ;
716723
@@ -764,20 +771,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
764771 PArgValue = nullptr ;
765772 }
766773
774+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
775+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
776+ }
777+
767778 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
779+ ze_result_t ZeResult = ZE_RESULT_SUCCESS;
768780 if (Kernel->ZeKernelMap .empty ()) {
769781 auto ZeKernel = Kernel->ZeKernel ;
770- ZE2UR_CALL (zeKernelSetArgumentValue,
771- (ZeKernel, ArgIndex, ArgSize, PArgValue));
782+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
783+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
772784 } else {
773785 for (auto It : Kernel->ZeKernelMap ) {
774786 auto ZeKernel = It.second ;
775- ZE2UR_CALL (zeKernelSetArgumentValue,
776- (ZeKernel, ArgIndex, ArgSize, PArgValue));
787+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
788+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
777789 }
778790 }
779791
780- return UR_RESULT_SUCCESS;
792+ if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
793+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
794+ }
795+
796+ return ze2urResult (ZeResult);
781797}
782798
783799UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal (
@@ -826,6 +842,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(
826842 } catch (...) {
827843 return UR_RESULT_ERROR_UNKNOWN;
828844 }
845+ case UR_KERNEL_INFO_NUM_REGS:
829846 case UR_KERNEL_INFO_NUM_ARGS:
830847 return ReturnValue (uint32_t {Kernel->ZeKernelProperties ->numKernelArgs });
831848 case UR_KERNEL_INFO_REFERENCE_COUNT:
@@ -1076,6 +1093,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler(
10761093) {
10771094 std::ignore = Properties;
10781095 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
1096+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1097+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1098+ }
10791099 ZE2UR_CALL (zeKernelSetArgumentValue, (Kernel->ZeKernel , ArgIndex,
10801100 sizeof (void *), &ArgValue->ZeSampler ));
10811101
@@ -1095,6 +1115,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgMemObj(
10951115 // The ArgValue may be a NULL pointer in which case a NULL value is used for
10961116 // the kernel argument declared as a pointer to global or constant memory.
10971117
1118+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1119+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1120+ }
1121+
10981122 ur_mem_handle_t_ *UrMem = ur_cast<ur_mem_handle_t_ *>(ArgValue);
10991123
11001124 ur_mem_handle_t_::access_mode_t UrAccessMode = ur_mem_handle_t_::read_write;
0 commit comments