Skip to content

Commit 5131596

Browse files
authored
Add ability for EP to get vendor ID and device ID from OrtMemoryDevice (microsoft#25222)
### Description <!-- Describe your changes. --> EP implementations need to be able to read vendor id and device id to implement OrtDataTransferImpl::CanCopy correctly. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 930aa91 commit 5131596

File tree

4 files changed

+63
-1
lines changed

4 files changed

+63
-1
lines changed

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,28 @@ struct OrtEpApi {
296296
* \since Version 1.23.
297297
*/
298298
ORT_API_T(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device);
299+
300+
/** \brief Get the vendor ID from an OrtMemoryDevice instance.
301+
*
302+
* The vendor ID is used to identify the vendor of the device, and is typically set to the PCI vendor ID.
303+
*
304+
* If the device is not vendor specific (e.g. CPU memory) the vendor ID is set to 0.
305+
*
306+
* \param[in] memory_device OrtMemoryDevice instance.
307+
* \return The vendor ID value.
308+
*
309+
* \since Version 1.23.
310+
*/
311+
ORT_API_T(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device);
312+
313+
/** \brief Get the device ID from an OrtMemoryDevice instance.
314+
*
315+
* \param[in] memory_device OrtMemoryDevice instance.
316+
* \return The device ID.
317+
*
318+
* \since Version 1.23.
319+
*/
320+
ORT_API_T(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device);
299321
};
300322

301323
/**

onnxruntime/core/session/ep_api.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ ORT_API(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDev
152152
: OrtDeviceMemoryType_HOST_ACCESSIBLE;
153153
}
154154

155+
ORT_API(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device) {
156+
return memory_device->Vendor();
157+
}
158+
159+
ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device) {
160+
return memory_device->Id();
161+
}
162+
155163
static constexpr OrtEpApi ort_ep_api = {
156164
// NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end,
157165
// and no functions can be removed (the implementation needs to change to return an error).
@@ -171,6 +179,8 @@ static constexpr OrtEpApi ort_ep_api = {
171179
&OrtExecutionProviderApi::MemoryDevice_AreEqual,
172180
&OrtExecutionProviderApi::MemoryDevice_GetDeviceType,
173181
&OrtExecutionProviderApi::MemoryDevice_GetMemoryType,
182+
&OrtExecutionProviderApi::MemoryDevice_GetVendorId,
183+
&OrtExecutionProviderApi::MemoryDevice_GetDeviceId,
174184
};
175185

176186
// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned

onnxruntime/core/session/ep_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ ORT_API_STATUS_IMPL(Value_GetMemoryDevice, _In_ const OrtValue* value, _Out_ con
3333
ORT_API(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b);
3434
ORT_API(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemoryDevice* memory_device);
3535
ORT_API(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device);
36+
ORT_API(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device);
37+
ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device);
3638
} // namespace OrtExecutionProviderApi

onnxruntime/test/autoep/library/ep_data_transfer.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,38 @@
1010
bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(void* this_ptr,
1111
const OrtMemoryDevice* src_memory_device,
1212
const OrtMemoryDevice* dst_memory_device) noexcept {
13+
static constexpr uint32_t VendorId = 0xBE57; // Example vendor ID for demonstration purposes.
14+
1315
auto& impl = *static_cast<ExampleDataTransfer*>(this_ptr);
1416
bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info);
1517
bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info);
1618

17-
return src_is_our_device || dst_is_our_device;
19+
if (src_is_our_device && dst_is_our_device) {
20+
return true;
21+
}
22+
23+
// implementation should check if the copy is possible, which may require checking the device type, the memory type
24+
// and the vendor and device IDs as needed.
25+
OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device);
26+
OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device);
27+
// OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_memory_device);
28+
// OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_memory_device);
29+
// uint32_t src_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device);
30+
// uint32_t dst_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device);
31+
// uint32_t src_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device);
32+
// uint32_t dst_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device);
33+
34+
if (src_is_our_device) {
35+
// check device type and vendor to see if compatible
36+
return (dst_device_type == OrtMemoryInfoDeviceType_CPU);
37+
}
38+
39+
if (dst_is_our_device) {
40+
// check device type and vendor to see if compatible
41+
return (src_device_type == OrtMemoryInfoDeviceType_CPU);
42+
}
43+
44+
return false;
1845
}
1946

2047
// function to copy one or more tensors.
@@ -41,6 +68,7 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr,
4168

4269
OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device);
4370
OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device);
71+
4472
// OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device);
4573
// OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device);
4674
// bool copy_involves_pinned_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE ||

0 commit comments

Comments
 (0)