Skip to content

Commit 1e7f05e

Browse files
Verify DX sharing based on AdapterLuid
Change-Id: I86e970cbc48256e5941f0a071dc549dd22423105 Signed-off-by: Mateusz Jablonski <[email protected]>
1 parent d00daac commit 1e7f05e

File tree

6 files changed

+46
-20
lines changed

6 files changed

+46
-20
lines changed

opencl/source/os_interface/windows/api_win.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "shared/source/helpers/get_info.h"
99
#include "shared/source/helpers/hw_info.h"
10+
#include "shared/source/os_interface/windows/os_interface.h"
1011
#include "shared/source/utilities/api_intercept.h"
1112

1213
#include "opencl/source/api/api.h"
@@ -22,6 +23,18 @@
2223

2324
using namespace NEO;
2425

26+
ClDevice *pickDeviceWithAdapterLuid(Platform *platform, LUID adapterLuid) {
27+
ClDevice *deviceToReturn = nullptr;
28+
for (auto i = 0u; i < platform->getNumDevices(); i++) {
29+
auto device = platform->getClDevice(i);
30+
if (device->getRootDeviceEnvironment().osInterface->get()->getWddm()->verifyAdapterLuid(adapterLuid)) {
31+
deviceToReturn = device;
32+
break;
33+
}
34+
}
35+
return deviceToReturn;
36+
}
37+
2538
void NEO::MemObj::getOsSpecificMemObjectInfo(const cl_mem_info &paramName, size_t *srcParamSize, void **srcParam) {
2639
switch (paramName) {
2740
case CL_MEM_D3D10_RESOURCE_KHR:
@@ -295,6 +308,7 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D10KHR(cl_platform_id platform, cl_d3d10_
295308
cl_int retCode = CL_SUCCESS;
296309

297310
Platform *platformInternal = nullptr;
311+
ClDevice *device = nullptr;
298312
auto retVal = validateObjects(WithCastToInternal(platform, &platformInternal));
299313
API_ENTER(&retVal);
300314
DBG_LOG_INPUTS("platform", platform,
@@ -312,8 +326,6 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D10KHR(cl_platform_id platform, cl_d3d10_
312326
sharingFcns.getDxgiDescFcn = (D3DSharingFunctions<D3DTypesHelper::D3D10>::GetDxgiDescFcn)DebugManager.injectFcn;
313327
}
314328

315-
ClDevice *device = platformInternal->getClDevice(0);
316-
317329
switch (d3dDeviceSource) {
318330
case CL_D3D10_DEVICE_KHR:
319331
d3dDevice = (ID3D10Device *)d3dObject;
@@ -328,7 +340,7 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D10KHR(cl_platform_id platform, cl_d3d10_
328340
}
329341

330342
sharingFcns.getDxgiDescFcn(&dxgiDesc, dxgiAdapter, d3dDevice);
331-
if (dxgiDesc.VendorId != INTEL_VENDOR_ID || dxgiDesc.DeviceId != device->getHardwareInfo().platform.usDeviceID) {
343+
if (dxgiDesc.VendorId != INTEL_VENDOR_ID) {
332344
GetInfoHelper::set(numDevices, localNumDevices);
333345
retVal = CL_DEVICE_NOT_FOUND;
334346
return retVal;
@@ -337,8 +349,13 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D10KHR(cl_platform_id platform, cl_d3d10_
337349
switch (d3dDeviceSet) {
338350
case CL_PREFERRED_DEVICES_FOR_D3D10_KHR:
339351
case CL_ALL_DEVICES_FOR_D3D10_KHR:
340-
GetInfoHelper::set(devices, static_cast<cl_device_id>(device));
341-
localNumDevices = 1;
352+
device = pickDeviceWithAdapterLuid(platformInternal, dxgiDesc.AdapterLuid);
353+
if (device) {
354+
GetInfoHelper::set(devices, static_cast<cl_device_id>(device));
355+
localNumDevices = 1;
356+
} else {
357+
retCode = CL_DEVICE_NOT_FOUND;
358+
}
342359
break;
343360
default:
344361
retCode = CL_INVALID_VALUE;
@@ -502,6 +519,7 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D11KHR(cl_platform_id platform, cl_d3d11_
502519
cl_uint localNumDevices = 0;
503520

504521
Platform *platformInternal = nullptr;
522+
ClDevice *device = nullptr;
505523
auto retVal = validateObjects(WithCastToInternal(platform, &platformInternal));
506524
API_ENTER(&retVal);
507525
DBG_LOG_INPUTS("platform", platform,
@@ -519,8 +537,6 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D11KHR(cl_platform_id platform, cl_d3d11_
519537
sharingFcns.getDxgiDescFcn = (D3DSharingFunctions<D3DTypesHelper::D3D11>::GetDxgiDescFcn)DebugManager.injectFcn;
520538
}
521539

522-
ClDevice *device = platformInternal->getClDevice(0);
523-
524540
switch (d3dDeviceSource) {
525541
case CL_D3D11_DEVICE_KHR:
526542
d3dDevice = (ID3D11Device *)d3dObject;
@@ -536,7 +552,7 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D11KHR(cl_platform_id platform, cl_d3d11_
536552
}
537553

538554
sharingFcns.getDxgiDescFcn(&dxgiDesc, dxgiAdapter, d3dDevice);
539-
if (dxgiDesc.VendorId != INTEL_VENDOR_ID || dxgiDesc.DeviceId != device->getHardwareInfo().platform.usDeviceID) {
555+
if (dxgiDesc.VendorId != INTEL_VENDOR_ID) {
540556
GetInfoHelper::set(numDevices, localNumDevices);
541557
retVal = CL_DEVICE_NOT_FOUND;
542558
return retVal;
@@ -545,8 +561,13 @@ cl_int CL_API_CALL clGetDeviceIDsFromD3D11KHR(cl_platform_id platform, cl_d3d11_
545561
switch (d3dDeviceSet) {
546562
case CL_PREFERRED_DEVICES_FOR_D3D11_KHR:
547563
case CL_ALL_DEVICES_FOR_D3D11_KHR:
548-
GetInfoHelper::set(devices, static_cast<cl_device_id>(device));
549-
localNumDevices = 1;
564+
device = pickDeviceWithAdapterLuid(platformInternal, dxgiDesc.AdapterLuid);
565+
if (device) {
566+
GetInfoHelper::set(devices, static_cast<cl_device_id>(device));
567+
localNumDevices = 1;
568+
} else {
569+
retVal = CL_DEVICE_NOT_FOUND;
570+
}
550571
break;
551572
default:
552573
retVal = CL_INVALID_VALUE;

opencl/test/unit_test/d3d_sharing/d3d_tests_part2.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*
66
*/
77

8+
#include "shared/source/os_interface/windows/os_interface.h"
89
#include "shared/source/utilities/arrayref.h"
910
#include "shared/test/unit_test/helpers/debug_manager_state_restore.h"
1011

@@ -18,6 +19,7 @@
1819
#include "opencl/source/sharings/d3d/d3d_surface.h"
1920
#include "opencl/source/sharings/d3d/d3d_texture.h"
2021
#include "opencl/test/unit_test/fixtures/d3d_test_fixture.h"
22+
#include "opencl/test/unit_test/mocks/mock_wddm.h"
2123

2224
#include "gmock/gmock.h"
2325
#include "gtest/gtest.h"
@@ -562,14 +564,11 @@ INSTANTIATE_TYPED_TEST_CASE_P(D3DSharingTests, D3DTests, D3DTypes);
562564

563565
using D3D10Test = D3DTests<D3DTypesHelper::D3D10>;
564566

565-
TEST_F(D3D10Test, givenIncompatibleD3DAdapterWhenGettingDeviceIdsThenNoDevicesAreReturned) {
567+
TEST_F(D3D10Test, givenIncompatibleAdapterLuidWhenGettingDeviceIdsThenNoDevicesAreReturned) {
566568
cl_device_id deviceID;
567569
cl_uint numDevices = 15;
568-
auto clAdapterId = context->getDevice(0)->getHardwareInfo().platform.usDeviceID;
569-
auto d3dAdapterId = clAdapterId + 1;
570-
mockSharingFcns->mockDxgiDesc.DeviceId = d3dAdapterId;
570+
static_cast<WddmMock *>(context->getDevice(0)->getRootDeviceEnvironment().osInterface->get()->getWddm())->verifyAdapterLuidReturnValue = false;
571571

572-
EXPECT_NE(clAdapterId, d3dAdapterId);
573572
auto retVal = clGetDeviceIDsFromD3D10KHR(pPlatform, CL_D3D10_DEVICE_KHR, nullptr, CL_ALL_DEVICES_FOR_D3D10_KHR, 1, &deviceID, &numDevices);
574573

575574
EXPECT_EQ(CL_DEVICE_NOT_FOUND, retVal);
@@ -578,14 +577,11 @@ TEST_F(D3D10Test, givenIncompatibleD3DAdapterWhenGettingDeviceIdsThenNoDevicesAr
578577

579578
using D3D11Test = D3DTests<D3DTypesHelper::D3D11>;
580579

581-
TEST_F(D3D11Test, givenIncompatibleD3DAdapterWhenGettingDeviceIdsThenNoDevicesAreReturned) {
580+
TEST_F(D3D11Test, givenIncompatibleAdapterLuidWhenGettingDeviceIdsThenNoDevicesAreReturned) {
582581
cl_device_id deviceID;
583582
cl_uint numDevices = 15;
584-
auto clAdapterId = context->getDevice(0)->getHardwareInfo().platform.usDeviceID;
585-
auto d3dAdapterId = clAdapterId + 1;
586-
mockSharingFcns->mockDxgiDesc.DeviceId = d3dAdapterId;
583+
static_cast<WddmMock *>(context->getDevice(0)->getRootDeviceEnvironment().osInterface->get()->getWddm())->verifyAdapterLuidReturnValue = false;
587584

588-
EXPECT_NE(clAdapterId, d3dAdapterId);
589585
auto retVal = clGetDeviceIDsFromD3D11KHR(pPlatform, CL_D3D11_DEVICE_KHR, nullptr, CL_ALL_DEVICES_FOR_D3D11_KHR, 1, &deviceID, &numDevices);
590586

591587
EXPECT_EQ(CL_DEVICE_NOT_FOUND, retVal);

opencl/test/unit_test/fixtures/d3d_test_fixture.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class D3DTests : public PlatformFixture, public ::testing::Test {
9393
}
9494

9595
void SetUp() override {
96+
VariableBackup<UltHwConfig> backup(&ultHwConfig);
97+
ultHwConfig.useMockedPrepareDeviceEnvironmentsFunc = false;
9698
PlatformFixture::SetUp();
9799
rootDeviceIndex = pPlatform->getClDevice(0)->getRootDeviceIndex();
98100
context = new MockContext(pPlatform->getClDevice(0));

opencl/test/unit_test/mocks/mock_wddm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class WddmMock : public Wddm {
8888
uint64_t *getPagingFenceAddress() override;
8989
void waitOnPagingFenceFromCpu() override;
9090
void createPagingFenceLogger() override;
91+
bool verifyAdapterLuid(LUID adapterLuid) const override { return verifyAdapterLuidReturnValue; }
9192

9293
bool configureDeviceAddressSpace() {
9394
configureDeviceAddressSpaceResult.called++;
@@ -140,6 +141,7 @@ class WddmMock : public Wddm {
140141
WddmMockHelpers::CallResult waitOnPagingFenceFromCpuResult;
141142

142143
NTSTATUS createAllocationStatus = STATUS_SUCCESS;
144+
bool verifyAdapterLuidReturnValue = true;
143145
bool mapGpuVaStatus = true;
144146
bool callBaseDestroyAllocations = true;
145147
bool failOpenSharedHandle = false;

shared/source/os_interface/windows/wddm/wddm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,10 @@ PFND3DKMT_ESCAPE Wddm::getEscapeHandle() const {
902902
return getGdi()->escape;
903903
}
904904

905+
bool Wddm::verifyAdapterLuid(LUID adapterLuid) const {
906+
return adapterLuid.HighPart == hwDeviceId->getAdapterLuid().HighPart && adapterLuid.LowPart == hwDeviceId->getAdapterLuid().LowPart;
907+
}
908+
905909
VOID *Wddm::registerTrimCallback(PFND3DKMT_TRIMNOTIFICATIONCALLBACK callback, WddmResidencyController &residencyController) {
906910
if (DebugManager.flags.DoNotRegisterTrimCallback.get()) {
907911
return nullptr;

shared/source/os_interface/windows/wddm/wddm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class Wddm {
119119
D3DKMT_HANDLE getPagingQueue() const { return pagingQueue; }
120120
D3DKMT_HANDLE getPagingQueueSyncObject() const { return pagingQueueSyncObject; }
121121
inline Gdi *getGdi() const { return hwDeviceId->getGdi(); }
122+
MOCKABLE_VIRTUAL bool verifyAdapterLuid(LUID adapterLuid) const;
122123

123124
PFND3DKMT_ESCAPE getEscapeHandle() const;
124125

0 commit comments

Comments
 (0)