Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test_common/harness/extensionHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,15 @@
} \
} while (false)

#define GET_FUNCTION_EXTENSION_ADDRESS(device, FUNC) \
FUNC = \
reinterpret_cast<FUNC##_fn>(clGetExtensionFunctionAddressForPlatform( \
getPlatformFromDevice(device), #FUNC)); \
if (FUNC == nullptr) \
{ \
log_error("ERROR: clGetExtensionFunctionAddressForPlatform failed" \
" with " #FUNC "\n"); \
return TEST_FAIL; \
}

#endif // _extensionHelpers_h
8 changes: 4 additions & 4 deletions test_conformance/common/directx_wrapper/directx_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class DirectXWrapper {
public:
DirectXWrapper();

ID3D12Device* getDXDevice() const;
ID3D12CommandQueue* getDXCommandQueue() const;
ID3D12CommandAllocator* getDXCommandAllocator() const;
[[nodiscard]] ID3D12Device* getDXDevice() const;
[[nodiscard]] ID3D12CommandQueue* getDXCommandQueue() const;
[[nodiscard]] ID3D12CommandAllocator* getDXCommandAllocator() const;

protected:
ComPtr<ID3D12Device> dx_device = nullptr;
Expand All @@ -39,7 +39,7 @@ class DirectXWrapper {
class DirectXFenceWrapper {
public:
DirectXFenceWrapper(ID3D12Device* dx_device);
ID3D12Fence* operator*() const { return dx_fence.Get(); }
[[nodiscard]] ID3D12Fence* get() const { return dx_fence.Get(); }

private:
ComPtr<ID3D12Fence> dx_fence = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

int main(int argc, const char *argv[])
{
return runTestHarness(argc, argv, test_registry::getInstance().num_tests(),
test_registry::getInstance().definitions(), false, 0);
return runTestHarness(
argc, argv, static_cast<int>(test_registry::getInstance().num_tests()),
test_registry::getInstance().definitions(), false, 0);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,97 +21,149 @@
#include "harness/errorHelpers.h"
#include "directx_wrapper.hpp"

class CLDXSemaphoreWrapper {
public:
CLDXSemaphoreWrapper(cl_device_id device, cl_context context,
ID3D12Device* dx_device)
: device(device), context(context), dx_device(dx_device){};

int createSemaphoreFromFence(ID3D12Fence* fence)
{
cl_int errcode = CL_SUCCESS;

GET_PFN(device, clCreateSemaphoreWithPropertiesKHR);

const HRESULT hr = dx_device->CreateSharedHandle(
fence, nullptr, GENERIC_ALL, nullptr, &fence_handle);
test_error(FAILED(hr), "Failed to get shared handle from D3D12 fence");

cl_semaphore_properties_khr sem_props[] = {
static_cast<cl_semaphore_properties_khr>(CL_SEMAPHORE_TYPE_KHR),
static_cast<cl_semaphore_properties_khr>(
CL_SEMAPHORE_TYPE_BINARY_KHR),
static_cast<cl_semaphore_properties_khr>(
CL_SEMAPHORE_HANDLE_D3D12_FENCE_KHR),
reinterpret_cast<cl_semaphore_properties_khr>(fence_handle), 0
};
semaphore =
clCreateSemaphoreWithPropertiesKHR(context, sem_props, &errcode);
test_error(errcode, "Could not create semaphore");

return CL_SUCCESS;
}

~CLDXSemaphoreWrapper()
struct DXFenceTestBase
{
DXFenceTestBase(cl_device_id device, cl_context context,
cl_command_queue queue, cl_int num_elems)
: device(device), context(context), queue(queue), num_elems(num_elems)
{}
virtual ~DXFenceTestBase()
{
releaseSemaphore();
if (fence_handle)
{
CloseHandle(fence_handle);
fence_handle = nullptr;
}
if (fence_wrapper)
{
delete fence_wrapper;
fence_wrapper = nullptr;
}
if (semaphore)
{
clReleaseSemaphoreKHR(semaphore);
semaphore = nullptr;
}
};

const cl_semaphore_khr* operator&() const { return &semaphore; };
cl_semaphore_khr operator*() const { return semaphore; };
virtual int SetUp()
{
REQUIRE_EXTENSION("cl_khr_external_semaphore");
REQUIRE_EXTENSION("cl_khr_external_semaphore_dx_fence");

// Obtain pointers to semaphore's API
GET_FUNCTION_EXTENSION_ADDRESS(device,
clCreateSemaphoreWithPropertiesKHR);
GET_FUNCTION_EXTENSION_ADDRESS(device, clReleaseSemaphoreKHR);
GET_FUNCTION_EXTENSION_ADDRESS(device, clEnqueueSignalSemaphoresKHR);
GET_FUNCTION_EXTENSION_ADDRESS(device, clEnqueueWaitSemaphoresKHR);
GET_FUNCTION_EXTENSION_ADDRESS(device, clGetSemaphoreHandleForTypeKHR);
GET_FUNCTION_EXTENSION_ADDRESS(device, clRetainSemaphoreKHR);
GET_FUNCTION_EXTENSION_ADDRESS(device, clGetSemaphoreInfoKHR);

test_error(
!is_import_handle_available(CL_SEMAPHORE_HANDLE_D3D12_FENCE_KHR),
"Could not find CL_SEMAPHORE_HANDLE_D3D12_FENCE_KHR between the "
"supported import types");

// Import D3D12 fence into OpenCL
fence_wrapper = new DirectXFenceWrapper(dx_wrapper.getDXDevice());
semaphore = createSemaphoreFromFence(fence_wrapper->get());
test_assert_error(!!semaphore, "Could not create semaphore");

return TEST_PASS;
}

HANDLE getHandle() const { return fence_handle; };
virtual cl_int Run() = 0;

private:
cl_semaphore_khr semaphore;
ComPtr<ID3D12Fence> fence;
HANDLE fence_handle;
cl_device_id device;
cl_context context;
ComPtr<ID3D12Device> dx_device;
protected:
int errcode = CL_SUCCESS;

int releaseSemaphore() const
cl_device_id device = nullptr;
cl_context context = nullptr;
cl_command_queue queue = nullptr;
cl_int num_elems = 0;
DirectXWrapper dx_wrapper;

cl_semaphore_payload_khr semaphore_payload = 1;
cl_semaphore_khr semaphore = nullptr;
HANDLE fence_handle = nullptr;
DirectXFenceWrapper *fence_wrapper = nullptr;

clCreateSemaphoreWithPropertiesKHR_fn clCreateSemaphoreWithPropertiesKHR =
nullptr;
clEnqueueSignalSemaphoresKHR_fn clEnqueueSignalSemaphoresKHR = nullptr;
clEnqueueWaitSemaphoresKHR_fn clEnqueueWaitSemaphoresKHR = nullptr;
clReleaseSemaphoreKHR_fn clReleaseSemaphoreKHR = nullptr;
clGetSemaphoreInfoKHR_fn clGetSemaphoreInfoKHR = nullptr;
clRetainSemaphoreKHR_fn clRetainSemaphoreKHR = nullptr;
clGetSemaphoreHandleForTypeKHR_fn clGetSemaphoreHandleForTypeKHR = nullptr;

[[nodiscard]] bool is_import_handle_available(
const cl_external_memory_handle_type_khr handle_type)
{
GET_PFN(device, clReleaseSemaphoreKHR);

if (semaphore)
size_t import_types_size = 0;
errcode =
clGetDeviceInfo(device, CL_DEVICE_SEMAPHORE_IMPORT_HANDLE_TYPES_KHR,
0, nullptr, &import_types_size);
if (errcode != CL_SUCCESS)
{
clReleaseSemaphoreKHR(semaphore);
log_error("Could not query import semaphore handle types");
return false;
}
std::vector<cl_external_semaphore_handle_type_khr> import_types(
import_types_size / sizeof(cl_external_semaphore_handle_type_khr));
errcode =
clGetDeviceInfo(device, CL_DEVICE_SEMAPHORE_IMPORT_HANDLE_TYPES_KHR,
import_types_size, import_types.data(), nullptr);
if (errcode != CL_SUCCESS)
{
log_error("Could not query import semaphore handle types");
return false;
}

return CL_SUCCESS;
return std::find(import_types.begin(), import_types.end(), handle_type)
!= import_types.end();
}

cl_semaphore_khr createSemaphoreFromFence(ID3D12Fence *src_fence)
{
const HRESULT hr = dx_wrapper.getDXDevice()->CreateSharedHandle(
src_fence, nullptr, GENERIC_ALL, nullptr, &fence_handle);
if (FAILED(hr)) return nullptr;

const cl_semaphore_properties_khr sem_props[] = {
static_cast<cl_semaphore_properties_khr>(CL_SEMAPHORE_TYPE_KHR),
static_cast<cl_semaphore_properties_khr>(
CL_SEMAPHORE_TYPE_BINARY_KHR),
static_cast<cl_semaphore_properties_khr>(
CL_SEMAPHORE_HANDLE_D3D12_FENCE_KHR),
reinterpret_cast<cl_semaphore_properties_khr>(fence_handle), 0
};
cl_semaphore_khr tmp_semaphore =
clCreateSemaphoreWithPropertiesKHR(context, sem_props, &errcode);
if (errcode != CL_SUCCESS) return nullptr;

return tmp_semaphore;
}
};

static bool
is_import_handle_available(cl_device_id device,
const cl_external_memory_handle_type_khr handle_type)
template <class T>
int MakeAndRunTest(cl_device_id device, cl_context context,
cl_command_queue queue, cl_int nelems)
{
int errcode = CL_SUCCESS;
size_t import_types_size = 0;
errcode =
clGetDeviceInfo(device, CL_DEVICE_SEMAPHORE_IMPORT_HANDLE_TYPES_KHR, 0,
nullptr, &import_types_size);
if (errcode != CL_SUCCESS)
cl_int status = TEST_PASS;
try
{
log_error("Could not query import semaphore handle types");
return false;
}
std::vector<cl_external_semaphore_handle_type_khr> import_types(
import_types_size / sizeof(cl_external_semaphore_handle_type_khr));
errcode =
clGetDeviceInfo(device, CL_DEVICE_SEMAPHORE_IMPORT_HANDLE_TYPES_KHR,
import_types_size, import_types.data(), nullptr);
if (errcode != CL_SUCCESS)
auto test_fixture = T(device, context, queue, nelems);
status = test_fixture.SetUp();
if (status != TEST_PASS) return status;
status = test_fixture.Run();
} catch (const std::runtime_error &e)
{
log_error("Could not query import semaphore handle types");
return false;
log_error("%s", e.what());
return TEST_FAIL;
}

return std::find(import_types.begin(), import_types.end(), handle_type)
!= import_types.end();
return status;
}
Loading