@@ -98,17 +98,19 @@ ur_result_t ur_kernel_handle_t_::release() {
9898void ur_kernel_handle_t_::completeInitialization () {
9999 // Cache kernel name. Should be the same for all devices
100100 assert (deviceKernels.size () > 0 );
101- auto nonEmptyKernel =
102- std::find_if (deviceKernels.begin (), deviceKernels.end (),
103- [](const auto &kernel) { return kernel.has_value (); });
101+ nonEmptyKernel =
102+ &std::find_if (deviceKernels.begin (), deviceKernels.end (),
103+ [](const auto &kernel) { return kernel.has_value (); })
104+ ->value ();
104105
105- zeKernelName .Compute = [kernel =
106- &nonEmptyKernel-> value ()](std::string &name ) {
106+ zeCommonProperties .Compute = [kernel = nonEmptyKernel](
107+ common_properties_t &props ) {
107108 size_t size = 0 ;
108109 ZE_CALL_NOCHECK (zeKernelGetName, (kernel->hKernel .get (), &size, nullptr ));
109- name.resize (size);
110+ props. name .resize (size);
110111 ZE_CALL_NOCHECK (zeKernelGetName,
111- (kernel->hKernel .get (), &size, name.data ()));
112+ (kernel->hKernel .get (), &size, props.name .data ()));
113+ props.numKernelArgs = kernel->zeKernelProperties ->numKernelArgs ;
112114 };
113115}
114116
@@ -142,8 +144,9 @@ ur_kernel_handle_t_::getZeHandle(ur_device_handle_t hDevice) {
142144 return deviceKernels[hDevice->Id .value ()].value ().hKernel .get ();
143145}
144146
145- const std::string &ur_kernel_handle_t_::getName () const {
146- return *zeKernelName.operator ->();
147+ ur_kernel_handle_t_::common_properties_t
148+ ur_kernel_handle_t_::getCommonProperties () const {
149+ return zeCommonProperties.get ();
147150}
148151
149152const ze_kernel_properties_t &
@@ -154,10 +157,7 @@ ur_kernel_handle_t_::getProperties(ur_device_handle_t hDevice) const {
154157
155158 assert (deviceKernels[hDevice->Id .value ()].value ().hKernel .get ());
156159
157- return *deviceKernels[hDevice->Id .value ()]
158- .value ()
159- .zeKernelProperties .
160- operator ->();
160+ return deviceKernels[hDevice->Id .value ()].value ().zeKernelProperties .get ();
161161}
162162
163163ur_result_t ur_kernel_handle_t_::setArgValue (
@@ -178,16 +178,26 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
178178 pArgValue = nullptr ;
179179 }
180180
181+ if (argIndex > zeCommonProperties->numKernelArgs - 1 ) {
182+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
183+ }
184+
181185 std::scoped_lock<ur_shared_mutex> guard (Mutex);
182186
183187 for (auto &singleDeviceKernel : deviceKernels) {
184188 if (!singleDeviceKernel.has_value ()) {
185189 continue ;
186190 }
187191
188- ZE2UR_CALL (zeKernelSetArgumentValue,
189- (singleDeviceKernel.value ().hKernel .get (), argIndex, argSize,
190- pArgValue));
192+ auto zeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
193+ (singleDeviceKernel.value ().hKernel .get (),
194+ argIndex, argSize, pArgValue));
195+
196+ if (zeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
197+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
198+ } else if (zeResult != ZE_RESULT_SUCCESS) {
199+ return ze2urResult (zeResult);
200+ }
191201 }
192202 return UR_RESULT_SUCCESS;
193203}
@@ -257,6 +267,17 @@ std::vector<ur_device_handle_t> ur_kernel_handle_t_::getDevices() const {
257267 return devices;
258268}
259269
270+ std::vector<char > ur_kernel_handle_t_::getSourceAttributes () const {
271+ uint32_t size;
272+ ZE2UR_CALL_THROWS (zeKernelGetSourceAttributes,
273+ (nonEmptyKernel->hKernel .get (), &size, nullptr ));
274+ std::vector<char > attributes (size);
275+ char *dataPtr = attributes.data ();
276+ ZE2UR_CALL_THROWS (zeKernelGetSourceAttributes,
277+ (nonEmptyKernel->hKernel .get (), &size, &dataPtr));
278+ return attributes;
279+ }
280+
260281namespace ur ::level_zero {
261282ur_result_t urKernelCreate (ur_program_handle_t hProgram,
262283 const char *pKernelName,
@@ -477,4 +498,40 @@ ur_result_t urKernelGetSubGroupInfo(
477498 }
478499 return UR_RESULT_SUCCESS;
479500}
501+
502+ ur_result_t urKernelGetInfo (ur_kernel_handle_t hKernel,
503+ ur_kernel_info_t paramName, size_t propSize,
504+ void *pKernelInfo, size_t *pPropSizeRet) {
505+
506+ UrReturnHelper ReturnValue (propSize, pKernelInfo, pPropSizeRet);
507+
508+ std::shared_lock<ur_shared_mutex> Guard (hKernel->Mutex );
509+ switch (paramName) {
510+ case UR_KERNEL_INFO_CONTEXT:
511+ return ReturnValue (
512+ ur_context_handle_t {hKernel->getProgramHandle ()->Context });
513+ case UR_KERNEL_INFO_PROGRAM:
514+ return ReturnValue (ur_program_handle_t {hKernel->getProgramHandle ()});
515+ case UR_KERNEL_INFO_FUNCTION_NAME: {
516+ auto kernelName = hKernel->getCommonProperties ().name ;
517+ return ReturnValue (static_cast <const char *>(kernelName.c_str ()));
518+ }
519+ case UR_KERNEL_INFO_NUM_REGS:
520+ case UR_KERNEL_INFO_NUM_ARGS:
521+ return ReturnValue (uint32_t {hKernel->getCommonProperties ().numKernelArgs });
522+ case UR_KERNEL_INFO_REFERENCE_COUNT:
523+ return ReturnValue (uint32_t {hKernel->RefCount .load ()});
524+ case UR_KERNEL_INFO_ATTRIBUTES: {
525+ auto attributes = hKernel->getSourceAttributes ();
526+ return ReturnValue (static_cast <const char *>(attributes.data ()));
527+ }
528+ default :
529+ logger::error (
530+ " Unsupported ParamName in urKernelGetInfo: ParamName={}(0x{})" ,
531+ paramName, logger::toHex (paramName));
532+ return UR_RESULT_ERROR_INVALID_VALUE;
533+ }
534+
535+ return UR_RESULT_SUCCESS;
536+ }
480537} // namespace ur::level_zero
0 commit comments