Skip to content

Commit e0b4007

Browse files
authored
Fix pybind API and docs
Differential Revision: D84519459 Pull Request resolved: #15051
1 parent e0882af commit e0b4007

File tree

3 files changed

+167
-34
lines changed

3 files changed

+167
-34
lines changed

runtime/__init__.py

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
from pathlib import Path
1313
1414
import torch
15-
from executorch.runtime import Verification, Runtime, Program, Method
15+
from executorch.runtime import Runtime, Program, Method
1616
1717
et_runtime: Runtime = Runtime.get()
18-
program: Program = et_runtime.load_program(
19-
Path("/tmp/program.pte"),
20-
verification=Verification.Minimal,
21-
)
18+
program: Program = et_runtime.load_program(Path("/tmp/program.pte"))
2219
print("Program methods:", program.method_names)
2320
forward: Method = program.load_method("forward")
2421
@@ -40,21 +37,23 @@
4037
4138
Example usage with ETDump generation:
4239
40+
Note: ETDump requires building ExecuTorch with event tracing enabled
41+
(CMake option ``EXECUTORCH_ENABLE_EVENT_TRACER=ON``).
42+
4343
.. code-block:: python
4444
4545
from pathlib import Path
4646
import os
4747
4848
import torch
49-
from executorch.runtime import Verification, Runtime, Program, Method
49+
from executorch.runtime import Runtime, Program, Method
5050
5151
# Create program with etdump generation enabled
5252
et_runtime: Runtime = Runtime.get()
5353
program: Program = et_runtime.load_program(
5454
Path("/tmp/program.pte"),
55-
verification=Verification.Minimal,
5655
enable_etdump=True,
57-
debug_buffer_size=1e7, # A large buffer size to ensure that all debug info is captured
56+
debug_buffer_size=int(1e7), # 10MB buffer to capture all debug info
5857
)
5958
6059
# Load method and execute
@@ -76,10 +75,37 @@
7675
7776
.. code-block:: text
7877
79-
Program methods: {'forward'}
8078
ETDump file created: True
8179
Debug file created: True
8280
Directory contents: ['program.pte', 'etdump_output.etdp', 'debug_output.bin']
81+
82+
Example usage with backend and operator introspection:
83+
84+
.. code-block:: python
85+
86+
from executorch.runtime import Runtime
87+
88+
runtime = Runtime.get()
89+
90+
# Check available backends
91+
backends = runtime.backend_registry.registered_backend_names
92+
print(f"Available backends: {backends}")
93+
94+
# Check if a specific backend is available
95+
if runtime.backend_registry.is_available("XnnpackBackend"):
96+
print("XNNPACK backend is available")
97+
98+
# List all registered operators
99+
operators = runtime.operator_registry.operator_names
100+
print(f"Number of registered operators: {len(operators)}")
101+
102+
Example output:
103+
104+
.. code-block:: text
105+
106+
Available backends: ['XnnpackBackend', ...] # Depends on your build configuration
107+
XNNPACK backend is available
108+
Number of registered operators: 247 # Depends on linked kernels
83109
"""
84110

85111
import functools
@@ -113,19 +139,22 @@ def execute(self, inputs: Sequence[Any]) -> Sequence[Any]:
113139
"""Executes the method with the given inputs.
114140
115141
Args:
116-
inputs: The inputs to the method.
142+
inputs: A sequence of input values, typically torch.Tensor objects.
117143
118144
Returns:
119-
The outputs of the method.
145+
A list of output values, typically torch.Tensor objects.
120146
"""
121147
return self._method(inputs)
122148

123149
@property
124150
def metadata(self) -> MethodMeta:
125151
"""Gets the metadata for the method.
126152
153+
The metadata includes information about input and output specifications,
154+
such as tensor shapes, data types, and memory requirements.
155+
127156
Returns:
128-
The metadata for the method.
157+
The MethodMeta object containing method specifications.
129158
"""
130159
return self._method.method_meta()
131160

@@ -148,9 +177,7 @@ def __init__(self, program: ExecuTorchProgram, data: Optional[bytes]) -> None:
148177

149178
@property
150179
def method_names(self) -> Set[str]:
151-
"""
152-
Returns method names of the `Program` as a set of strings.
153-
"""
180+
"""Returns method names of the Program as a set of strings."""
154181
return set(self._methods.keys())
155182

156183
def load_method(self, name: str) -> Optional[Method]:
@@ -170,13 +197,13 @@ def load_method(self, name: str) -> Optional[Method]:
170197
return method
171198

172199
def metadata(self, method_name: str) -> MethodMeta:
173-
"""Gets the metadata for the specified method.
200+
"""Gets the metadata for the specified method without loading it.
174201
175202
Args:
176203
method_name: The name of the method.
177204
178205
Returns:
179-
The outputs of the method.
206+
The metadata for the method, including input/output specifications.
180207
"""
181208
return self._program.method_meta(method_name)
182209

@@ -201,14 +228,17 @@ def __init__(self, legacy_module: ModuleType) -> None:
201228

202229
@property
203230
def registered_backend_names(self) -> List[str]:
204-
"""
205-
Returns the names of all registered backends as a list of strings.
206-
"""
231+
"""Returns the names of all registered backends as a list of strings."""
207232
return self._legacy_module._get_registered_backend_names()
208233

209234
def is_available(self, backend_name: str) -> bool:
210-
"""
211-
Returns the names of all registered backends as a list of strings.
235+
"""Checks if a specific backend is available in the runtime.
236+
237+
Args:
238+
backend_name: The name of the backend to check (e.g., "XnnpackBackend").
239+
240+
Returns:
241+
True if the backend is available, False otherwise.
212242
"""
213243
return self._legacy_module._is_available(backend_name)
214244

@@ -222,9 +252,7 @@ def __init__(self, legacy_module: ModuleType) -> None:
222252

223253
@property
224254
def operator_names(self) -> Set[str]:
225-
"""
226-
Returns the names of all registered operators as a set of strings.
227-
"""
255+
"""Returns the names of all registered operators as a set of strings."""
228256
return set(self._legacy_module._get_operator_names())
229257

230258

@@ -233,6 +261,10 @@ class Runtime:
233261
234262
This can be used to concurrently load and execute any number of ExecuTorch
235263
programs and methods.
264+
265+
Attributes:
266+
backend_registry: Registry for querying available hardware backends.
267+
operator_registry: Registry for querying available operators/kernels.
236268
"""
237269

238270
@staticmethod
@@ -261,11 +293,17 @@ def load_program(
261293
"""Loads an ExecuTorch program from a PTE binary.
262294
263295
Args:
264-
data: The binary program data to load; typically PTE data.
265-
verification: level of program verification to perform.
296+
data: The binary program data to load. Can be a file path (str or Path),
297+
bytes/bytearray, or a file-like object.
298+
verification: Level of program verification to perform (Minimal or InternalConsistency).
299+
Default is InternalConsistency.
300+
enable_etdump: If True, enables ETDump profiling for runtime performance analysis.
301+
Default is False.
302+
debug_buffer_size: Size of the debug buffer in bytes for ETDump data.
303+
Only used when enable_etdump=True. Default is 0.
266304
267305
Returns:
268-
The loaded program.
306+
The loaded Program instance.
269307
"""
270308
if isinstance(data, (Path, str)):
271309
p = self._legacy_module._load_program(
@@ -275,20 +313,21 @@ def load_program(
275313
program_verification=verification,
276314
)
277315
return Program(p, data=None)
278-
elif isinstance(data, BinaryIO):
279-
data_bytes = data.read()
280-
elif isinstance(data, bytearray):
281-
data_bytes = bytes(data)
282316
elif isinstance(data, bytes):
283317
data_bytes = data
318+
elif isinstance(data, bytearray):
319+
data_bytes = bytes(data)
320+
elif hasattr(data, "read"):
321+
# File-like object with read() method
322+
data_bytes = data.read()
284323
else:
285324
raise TypeError(
286325
f"Expected data to be bytes, bytearray, a path to a .pte file, or a file-like object, but got {type(data).__name__}."
287326
)
288327
p = self._legacy_module._load_program_from_buffer(
289328
data_bytes,
290-
enable_etdump=False,
291-
debug_buffer_size=0,
329+
enable_etdump=enable_etdump,
330+
debug_buffer_size=debug_buffer_size,
292331
program_verification=verification,
293332
)
294333

runtime/test/test_runtime.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import io
78
import tempfile
89
import unittest
910
from pathlib import Path
@@ -76,3 +77,30 @@ def test_add(program):
7677
with open(f.name, "rb") as f:
7778
program = runtime.load_program(f.read())
7879
test_add(program)
80+
81+
def test_load_program_with_file_like_objects(self):
82+
"""Regression test: Ensure file-like objects (BytesIO, etc.) work correctly.
83+
84+
Previously, isinstance(data, BinaryIO) check didn't work because BinaryIO
85+
is a typing protocol. Fixed by using hasattr(data, 'read') duck-typing.
86+
"""
87+
ep, inputs = create_program(ModuleAdd())
88+
runtime = Runtime.get()
89+
90+
def test_add(program):
91+
method = program.load_method("forward")
92+
outputs = method.execute(inputs)
93+
self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
94+
95+
# Test with BytesIO
96+
bytesio = io.BytesIO(ep.buffer)
97+
program = runtime.load_program(bytesio)
98+
test_add(program)
99+
100+
# Test with bytes
101+
program = runtime.load_program(bytes(ep.buffer))
102+
test_add(program)
103+
104+
# Test with bytearray
105+
program = runtime.load_program(bytearray(ep.buffer))
106+
test_add(program)

runtime/test/test_runtime_etdump_gen.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import io
78
import os
89
import tempfile
910
import unittest
@@ -95,3 +96,68 @@ def test_etdump_generation(self):
9596
self.assertGreater(
9697
len(run_data.events), 0, "Run data should contain at least one events"
9798
)
99+
100+
def test_etdump_params_with_bytes_and_buffer(self):
101+
"""Regression test: Ensure enable_etdump and debug_buffer_size work with bytes/buffer.
102+
103+
Previously, when loading from bytes/bytearray/file-like objects, these parameters
104+
were hardcoded to False/0 instead of using the provided values.
105+
"""
106+
ep, inputs = create_program(ModuleAdd())
107+
runtime = Runtime.get()
108+
109+
with tempfile.TemporaryDirectory() as temp_dir:
110+
etdump_path = os.path.join(temp_dir, "etdump_output.etdp")
111+
debug_path = os.path.join(temp_dir, "debug_output.bin")
112+
113+
def test_etdump_with_data(data, data_type):
114+
"""Helper to test ETDump with different data types."""
115+
# Load program with etdump enabled
116+
program = runtime.load_program(
117+
data,
118+
verification=Verification.Minimal,
119+
enable_etdump=True,
120+
debug_buffer_size=int(1e7),
121+
)
122+
123+
# Execute the method
124+
method = program.load_method("forward")
125+
outputs = method.execute(inputs)
126+
127+
# Verify computation
128+
self.assertTrue(
129+
torch.allclose(outputs[0], inputs[0] + inputs[1]),
130+
f"Computation failed for {data_type}",
131+
)
132+
133+
# Write etdump result
134+
program.write_etdump_result_to_file(etdump_path, debug_path)
135+
136+
# Verify files were created
137+
self.assertTrue(
138+
os.path.exists(etdump_path),
139+
f"ETDump file not created for {data_type}",
140+
)
141+
self.assertTrue(
142+
os.path.exists(debug_path),
143+
f"Debug file not created for {data_type}",
144+
)
145+
146+
# Verify etdump file is not empty
147+
etdump_size = os.path.getsize(etdump_path)
148+
self.assertGreater(
149+
etdump_size, 0, f"ETDump file is empty for {data_type}"
150+
)
151+
152+
# Clean up for next test
153+
os.remove(etdump_path)
154+
os.remove(debug_path)
155+
156+
# Test with bytes
157+
test_etdump_with_data(ep.buffer, "bytes")
158+
159+
# Test with bytearray
160+
test_etdump_with_data(bytearray(ep.buffer), "bytearray")
161+
162+
# Test with BytesIO (file-like object)
163+
test_etdump_with_data(io.BytesIO(ep.buffer), "BytesIO")

0 commit comments

Comments
 (0)