|
12 | 12 |
|
13 | 13 | #include "context.hpp" |
14 | 14 | #include "kernel.hpp" |
| 15 | +#include "memory.hpp" |
15 | 16 |
|
16 | 17 | #include "../device.hpp" |
17 | 18 | #include "../platform.hpp" |
18 | 19 | #include "../program.hpp" |
19 | 20 | #include "../ur_interface_loader.hpp" |
20 | 21 |
|
21 | | -ur_single_device_kernel_t::ur_single_device_kernel_t(ze_device_handle_t hDevice, |
| 22 | +ur_single_device_kernel_t::ur_single_device_kernel_t(ur_device_handle_t hDevice, |
22 | 23 | ze_kernel_handle_t hKernel, |
23 | 24 | bool ownZeHandle) |
24 | 25 | : hDevice(hDevice), hKernel(hKernel, ownZeHandle) { |
@@ -54,7 +55,7 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(ur_program_handle_t hProgram, |
54 | 55 | assert(urDevice != hProgram->Context->getDevices().end()); |
55 | 56 | auto deviceId = (*urDevice)->Id.value(); |
56 | 57 |
|
57 | | - deviceKernels[deviceId].emplace(zeDevice, zeKernel, true); |
| 58 | + deviceKernels[deviceId].emplace(*urDevice, zeKernel, true); |
58 | 59 | } |
59 | 60 | completeInitialization(); |
60 | 61 | } |
@@ -118,7 +119,7 @@ ur_kernel_handle_t_::getZeHandle(ur_device_handle_t hDevice) { |
118 | 119 | auto &kernel = deviceKernels[0].value(); |
119 | 120 |
|
120 | 121 | // hDevice is nullptr for native handle |
121 | | - if ((kernel.hDevice != nullptr && kernel.hDevice != hDevice->ZeDevice)) { |
| 122 | + if ((kernel.hDevice != nullptr && kernel.hDevice != hDevice)) { |
122 | 123 | throw UR_RESULT_ERROR_INVALID_DEVICE; |
123 | 124 | } |
124 | 125 |
|
@@ -239,6 +240,16 @@ ur_result_t ur_kernel_handle_t_::setExecInfo(ur_kernel_exec_info_t propName, |
239 | 240 | return UR_RESULT_SUCCESS; |
240 | 241 | } |
241 | 242 |
|
| 243 | +std::vector<ur_device_handle_t> ur_kernel_handle_t_::getDevices() const { |
| 244 | + std::vector<ur_device_handle_t> devices; |
| 245 | + for (size_t i = 0; i < deviceKernels.size(); ++i) { |
| 246 | + if (deviceKernels[i].has_value()) { |
| 247 | + devices.push_back(deviceKernels[i].value().hDevice); |
| 248 | + } |
| 249 | + } |
| 250 | + return devices; |
| 251 | +} |
| 252 | + |
242 | 253 | namespace ur::level_zero { |
243 | 254 | ur_result_t urKernelCreate(ur_program_handle_t hProgram, |
244 | 255 | const char *pKernelName, |
@@ -291,6 +302,28 @@ ur_result_t urKernelSetArgPointer( |
291 | 302 | return hKernel->setArgPointer(argIndex, pProperties, pArgValue); |
292 | 303 | } |
293 | 304 |
|
| 305 | +ur_result_t |
| 306 | +urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex, |
| 307 | + const ur_kernel_arg_mem_obj_properties_t *pProperties, |
| 308 | + ur_mem_handle_t hArgValue) { |
| 309 | + TRACK_SCOPE_LATENCY("ur_kernel_handle_t_::setArgMemObj"); |
| 310 | + |
| 311 | + // TODO: support properties |
| 312 | + std::ignore = pProperties; |
| 313 | + |
| 314 | + auto kernelDevices = hKernel->getDevices(); |
| 315 | + if (kernelDevices.size() == 1) { |
| 316 | + auto zePtr = hArgValue->getPtr(kernelDevices.front()); |
| 317 | + return hKernel->setArgPointer(argIndex, nullptr, zePtr); |
| 318 | + } else { |
| 319 | + // TODO: Implement this for multi-device kernels. |
| 320 | + // Do this the same way as in legacy (keep a pending Args vector and |
| 321 | + // do actual allocation on kernel submission) or allocate the memory |
| 322 | + // immediately (only for small allocations?) |
| 323 | + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; |
| 324 | + } |
| 325 | +} |
| 326 | + |
294 | 327 | ur_result_t urKernelSetExecInfo( |
295 | 328 | ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object |
296 | 329 | ur_kernel_exec_info_t propName, ///< [in] name of the execution attribute |
|
0 commit comments