Skip to content

Commit b3a34bb

Browse files
BaseTester: support plugin EPs with compiled nodes and registered kernels (#27176)
### Description Updates the `BaseTester` class used by the `onnxruntime_provider_test` tool to support plugin EPs that use a kernel registry but compile other nodes. For example, TRT EP only uses registered kernels for Memcpy* nodes, but compiles every other node. Without this change, plugin EPs that use a mix of compiled nodes and registered kernels cannot be tested with `onnxruntime_provider_test`. ### Motivation and Context
1 parent 1a71a5f commit b3a34bb

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

onnxruntime/test/unittest_util/base_tester.cc

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session,
424424
bool SetEpsForAllNodes(Graph& graph,
425425
const std::vector<std::unique_ptr<IExecutionProvider>>& execution_providers,
426426
const std::vector<std::shared_ptr<CustomRegistry>>* custom_registries,
427-
const std::function<bool(const IExecutionProvider&)>& ep_uses_kernel_registry_fn) {
427+
const std::function<bool(const IExecutionProvider&)>& ep_only_uses_kernel_registry_fn) {
428428
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
429429
const KernelRegistry::TypeConstraintMap type_constraint_map{};
430430

@@ -440,7 +440,7 @@ bool SetEpsForAllNodes(Graph& graph,
440440

441441
node.SetExecutionProviderType(provider_type);
442442

443-
if (!ep_uses_kernel_registry_fn(*ep)) {
443+
if (!ep_only_uses_kernel_registry_fn(*ep)) {
444444
found = true;
445445
break;
446446
}
@@ -659,7 +659,12 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter,
659659
#endif
660660
kDnnlExecutionProvider,
661661
kTensorrtExecutionProvider,
662+
#ifdef USE_NV
663+
// Only include NV TRT RTX EP when is ORT is built with the provider-bridge
664+
// version of the EP (i.e., USE_NV is defined). This allows use of the plugin EP version of the EP
665+
// when ORT is not built any provider-bridge EPs.
662666
kNvTensorRTRTXExecutionProvider,
667+
#endif
663668
kOpenVINOExecutionProvider,
664669
kDmlExecutionProvider,
665670
kAclExecutionProvider,
@@ -830,12 +835,15 @@ void BaseTester::ExecuteModelForEps(
830835

831836
ASSERT_TRUE(!execution_providers.empty()) << "Empty execution providers vector.";
832837
if (try_assign_ep_for_nodes) {
833-
auto ep_uses_kernel_registry = [](const IExecutionProvider& ep) {
838+
auto ep_only_uses_kernel_registry = [](const IExecutionProvider& ep) {
834839
const auto& provider_type = ep.Type();
835840

836-
constexpr std::array kEpsThatDoNotUseKernelRegistry{
841+
constexpr std::array kEpsThatCompileNodes{
837842
kOpenVINOExecutionProvider,
838-
kTensorrtExecutionProvider,
843+
kTensorrtExecutionProvider, // uses kernel registry for Memcpy* nodes only
844+
#ifdef USE_NV
845+
kNvTensorRTRTXExecutionProvider, // uses kernel registry for Memcpy* nodes only
846+
#endif
839847
kNnapiExecutionProvider,
840848
kVSINPUExecutionProvider,
841849
kCoreMLExecutionProvider,
@@ -844,24 +852,33 @@ void BaseTester::ExecuteModelForEps(
844852
kSnpeExecutionProvider,
845853
};
846854

847-
// check list of known EPs that do not use a kernel registry
848-
if (const auto ep_it = std::find(kEpsThatDoNotUseKernelRegistry.begin(), kEpsThatDoNotUseKernelRegistry.end(),
855+
// check list of known EPs that compile nodes
856+
if (const auto ep_it = std::find(kEpsThatCompileNodes.begin(), kEpsThatCompileNodes.end(),
849857
provider_type);
850-
ep_it != kEpsThatDoNotUseKernelRegistry.end()) {
858+
ep_it != kEpsThatCompileNodes.end()) {
851859
return false;
852860
}
853861

854-
// assume that a dynamic plugin EP which does not return a kernel registry does not use one
855-
if (provider_type == dynamic_plugin_ep_infra::GetEpName() &&
856-
ep.GetKernelRegistry() == nullptr) {
857-
return false;
862+
const OrtEp* ort_ep = ep.GetOrtEp();
863+
864+
if (ort_ep != nullptr) { // This is a plugin EP
865+
866+
if (ep.GetKernelRegistry() == nullptr) {
867+
// assume that a dynamic plugin EP which does not return a kernel registry does not use one
868+
return false;
869+
}
870+
871+
if (ort_ep->Compile != nullptr) {
872+
// assume that a plugin EP that compiles nodes does not use a kernel registry for all nodes
873+
return false;
874+
}
858875
}
859876

860877
// otherwise, assume that the EP uses a kernel registry
861878
return true;
862879
};
863880

864-
if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_uses_kernel_registry)) {
881+
if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_only_uses_kernel_registry)) {
865882
std::string providers;
866883
for (const auto& ep : execution_providers) {
867884
providers.append(ep->Type() + " ");

0 commit comments

Comments
 (0)