Skip to content
Merged
50 changes: 50 additions & 0 deletions onnxruntime/test/ep_weight_sharing_ctx_gen/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# ONNXRuntime EP Context Model Generation with Weight Sharing

> [!NOTE]
> This tool is deprecated. Please use the public ONNX Runtime Python APIs to compile models with resource sharing. Refer to the example Python script at the end of this document.

[EP context with weight sharing design doc](https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html#epcontext-with-weight-sharing)

OnnxRuntime provides the ep_weight_sharing_ctx_gen tool to automate the weight-sharing workflow. This tool handles the entire process. This tool is specifically designed for weight sharing scenarios, streamlining the EPContext model generation process.
Expand All @@ -13,6 +16,7 @@ Example: ./ep_weight_sharing_ctx_gen -e qnn -i "soc_model|60 htp_graph_finalizat

Options:
-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'.
-p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP.
-v: Show verbose information.
-C: Specify session configuration entries as key-value pairs: -C "<key1>|<value1> <key2>|<value2>"
Refer to onnxruntime_session_options_config_keys.h for valid keys and values.
Expand All @@ -36,3 +40,49 @@ Options:

-h: help
```

# Example: Use Python APIs to compile models with resource sharing
Use of the public ORT Python APIs is now recommended for compiling models with resource (e.g., "weight") sharing.
The following snippet shows an example that compiles two models using an example plugin EP.

```Python
import onnxruntime
import os

def main():
ep_name = "example_ep"
ep_lib_path = "example_plugin_ep.dll"

onnxruntime.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path))

# Find one or more EP devices that correspond to the EP of interest.
# In this example, we pick the first one.
ep_device = next((d for d in onnxruntime.get_ep_devices() if d.ep_name == ep_name), None)

# These are the names/paths to the input and output models.
input_models = ["model_0.onnx", "model_1.onnx"]
output_models = ["model_0_ctx.onnx", "model_1_ctx.onnx"]

num_models = len(input_models)
session_options = onnxruntime.SessionOptions()
provider_options = {} # Empty for this example

# Set option that tells EP to share resources (e.g., weights) across sessions.
session_options.add_session_config_entry("ep.share_ep_contexts", "1")
session_options.add_provider_for_devices([ep_device], provider_options)

# Compile individual models
for i in range(len(input_models)):
if i == num_models - 1:
# Tell EP that this is the last compiling session that will be sharing resources.
session_options.add_session_config_entry("ep.stop_share_ep_contexts", "1")

model_compiler = onnxruntime.ModelCompiler(
session_options,
input_models[i],
embed_compiled_data_into_model=False,
)
model_compiler.compile_to_file(output_models[i])

onnxruntime.unregister_execution_provider_library(ep_name)
```
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "command_args_parser.h"

#include <string.h>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string_view>
Expand All @@ -21,6 +22,7 @@
#include <core/platform/path_lib.h>
#include <core/optimizer/graph_transformer_level.h>

#include "nlohmann/json.hpp"
#include "test_configuration.h"

namespace onnxruntime {
Expand All @@ -35,6 +37,7 @@ namespace qnnctxgen {
"\n"
"Options:\n"
"\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'.\n"
"\t-p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP.\n"
"\t-v: Show verbose information.\n"
"\t-C: Specify session configuration entries as key-value pairs: -C \"<key1>|<value1> <key2>|<value2>\" \n"
"\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n"
Expand All @@ -58,6 +61,7 @@ namespace qnnctxgen {
"\n"
"\t-h: help\n");
}

#ifdef _WIN32
static const ORTCHAR_T* delimiter = L",";
#else
Expand Down Expand Up @@ -110,9 +114,62 @@ static bool ParseSessionConfigs(const std::string& configs_string,
return true;
}

static bool ParsePluginEpConfig(const std::string& json_file_path, PluginEpConfig& config_out) {
using json = nlohmann::json;
bool success = true;

ORT_TRY {
std::ifstream ifs{json_file_path};
if (!ifs) {
fprintf(stderr, "ERROR: Failed to open plugin EP configuration file at path: %s\n", json_file_path.c_str());
return false;
}

std::string content(std::istreambuf_iterator<char>{ifs},
std::istreambuf_iterator<char>{});
PluginEpConfig config{};
const auto parsed_json = json::parse(content);

// required keys
parsed_json.at("ep_library_registration_name").get_to(config.ep_library_registration_name);
parsed_json.at("ep_library_path").get_to(config.ep_library_path);

// optional keys
config.default_ep_options = parsed_json.value<decltype(config.default_ep_options)>("default_ep_options", {});
config.selected_ep_name = parsed_json.value<decltype(config.selected_ep_name)>("selected_ep_name", {});
config.selected_ep_device_indices =
parsed_json.value<decltype(config.selected_ep_device_indices)>("selected_ep_device_indices", {});

if (config.selected_ep_name.empty() == config.selected_ep_device_indices.empty()) {
fprintf(stderr,
"ERROR: Plugin EP configuration must specify exactly one of 'selected_ep_name' "
"or 'selected_ep_device_indices'\n");
return false;
}

config_out = std::move(config);
return success;
}
ORT_CATCH(const json::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
std::string kExampleValidJsonStr =
"{\n"
" \"ep_library_registration_name\": \"example_plugin_ep\",\n"
" \"ep_library_path\": \"/path/to/example_plugin_ep.dll\",\n"
" \"selected_ep_name\": \"example_plugin_ep\"\n"
"}";

success = false;
fprintf(stderr, "ERROR: JSON parse error: %s\n", e.what());
fprintf(stderr, "This is an example valid JSON configuration:\n%s\n", kExampleValidJsonStr.c_str());
});
}
return success;
}

/*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) {
while ((ch = getopt(argc, argv, ORT_TSTR("e:p:o:u:i:C:vh"))) != -1) {
switch (ch) {
case 'e':
if (!CompareCString(optarg, ORT_TSTR("qnn"))) {
Expand All @@ -128,6 +185,20 @@ static bool ParseSessionConfigs(const std::string& configs_string,
return false;
}
break;
case 'p': {
#ifdef _MSC_VER
std::string plugin_ep_config_file_path = ToUTF8String(optarg);
#else
std::string plugin_ep_config_file_path = optarg;
#endif
PluginEpConfig plugin_ep_config{};
if (!ParsePluginEpConfig(plugin_ep_config_file_path, plugin_ep_config)) {
return false;
}

test_config.machine_config.plugin_ep_config = std::move(plugin_ep_config);
break;
}
case 'v':
test_config.run_config.f_verbose = true;
break;
Expand Down Expand Up @@ -202,6 +273,11 @@ static bool ParseSessionConfigs(const std::string& configs_string,
argc -= optind;
argv += optind;

if (argc == 0) {
fprintf(stderr, "ERROR: Did not specify model paths\n");
return false;
}

ParsePaths(argv[0], test_config.model_file_paths);

return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"ep_library_registration_name": "example_plugin_ep",
"ep_library_path": "example_plugin_ep.dll",
"selected_ep_name": "example_plugin_ep",
"default_ep_options": { "option_key": "option_value" }
}
82 changes: 81 additions & 1 deletion onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

// onnx dependencies
#include "onnx/onnx_pb.h"
#include <algorithm>
#include <fstream>
#include <sstream>

using namespace onnxruntime;
using ProviderOptions = std::unordered_map<std::string, std::string>;
Expand Down Expand Up @@ -81,6 +83,76 @@ static void UpdateEpContextModel(const std::vector<std::basic_string<ORTCHAR_T>>
}
}

using PluginEpLibraryRegistrationHandle = std::unique_ptr<void, std::function<void(void*)>>;

struct PluginEpState {
PluginEpLibraryRegistrationHandle plugin_ep_library_registration_handle{};
std::vector<Ort::ConstEpDevice> selected_ep_devices{};
};

static PluginEpLibraryRegistrationHandle RegisterPluginEpLibrary(Ort::Env& env,
const std::string& ep_library_registration_name,
const std::basic_string<ORTCHAR_T>& ep_library_path) {
env.RegisterExecutionProviderLibrary(ep_library_registration_name.c_str(), ep_library_path);

auto unregister_ep_library = [&env, registration_name = ep_library_registration_name](void* p) {
if (p == nullptr) {
return;
}

ORT_TRY {
env.UnregisterExecutionProviderLibrary(registration_name.c_str());
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
std::cerr << "Failed to unregister EP library with name '" << registration_name << "': " << e.what() << "\n";
});
}
};

// Set `handle_value` to something not equal to nullptr. The particular value doesn't really matter.
// We are just using the unique_ptr deleter to unregister the EP library.
void* const handle_value = reinterpret_cast<void*>(0x1);
return PluginEpLibraryRegistrationHandle{handle_value, unregister_ep_library};
}

static Ort::Status SetPluginEpSessionOptions(Ort::Env& env, Ort::SessionOptions& session_options,
const qnnctxgen::PluginEpConfig& config, PluginEpState& plugin_ep_state) {
auto plugin_ep_library_registration_handle = RegisterPluginEpLibrary(env,
config.ep_library_registration_name,
ToPathString(config.ep_library_path));

std::vector<Ort::ConstEpDevice> ep_devices = env.GetEpDevices();
std::vector<Ort::ConstEpDevice> selected_ep_devices{};

if (!config.selected_ep_device_indices.empty()) {
for (const auto idx : config.selected_ep_device_indices) {
if (idx >= ep_devices.size()) {
std::ostringstream oss;
oss << "Selected EP device index is out of range (max is " << ep_devices.size() - 1 << "): " << idx;
return Ort::Status(oss.str().c_str(), ORT_FAIL);
}

selected_ep_devices.push_back(ep_devices[idx]);
}
} else {
std::copy_if(ep_devices.begin(), ep_devices.end(), std::back_inserter(selected_ep_devices),
[&selected_ep_name = std::as_const(config.selected_ep_name)](Ort::ConstEpDevice ep_device) {
return ep_device.EpName() == selected_ep_name;
});
}

if (selected_ep_devices.empty()) {
return Ort::Status("No EP devices were selected.", ORT_FAIL);
}

session_options.AppendExecutionProvider_V2(env, selected_ep_devices, config.default_ep_options);
plugin_ep_state.plugin_ep_library_registration_handle = std::move(plugin_ep_library_registration_handle);
plugin_ep_state.selected_ep_devices = std::move(selected_ep_devices);

return Ort::Status{nullptr};
}

#ifdef _WIN32
int real_main(int argc, wchar_t* argv[]) {
#else
Expand All @@ -98,6 +170,7 @@ int real_main(int argc, char* argv[]) {
Ort::Env env(logging_level, "ep_weight_sharing");

ORT_TRY {
std::optional<PluginEpState> plugin_ep_state = std::nullopt;
Ort::SessionOptions so;
so.SetLogId("ep_weight_sharing_ctx_gen_session_logger");
// Set default session option to dump EPContext model with non-embed mode
Expand Down Expand Up @@ -136,7 +209,14 @@ int real_main(int argc, char* argv[]) {
// The context binary file generated later includes all graphs from previous models
{
std::string provider_name_ = test_config.machine_config.provider_type_name;
if (provider_name_ == onnxruntime::kQnnExecutionProvider) {

if (const auto& plugin_ep_config = test_config.machine_config.plugin_ep_config; plugin_ep_config.has_value()) {
plugin_ep_state = PluginEpState{};

if (Ort::Status status = SetPluginEpSessionOptions(env, so, *plugin_ep_config, *plugin_ep_state); !status.IsOK()) {
ORT_CXX_API_THROW(status.GetErrorMessage(), status.GetErrorCode());
}
} else if (provider_name_ == onnxruntime::kQnnExecutionProvider) {
#ifdef USE_QNN
so.AppendExecutionProvider("QNN", provider_options);
#else
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <map>
#include <cstdint>
#include <optional>
#include <string>
#include <unordered_map>

Expand All @@ -14,8 +15,25 @@
namespace onnxruntime {
namespace qnnctxgen {

// Configuration for initializing the dynamic plugin EP infrastructure.
struct PluginEpConfig {
std::string ep_library_registration_name{};
std::string ep_library_path{};

// Note: Exactly one of `selected_ep_name` or `selected_ep_device_indices` should be set.
// An empty value for either means it is unset.

// Specifies the EP devices matching this EP name as the selected EP devices.
std::string selected_ep_name{};
// Specifies the selected EP devices by their indices.
std::vector<size_t> selected_ep_device_indices{};

std::unordered_map<std::string, std::string> default_ep_options{};
};

struct MachineConfig {
std::string provider_type_name{onnxruntime::kQnnExecutionProvider};
std::optional<PluginEpConfig> plugin_ep_config = std::nullopt;
};

struct RunConfig {
Expand Down
Loading
Loading