Skip to content

Commit 2ca537d

Browse files
committed
python API samples
1 parent 7a635da commit 2ca537d

File tree

6 files changed

+260
-36
lines changed

6 files changed

+260
-36
lines changed

python/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
## API
44

5-
[Run the ONNX Runtime session creation and inference API](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/python/api)
5+
The [api directory](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/python/api) contains samples that demonstrate how to use the ONNX Runtime Python API.
6+
These samples show very minimal API usage that is not execution provider specific.
67

78
## OpenVINO Execution Provider
89

python/api/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Python API Samples
2+
3+
This directory contains sample scripts demonstrating various ONNX Runtime Python API features:
4+
5+
- `getting_started.py`
6+
Introduces the basics of exporting a simple PyTorch model to ONNX, running inference with ONNX Runtime, and handling inputs/outputs as NumPy arrays.
7+
8+
- `compile_api.py`
9+
Shows how to programmatically compile an ONNX model for a specific execution provider (e.g., TensorRT RTX) to an [EP context](https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html) ONNX. The sample measures model load and compile times to demonstrate performance improvements and has the option to specify an input model.
10+
- For `NvTensorRTRTXExecutionProvider` try adding the provider option for a runtime cache (`-p NvTensorRTRTXExecutionProvider -popt "nv_runtime_cache_path=./cache"`) which will further increase the load speed of a compiled model.
11+
12+
- `device_bindings.py`
13+
Demonstrates advanced device bindings, including running ONNX models on CPU or GPU, using ONNX Runtime's `OrtValue` for device memory, and direct inference with PyTorch tensors on the selected device. It also demonstrates how to interact with ORT using dlpack.
14+
15+
Each sample is self-contained and includes comments explaining the main concepts.
16+
17+
### Setup
18+
19+
Besides installing the ONNX Runtime package there are some other dependencies for the samples to work correctly.
20+
Please pick your selected [onnxruntime package](https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime) manually.
21+
```
22+
pip install -r requirements.txt
23+
# to install ORT GPU with required cuda dependencies
24+
pip install onnxruntime-gpu[cuda,cudnn]
25+
```

python/api/compile_api.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import argparse
2+
import os
3+
import time
4+
import onnxruntime as ort
5+
6+
# Set logger severity to warning level to reduce console output.
7+
ort.set_default_logger_severity(3)
8+
9+
# Default Execution Provider for NVIDIA GPUs as requested.
10+
TRT_RTX_EP = "NvTensorRTRTXExecutionProvider"
11+
12+
13+
def compile(input_path, output_path, provider, ep_options, embed_mode=False):
14+
"""
15+
Compiles an ONNX model for a specified execution provider and saves it.
16+
17+
Args:
18+
input_path (str): Path to the original ONNX model.
19+
output_path (str): Path to save the compiled model.
20+
provider (str): The name of the execution provider.
21+
embed_mode (bool): If True, embeds the compiled binary data into the ONNX file.
22+
"""
23+
# Remove the output file if it already exists to ensure a clean compilation.
24+
if os.path.exists(output_path):
25+
os.remove(output_path)
26+
print(f"> Previous compiled model at {output_path} removed.")
27+
28+
# Create session options and add the provider.
29+
session_options = ort.SessionOptions()
30+
session_options.add_provider(provider, ep_options)
31+
32+
# Create a ModelCompiler instance using positional arguments.
33+
model_compiler = ort.ModelCompiler(
34+
session_options,
35+
input_path,
36+
embed_compiled_data_into_model=embed_mode
37+
)
38+
39+
print(f"\n> Compiling model with '{provider}'...")
40+
start = time.perf_counter()
41+
# Execute the compilation process.
42+
model_compiler.compile_to_file(output_path)
43+
stop = time.perf_counter()
44+
45+
if os.path.exists(output_path):
46+
print("> Compiled successfully!")
47+
print(f"> Compile time: {stop - start:.3f} sec")
48+
print(f"> Compiled model saved at {output_path}")
49+
50+
51+
def load_session(model_path, provider, ep_options):
52+
"""
53+
Loads an ONNX model into an InferenceSession and measures the loading time.
54+
55+
Args:
56+
model_path (str): Path to the ONNX model file.
57+
provider (str): The name of the execution provider.
58+
ep_options (dict): The execution provider options.
59+
"""
60+
# Create the list of providers with an empty dictionary for options.
61+
62+
start = time.perf_counter()
63+
# Load the model using the specified provider.
64+
# session_options = ort.SessionOptions()
65+
# session_options.add_provider(provider, ep_options)
66+
# session = ort.InferenceSession(model_path, sess_options=session_options)
67+
session = ort.InferenceSession(model_path, providers=[(provider, ep_options)])
68+
stop = time.perf_counter()
69+
70+
print(f"> Session load time: {stop - start:.3f} sec")
71+
72+
73+
if __name__ == "__main__":
74+
parser = argparse.ArgumentParser(description="Compile ONNX model with ONNX Runtime")
75+
parser.add_argument("-i", "--model_path", type=str, default=None, help="Path to the ONNX model file")
76+
parser.add_argument("-o", "--output_path", type=str, default="model_ctx.onnx",
77+
help="Path to save the compiled EP context model")
78+
parser.add_argument("-p", "--provider", default=TRT_RTX_EP, type=str, help="Execution Provider")
79+
parser.add_argument("-popt", "--provider_options", default=[], type=str, nargs="+",
80+
help="Execution Provider options as key=value pairs")
81+
# Using a type=bool for the embed flag.
82+
parser.add_argument("--embed", action=argparse.BooleanOptionalAction, help="Binary data embedded within EP context node")
83+
args = parser.parse_args()
84+
85+
if args.model_path is None:
86+
from getting_started import create_model
87+
88+
args.model_path = create_model()
89+
ep_options = {}
90+
for kv_pair in args.provider_options:
91+
key, value = kv_pair.split("=")
92+
ep_options[key] = value
93+
94+
print(f"""
95+
-----------------------------------------------
96+
ONNX Runtime Model Compilation Script
97+
-----------------------------------------------
98+
"> Using Execution Provider: {args.provider}
99+
"> Using Execution Provider options: {ep_options}
100+
"> Embed Mode: {'Embedded' if args.embed else 'External'}
101+
-----------------------------------------------
102+
Available execution provider(s) {ort.get_available_providers()}
103+
""")
104+
105+
# Load and time the original model.
106+
print("\n> Loading regular onnx...")
107+
load_session(args.model_path, args.provider, ep_options=ep_options)
108+
109+
# Compile the model.
110+
compile(args.model_path, args.output_path, args.provider,
111+
ep_options=ep_options, embed_mode=args.embed)
112+
113+
# Load and time the compiled model.
114+
print("\n> Loading EP context model...")
115+
load_session(args.output_path, args.provider, ep_options=ep_options)
116+
117+
print("\nProgram finished successfully.")

python/api/onnxruntime-python-api.py renamed to python/api/device_bindings.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44

55
import numpy as np
66
import torch
7+
import os
8+
import re
79
import onnxruntime
810

911
MODEL_FILE = '.model.onnx'
1012
DEVICE_NAME = 'cuda' if torch.cuda.is_available() else 'cpu'
1113
DEVICE_INDEX = 0 # Replace this with the index of the device you want to run on
1214
DEVICE=f'{DEVICE_NAME}:{DEVICE_INDEX}'
15+
LIB_EXT = 'so' if os.name != 'nt' else 'dll'
1316

1417
# A simple model to calculate addition of two tensors
1518
def model():
@@ -32,39 +35,38 @@ def create_model(type: torch.dtype = torch.float32):
3235

3336
# Create an ONNX Runtime session with the provided model
3437
def create_session(model: str) -> onnxruntime.InferenceSession:
38+
available_providers = {device.ep_name for device in onnxruntime.get_ep_devices()}
3539
providers = ['CPUExecutionProvider']
3640
if torch.cuda.is_available():
37-
providers.insert(0, 'CUDAExecutionProvider')
41+
if 'CUDAExecutionProvider' in available_providers:
42+
providers.insert(0, 'CUDAExecutionProvider')
43+
if 'NvTensorRTRTXExecutionProvider' in available_providers:
44+
providers.insert(0, 'NvTensorRTRTXExecutionProvider')
3845
return onnxruntime.InferenceSession(model, providers=providers)
3946

40-
# Run the model on CPU consuming and producing numpy arrays
41-
def run(x: np.array, y: np.array) -> np.array:
42-
session = create_session(MODEL_FILE)
43-
44-
z = session.run(["z"], {"x": x, "y": y})
45-
46-
return z[0]
4747

4848
# Run the model on device consuming and producing ORTValues
4949
def run_with_data_on_device(x: np.array, y: np.array) -> onnxruntime.OrtValue:
5050
session = create_session(MODEL_FILE)
51+
mem_info = session.get_input_memory_infos()[0]
5152

52-
x_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(x, DEVICE_NAME, DEVICE_INDEX)
53-
y_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(y, DEVICE_NAME, DEVICE_INDEX)
53+
x_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(x, 'gpu', device_id=mem_info.device_id, vendor_id=mem_info.device_vendor_id)
54+
y_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(y, 'gpu', device_id=mem_info.device_id, vendor_id=mem_info.device_vendor_id)
5455

5556
io_binding = session.io_binding()
56-
io_binding.bind_input(name='x', device_type=x_ortvalue.device_name(), device_id=0, element_type=x.dtype, shape=x_ortvalue.shape(), buffer_ptr=x_ortvalue.data_ptr())
57-
io_binding.bind_input(name='y', device_type=y_ortvalue.device_name(), device_id=0, element_type=y.dtype, shape=y_ortvalue.shape(), buffer_ptr=y_ortvalue.data_ptr())
58-
io_binding.bind_output(name='z', device_type=DEVICE_NAME, device_id=DEVICE_INDEX, element_type=x.dtype, shape=x_ortvalue.shape())
57+
io_binding.bind_input(name='x', device_type=x_ortvalue.device_name(), device_id=mem_info.device_id, element_type=x.dtype, shape=x_ortvalue.shape(), buffer_ptr=x_ortvalue.data_ptr())
58+
io_binding.bind_input(name='y', device_type=y_ortvalue.device_name(), device_id=mem_info.device_id, element_type=y.dtype, shape=y_ortvalue.shape(), buffer_ptr=y_ortvalue.data_ptr())
59+
io_binding.bind_output(name='z', device_type=x_ortvalue.device_name(), device_id=mem_info.device_id, element_type=x.dtype, shape=x_ortvalue.shape())
5960
session.run_with_iobinding(io_binding)
6061

6162
z = io_binding.get_outputs()
6263

6364
return z[0]
6465

6566
# Run the model on device consuming and producing native PyTorch tensors
66-
def run_with_torch_tensors_on_device(x: torch.Tensor, y: torch.Tensor, np_type: np.dtype = np.float32, torch_type: torch.dtype = torch.float32) -> torch.Tensor:
67+
def run_with_torch_tensors_on_device(x: torch.Tensor, y: torch.Tensor, np_type: np.dtype = np.float32, torch_type: torch.dtype = torch.float32, dlpack=False) -> torch.Tensor:
6768
session = create_session(MODEL_FILE)
69+
mem_info = session.get_input_memory_infos()[0]
6870

6971
binding = session.io_binding()
7072

@@ -73,48 +75,69 @@ def run_with_torch_tensors_on_device(x: torch.Tensor, y: torch.Tensor, np_type:
7375

7476
binding.bind_input(
7577
name='x',
76-
device_type=DEVICE_NAME,
77-
device_id=DEVICE_INDEX,
78+
device_type="gpu",
79+
device_id=mem_info.device_id,
7880
element_type=np_type,
7981
shape=tuple(x_tensor.shape),
8082
buffer_ptr=x_tensor.data_ptr(),
8183
)
8284

8385
binding.bind_input(
8486
name='y',
85-
device_type=DEVICE_NAME,
86-
device_id=DEVICE_INDEX,
87+
device_type="gpu",
88+
device_id=mem_info.device_id,
8789
element_type=np_type,
8890
shape=tuple(y_tensor.shape),
8991
buffer_ptr=y_tensor.data_ptr(),
9092
)
91-
92-
## Allocate the PyTorch tensor for the model output
93-
z_tensor = torch.empty(x_tensor.shape, dtype=torch_type, device=DEVICE).contiguous()
94-
binding.bind_output(
95-
name='z',
96-
device_type=DEVICE_NAME,
97-
device_id=DEVICE_INDEX,
98-
element_type=np_type,
99-
shape=tuple(z_tensor.shape),
100-
buffer_ptr=z_tensor.data_ptr(),
101-
)
93+
if dlpack:
94+
binding.bind_output(
95+
name='z',
96+
device_type="gpu",
97+
)
98+
else:
99+
## Allocate the PyTorch tensor for the model output
100+
z_tensor = torch.empty(x_tensor.shape, dtype=torch_type, device=DEVICE).contiguous()
101+
binding.bind_output(
102+
name='z',
103+
device_type="gpu",
104+
device_id=mem_info.device_id,
105+
element_type=np_type,
106+
shape=tuple(z_tensor.shape),
107+
buffer_ptr=z_tensor.data_ptr(),
108+
)
102109

103110
session.run_with_iobinding(binding)
104-
105-
return z_tensor
111+
if dlpack:
112+
from onnxruntime.capi import _pybind_state as C
113+
outputs = binding.get_outputs()
114+
return torch.tensor(C.OrtValue.from_dlpack(outputs[0]._ortvalue.to_dlpack(), False))
115+
else:
116+
return z_tensor
106117

107118

108119
def main():
109-
create_model()
120+
# check if plugin based providers are available and register them
121+
ort_capi_dir = os.path.dirname(onnxruntime.capi.__file__)
122+
for p in os.listdir(ort_capi_dir):
123+
match = re.match(r".*onnxruntime_providers_(.*)\."+LIB_EXT, p)
124+
if match is not None:
125+
ep_name = match.group(1)
126+
if ep_name == 'shared': continue
127+
onnxruntime.register_execution_provider_library(ep_name, os.path.join(ort_capi_dir, p))
128+
print(f"Registered execution provider {ep_name} with library: {p}")
110129

111-
print(run(x=np.float32([1.0, 2.0, 3.0]),y=np.float32([4.0, 5.0, 6.0])))
112-
# [array([5., 7., 9.], dtype=float32)]
130+
create_model()
113131

114132
print(run_with_data_on_device(x=np.float32([1.0, 2.0, 3.0, 4.0, 5.0]), y=np.float32([1.0, 2.0, 3.0, 4.0, 5.0])).numpy())
115133
# [ 2. 4. 6. 8. 10.]
116134

117-
print(run_with_torch_tensors_on_device(torch.rand(5).to(DEVICE), torch.rand(5).to(DEVICE)))
135+
x = torch.rand(5).to(DEVICE)
136+
y = torch.rand(5).to(DEVICE)
137+
print(run_with_torch_tensors_on_device(x, y, dlpack=True))
138+
# tensor([0.7023, 1.3127, 1.7289, 0.3982, 0.8386])
139+
140+
print(run_with_torch_tensors_on_device(x, y, dlpack=False))
118141
# tensor([0.7023, 1.3127, 1.7289, 0.3982, 0.8386])
119142

120143
create_model(torch.int64)

python/api/getting_started.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# A set of code samples showing different usage of the ONNX Runtime Python API
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
5+
import numpy as np
6+
import torch
7+
import onnxruntime
8+
9+
MODEL_FILE = '.model.onnx'
10+
DEVICE_NAME = 'cuda' if torch.cuda.is_available() else 'cpu'
11+
12+
# A simple model to calculate addition of two tensors
13+
def model():
14+
class Model(torch.nn.Module):
15+
def __init__(self):
16+
super(Model, self).__init__()
17+
18+
def forward(self, x, y):
19+
return x.add(y)
20+
21+
return Model()
22+
23+
# Create an instance of the model and export it to ONNX graph format, with dynamic size for the data
24+
def create_model(type: torch.dtype = torch.float32):
25+
sample_x = torch.ones(3, dtype=type)
26+
sample_y = torch.zeros(3, dtype=type)
27+
28+
torch.onnx.export(model(), (sample_x, sample_y), MODEL_FILE, input_names=["x", "y"], output_names=["z"],
29+
dynamic_axes={"x": {0 : "array_length_x"}, "y": {0: "array_length_y"}})
30+
return MODEL_FILE
31+
32+
# Create an ONNX Runtime session with the provided model
33+
def create_session(model: str) -> onnxruntime.InferenceSession:
34+
providers = ['CPUExecutionProvider']
35+
if torch.cuda.is_available():
36+
providers.insert(0, 'CUDAExecutionProvider')
37+
return onnxruntime.InferenceSession(model, providers=providers)
38+
39+
# Run the model on CPU consuming and producing numpy arrays
40+
def run(x: np.array, y: np.array) -> np.array:
41+
session = create_session(MODEL_FILE)
42+
43+
z = session.run(["z"], {"x": x, "y": y})
44+
45+
return z[0]
46+
47+
def main():
48+
create_model()
49+
50+
print(run(x=np.float32([1.0, 2.0, 3.0]),y=np.float32([4.0, 5.0, 6.0])))
51+
# [array([5., 7., 9.], dtype=float32)]
52+
53+
if __name__ == "__main__":
54+
main()

python/api/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy
2+
torch
3+
onnx
4+
--extra-index-url https://download.pytorch.org/whl/cu128

0 commit comments

Comments
 (0)