Skip to content

Commit 528ca1a

Browse files
committed
Add program handle
1 parent 2b29c7f commit 528ca1a

File tree

4 files changed

+91
-67
lines changed

4 files changed

+91
-67
lines changed

source/adapters/opencl/enqueue.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "common.hpp"
12+
#include "program.hpp"
1213

1314
cl_map_flags convertURMapFlagsToCL(ur_map_flags_t URFlags) {
1415
cl_map_flags CLFlags = 0;
@@ -453,7 +454,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
453454
return UR_RESULT_ERROR_INVALID_OPERATION;
454455

455456
Res = F(cl_adapter::cast<cl_command_queue>(hQueue),
456-
cl_adapter::cast<cl_program>(hProgram), name, blockingWrite, count,
457+
hProgram->get(), name, blockingWrite, count,
457458
offset, pSrc, numEventsInWaitList,
458459
cl_adapter::cast<const cl_event *>(phEventWaitList),
459460
cl_adapter::cast<cl_event *>(phEvent));
@@ -484,7 +485,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
484485
return UR_RESULT_ERROR_INVALID_OPERATION;
485486

486487
Res = F(cl_adapter::cast<cl_command_queue>(hQueue),
487-
cl_adapter::cast<cl_program>(hProgram), name, blockingRead, count,
488+
hProgram->get(), name, blockingRead, count,
488489
offset, pDst, numEventsInWaitList,
489490
cl_adapter::cast<const cl_event *>(phEventWaitList),
490491
cl_adapter::cast<cl_event *>(phEvent));
@@ -515,7 +516,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(
515516
if (FuncPtr) {
516517
RetVal = mapCLErrorToUR(
517518
FuncPtr(cl_adapter::cast<cl_command_queue>(hQueue),
518-
cl_adapter::cast<cl_program>(hProgram), pipe_symbol, blocking,
519+
hProgram->get(), pipe_symbol, blocking,
519520
pDst, size, numEventsInWaitList,
520521
cl_adapter::cast<const cl_event *>(phEventWaitList),
521522
cl_adapter::cast<cl_event *>(phEvent)));
@@ -547,7 +548,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
547548
if (FuncPtr) {
548549
RetVal = mapCLErrorToUR(
549550
FuncPtr(cl_adapter::cast<cl_command_queue>(hQueue),
550-
cl_adapter::cast<cl_program>(hProgram), pipe_symbol, blocking,
551+
hProgram->get(), pipe_symbol, blocking,
551552
pSrc, size, numEventsInWaitList,
552553
cl_adapter::cast<const cl_event *>(phEventWaitList),
553554
cl_adapter::cast<cl_event *>(phEvent)));

source/adapters/opencl/kernel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010
#include "common.hpp"
1111
#include "device.hpp"
12+
#include "program.hpp"
1213

1314
#include <algorithm>
1415
#include <memory>
@@ -19,7 +20,7 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
1920

2021
cl_int CLResult;
2122
*phKernel = cl_adapter::cast<ur_kernel_handle_t>(clCreateKernel(
22-
cl_adapter::cast<cl_program>(hProgram), pKernelName, &CLResult));
23+
hProgram->get(), pKernelName, &CLResult));
2324
CL_RETURN_ON_FAILURE(CLResult);
2425
return UR_RESULT_SUCCESS;
2526
}

source/adapters/opencl/program.cpp

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,31 @@
1212
#include "context.hpp"
1313
#include "device.hpp"
1414
#include "platform.hpp"
15+
#include "program.hpp"
1516

1617
static ur_result_t getDevicesFromProgram(
1718
ur_program_handle_t hProgram,
1819
std::unique_ptr<std::vector<cl_device_id>> &DevicesInProgram) {
1920

20-
cl_uint DeviceCount;
21-
CL_RETURN_ON_FAILURE(clGetProgramInfo(cl_adapter::cast<cl_program>(hProgram),
22-
CL_PROGRAM_NUM_DEVICES, sizeof(cl_uint),
23-
&DeviceCount, nullptr));
24-
25-
if (DeviceCount < 1) {
26-
return UR_RESULT_ERROR_INVALID_CONTEXT;
21+
if (!hProgram->Context || !hProgram->Context->DeviceCount) {
22+
return UR_RESULT_ERROR_INVALID_PROGRAM;
2723
}
28-
24+
cl_uint DeviceCount = hProgram->Context->DeviceCount;
2925
DevicesInProgram = std::make_unique<std::vector<cl_device_id>>(DeviceCount);
30-
31-
CL_RETURN_ON_FAILURE(clGetProgramInfo(
32-
cl_adapter::cast<cl_program>(hProgram), CL_PROGRAM_DEVICES,
33-
DeviceCount * sizeof(cl_device_id), (*DevicesInProgram).data(), nullptr));
34-
26+
for (uint32_t i = 0; i < DeviceCount; i++) {
27+
(*DevicesInProgram)[i] = hProgram->Context->Devices[i]->get();
28+
}
3529
return UR_RESULT_SUCCESS;
3630
}
3731

3832
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
3933
ur_context_handle_t hContext, const void *pIL, size_t length,
4034
const ur_program_properties_t *, ur_program_handle_t *phProgram) {
4135

42-
std::unique_ptr<std::vector<cl_device_id>> DevicesInCtx;
43-
CL_RETURN_ON_FAILURE_AND_SET_NULL(
44-
cl_adapter::getDevicesFromContext(hContext, DevicesInCtx), phProgram);
45-
46-
cl_platform_id CurPlatform;
47-
CL_RETURN_ON_FAILURE_AND_SET_NULL(
48-
clGetDeviceInfo((*DevicesInCtx)[0], CL_DEVICE_PLATFORM,
49-
sizeof(cl_platform_id), &CurPlatform, nullptr),
50-
phProgram);
36+
if (!hContext->DeviceCount || !hContext->Devices[0]->Platform) {
37+
return UR_RESULT_ERROR_INVALID_CONTEXT;
38+
}
39+
cl_platform_id CurPlatform = hContext->Devices[0]->Platform->get();
5140

5241
oclv::OpenCLVersion PlatVer;
5342
CL_RETURN_ON_FAILURE_AND_SET_NULL(
@@ -57,7 +46,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
5746
if (PlatVer >= oclv::V2_1) {
5847

5948
/* Make sure all devices support CL 2.1 or newer as well. */
60-
for (cl_device_id Dev : *DevicesInCtx) {
49+
for (ur_device_handle_t URDev : hContext->Devices) {
50+
cl_device_id Dev = URDev->get();
6151
oclv::OpenCLVersion DevVer;
6252

6353
CL_RETURN_ON_FAILURE_AND_SET_NULL(
@@ -79,15 +69,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
7969
}
8070
}
8171

82-
*phProgram = cl_adapter::cast<ur_program_handle_t>(clCreateProgramWithIL(
83-
hContext->get(), pIL, length, &Err));
72+
cl_program Program = clCreateProgramWithIL(hContext->get(), pIL, length, &Err);
8473
CL_RETURN_ON_FAILURE(Err);
74+
75+
*phProgram = new ur_program_handle_t_(Program, hContext);
8576
} else {
8677

8778
/* If none of the devices conform with CL 2.1 or newer make sure they all
8879
* support the cl_khr_il_program extension.
8980
*/
90-
for (cl_device_id Dev : *DevicesInCtx) {
81+
for (ur_device_handle_t URDev : hContext->Devices) {
82+
cl_device_id Dev = URDev->get();
9183
bool Supported = false;
9284
CL_RETURN_ON_FAILURE_AND_SET_NULL(
9385
cl_adapter::checkDeviceExtensions(Dev, {"cl_khr_il_program"},
@@ -106,9 +98,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
10698
CurPlatform, "clCreateProgramWithILKHR"));
10799

108100
assert(FuncPtr != nullptr);
101+
cl_program Program = FuncPtr(hContext->get(), pIL, length, &Err);
102+
*phProgram = new ur_program_handle_t_(Program, hContext);
109103

110-
*phProgram = cl_adapter::cast<ur_program_handle_t>(
111-
FuncPtr(hContext->get(), pIL, length, &Err));
112104
CL_RETURN_ON_FAILURE(Err);
113105
}
114106

@@ -124,9 +116,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
124116
const size_t Lengths[1] = {size};
125117
cl_int BinaryStatus[1];
126118
cl_int CLResult;
127-
*phProgram = cl_adapter::cast<ur_program_handle_t>(clCreateProgramWithBinary(
119+
cl_program Program = clCreateProgramWithBinary(
128120
hContext->get(), cl_adapter::cast<cl_uint>(1u),
129-
Devices, Lengths, &pBinary, BinaryStatus, &CLResult));
121+
Devices, Lengths, &pBinary, BinaryStatus, &CLResult);
122+
*phProgram = new ur_program_handle_t_(Program, hContext);
130123
CL_RETURN_ON_FAILURE(BinaryStatus[0]);
131124
CL_RETURN_ON_FAILURE(CLResult);
132125

@@ -140,7 +133,7 @@ urProgramCompile([[maybe_unused]] ur_context_handle_t hContext,
140133
std::unique_ptr<std::vector<cl_device_id>> DevicesInProgram;
141134
CL_RETURN_ON_FAILURE(getDevicesFromProgram(hProgram, DevicesInProgram));
142135

143-
CL_RETURN_ON_FAILURE(clCompileProgram(cl_adapter::cast<cl_program>(hProgram),
136+
CL_RETURN_ON_FAILURE(clCompileProgram(hProgram->get(),
144137
DevicesInProgram->size(),
145138
DevicesInProgram->data(), pOptions, 0,
146139
nullptr, nullptr, nullptr, nullptr));
@@ -178,7 +171,7 @@ UR_APIEXPORT ur_result_t UR_APICALL
178171
urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName,
179172
size_t propSize, void *pPropValue, size_t *pPropSizeRet) {
180173
size_t CheckPropSize = 0;
181-
auto ClResult = clGetProgramInfo(cl_adapter::cast<cl_program>(hProgram),
174+
auto ClResult = clGetProgramInfo(hProgram->get(),
182175
mapURProgramInfoToCL(propName), propSize,
183176
pPropValue, &CheckPropSize);
184177
if (pPropValue && CheckPropSize != propSize) {
@@ -199,7 +192,7 @@ urProgramBuild([[maybe_unused]] ur_context_handle_t hContext,
199192
CL_RETURN_ON_FAILURE(getDevicesFromProgram(hProgram, DevicesInProgram));
200193

201194
CL_RETURN_ON_FAILURE(clBuildProgram(
202-
cl_adapter::cast<cl_program>(hProgram), DevicesInProgram->size(),
195+
hProgram->get(), DevicesInProgram->size(),
203196
DevicesInProgram->data(), pOptions, nullptr, nullptr));
204197
return UR_RESULT_SUCCESS;
205198
}
@@ -210,11 +203,16 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
210203
ur_program_handle_t *phProgram) {
211204

212205
cl_int CLResult;
213-
*phProgram = cl_adapter::cast<ur_program_handle_t>(
206+
std::vector<cl_program> CLPrograms(count);
207+
for (uint32_t i = 0; i < count; i++) {
208+
CLPrograms[i] = phPrograms[i]->get();
209+
}
210+
cl_program Program =
214211
clLinkProgram(hContext->get(), 0, nullptr,
215212
pOptions, cl_adapter::cast<cl_uint>(count),
216-
cl_adapter::cast<const cl_program *>(phPrograms), nullptr,
217-
nullptr, &CLResult));
213+
CLPrograms.data(), nullptr,
214+
nullptr, &CLResult);
215+
*phProgram = new ur_program_handle_t_(Program, hContext);
218216
CL_RETURN_ON_FAILURE(CLResult);
219217

220218
return UR_RESULT_SUCCESS;
@@ -280,14 +278,14 @@ urProgramGetBuildInfo(ur_program_handle_t hProgram, ur_device_handle_t hDevice,
280278
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
281279
cl_program_binary_type BinaryType;
282280
CL_RETURN_ON_FAILURE(clGetProgramBuildInfo(
283-
cl_adapter::cast<cl_program>(hProgram), hDevice->get(),
281+
hProgram->get(), hDevice->get(),
284282
mapURProgramBuildInfoToCL(propName), sizeof(cl_program_binary_type),
285283
&BinaryType, nullptr));
286284
return ReturnValue(mapCLBinaryTypeToUR(BinaryType));
287285
}
288286
size_t CheckPropSize = 0;
289287
cl_int ClErr =
290-
clGetProgramBuildInfo(cl_adapter::cast<cl_program>(hProgram),
288+
clGetProgramBuildInfo(hProgram->get(),
291289
hDevice->get(), mapURProgramBuildInfoToCL(propName),
292290
propSize, pPropValue, &CheckPropSize);
293291
if (pPropValue && CheckPropSize != propSize) {
@@ -304,30 +302,32 @@ urProgramGetBuildInfo(ur_program_handle_t hProgram, ur_device_handle_t hDevice,
304302
UR_APIEXPORT ur_result_t UR_APICALL
305303
urProgramRetain(ur_program_handle_t hProgram) {
306304

307-
CL_RETURN_ON_FAILURE(clRetainProgram(cl_adapter::cast<cl_program>(hProgram)));
305+
CL_RETURN_ON_FAILURE(clRetainProgram(hProgram->get()));
308306
return UR_RESULT_SUCCESS;
309307
}
310308

311309
UR_APIEXPORT ur_result_t UR_APICALL
312310
urProgramRelease(ur_program_handle_t hProgram) {
313311

314312
CL_RETURN_ON_FAILURE(
315-
clReleaseProgram(cl_adapter::cast<cl_program>(hProgram)));
313+
clReleaseProgram(hProgram->get()));
316314
return UR_RESULT_SUCCESS;
317315
}
318316

319317
UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
320318
ur_program_handle_t hProgram, ur_native_handle_t *phNativeProgram) {
321319

322-
*phNativeProgram = reinterpret_cast<ur_native_handle_t>(hProgram);
320+
*phNativeProgram = reinterpret_cast<ur_native_handle_t>(hProgram->get());
323321
return UR_RESULT_SUCCESS;
324322
}
325323

326324
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
327-
ur_native_handle_t hNativeProgram, ur_context_handle_t,
325+
ur_native_handle_t hNativeProgram, ur_context_handle_t hContext,
328326
const ur_program_native_properties_t *pProperties,
329327
ur_program_handle_t *phProgram) {
330-
*phProgram = reinterpret_cast<ur_program_handle_t>(hNativeProgram);
328+
cl_program NativeHandle =
329+
reinterpret_cast<cl_program>(hNativeProgram);
330+
*phProgram = new ur_program_handle_t_(NativeHandle, hContext);
331331
if (!pProperties || !pProperties->isNativeHandleOwned) {
332332
return urProgramRetain(*phProgram);
333333
}
@@ -338,20 +338,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
338338
ur_program_handle_t hProgram, uint32_t count,
339339
const ur_specialization_constant_info_t *pSpecConstants) {
340340

341-
cl_program CLProg = cl_adapter::cast<cl_program>(hProgram);
342-
cl_context Ctx = nullptr;
343-
size_t RetSize = 0;
344-
345-
CL_RETURN_ON_FAILURE(clGetProgramInfo(CLProg, CL_PROGRAM_CONTEXT, sizeof(Ctx),
346-
&Ctx, &RetSize));
341+
cl_program CLProg = hProgram->get();
342+
if (!hProgram->Context) {
343+
return UR_RESULT_ERROR_INVALID_PROGRAM;
344+
}
345+
ur_context_handle_t Ctx = hProgram->Context;
346+
if (!Ctx->DeviceCount || !Ctx->Devices[0]->Platform) {
347+
return UR_RESULT_ERROR_INVALID_CONTEXT;
348+
}
347349

348350
std::unique_ptr<std::vector<cl_device_id>> DevicesInCtx;
349-
cl_adapter::getDevicesFromContext(cl_adapter::cast<ur_context_handle_t>(Ctx),
350-
DevicesInCtx);
351+
cl_adapter::getDevicesFromContext(Ctx, DevicesInCtx);
351352

352-
cl_platform_id CurPlatform;
353-
clGetDeviceInfo((*DevicesInCtx)[0], CL_DEVICE_PLATFORM,
354-
sizeof(cl_platform_id), &CurPlatform, nullptr);
353+
cl_platform_id CurPlatform = Ctx->Devices[0]->Platform->get();
355354

356355
oclv::OpenCLVersion PlatVer;
357356
cl_adapter::getPlatformVersion(CurPlatform, PlatVer);
@@ -383,7 +382,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
383382
SetProgramSpecializationConstant = nullptr;
384383
const ur_result_t URResult = cl_ext::getExtFuncFromContext<
385384
decltype(SetProgramSpecializationConstant)>(
386-
Ctx, cl_ext::ExtFuncPtrCache->clSetProgramSpecializationConstantCache,
385+
Ctx->get(), cl_ext::ExtFuncPtrCache->clSetProgramSpecializationConstantCache,
387386
cl_ext::SetProgramSpecializationConstantName,
388387
&SetProgramSpecializationConstant);
389388

@@ -430,10 +429,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
430429
ur_device_handle_t hDevice, ur_program_handle_t hProgram,
431430
const char *pFunctionName, void **ppFunctionPointer) {
432431

433-
cl_context CLContext = nullptr;
434-
CL_RETURN_ON_FAILURE(clGetProgramInfo(cl_adapter::cast<cl_program>(hProgram),
435-
CL_PROGRAM_CONTEXT, sizeof(CLContext),
436-
&CLContext, nullptr));
432+
cl_context CLContext = hProgram->Context->get();
437433

438434
cl_ext::clGetDeviceFunctionPointer_fn FuncT = nullptr;
439435

@@ -453,14 +449,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
453449
// throws exceptions.
454450
*ppFunctionPointer = 0;
455451
size_t Size;
456-
CL_RETURN_ON_FAILURE(clGetProgramInfo(cl_adapter::cast<cl_program>(hProgram),
452+
CL_RETURN_ON_FAILURE(clGetProgramInfo(hProgram->get(),
457453
CL_PROGRAM_KERNEL_NAMES, 0, nullptr,
458454
&Size));
459455

460456
std::string KernelNames(Size, ' ');
461457

462458
CL_RETURN_ON_FAILURE(clGetProgramInfo(
463-
cl_adapter::cast<cl_program>(hProgram), CL_PROGRAM_KERNEL_NAMES,
459+
hProgram->get(), CL_PROGRAM_KERNEL_NAMES,
464460
KernelNames.size(), &KernelNames[0], nullptr));
465461

466462
// Get rid of the null terminator and search for the kernel name. If the
@@ -471,7 +467,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
471467
}
472468

473469
const cl_int CLResult =
474-
FuncT(hDevice->get(), cl_adapter::cast<cl_program>(hProgram),
470+
FuncT(hDevice->get(), hProgram->get(),
475471
pFunctionName, reinterpret_cast<cl_ulong *>(ppFunctionPointer));
476472
// GPU runtime sometimes returns CL_INVALID_ARG_VALUE if the function address
477473
// cannot be found but the kernel exists. As the kernel does exist, return

source/adapters/opencl/program.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===--------- program.hpp - OpenCL Adapter ---------------------------===//
2+
//
3+
// Copyright (C) 2023 Intel Corporation
4+
//
5+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See LICENSE.TXT
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
#pragma once
11+
12+
#include "common.hpp"
13+
14+
#include <vector>
15+
16+
struct ur_program_handle_t_ {
17+
using native_type = cl_program;
18+
native_type Program;
19+
ur_context_handle_t Context;
20+
21+
ur_program_handle_t_(native_type Prog, ur_context_handle_t Ctx) : Program(Prog), Context(Ctx) {}
22+
23+
~ur_program_handle_t_() {}
24+
25+
native_type get() { return Program; }
26+
};

0 commit comments

Comments
 (0)