Skip to content

Commit c8e3d6f

Browse files
committed
call EpDevice_AddAllocatorInfo in GetSupportedDevicesImpl
1 parent 3d6fa57 commit c8e3d6f

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,22 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
9292
factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3");
9393

9494
// OrtEpDevice copies ep_metadata and ep_options.
95+
OrtEpDevice* ep_device = nullptr;
9596
auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options,
96-
&ep_devices[num_ep_devices++]);
97+
&ep_device);
9798

9899
factory->ort_api.ReleaseKeyValuePairs(ep_metadata);
99100
factory->ort_api.ReleaseKeyValuePairs(ep_options);
100101

101102
if (status != nullptr) {
102103
return status;
103104
}
105+
106+
// register the allocator info required by the EP.
107+
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_gpu_memory_info_.get()));
108+
RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->host_accessible_gpu_memory_info_.get()));
109+
110+
ep_devices[num_ep_devices++] = ep_device;
104111
}
105112

106113
// C++ API equivalent. Throws on error.

0 commit comments

Comments
 (0)