Skip to content

Commit 53a6996

Browse files
committed
Update ONNX inference metadata types
1 parent 47ba804 commit 53a6996

File tree

5 files changed

+32
-5
lines changed

5 files changed

+32
-5
lines changed

Changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
## 0.0.89
2+
+ Added `OnnxRuntimeInferenceSessionMetadata.external_data_path` field for specifying the path to external data files.
3+
+ Added `OnnxRuntimeInferenceSessionMetadata.providers` field for specifying enabled ONNXRuntime providers.
4+
+ Added `OnnxRuntimeInferenceMetadata.providers` field for specifying enabled ONNXRuntime providers.
25
+ Removed `muna source` CLI command.
36

47
## 0.0.88

muna/beta/metadata/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99
from .iree import IREEInferenceBackend, IREEInferenceMetadata
1010
from .litert import LiteRTInferenceMetadata
1111
from .llama import LlamaCppBackend, LlamaCppInferenceMetadata
12-
from .onnx import OnnxRuntimeInferenceMetadata, OnnxRuntimeOptimizationLevel
12+
from .onnx import (
13+
OnnxRuntimeInferenceMetadata, OnnxRuntimeExecutionProvider,
14+
OnnxRuntimeOptimizationLevel
15+
)
1316
from .onnxruntime import OnnxRuntimeInferenceSessionMetadata
1417
from .openvino import OpenVINOInferenceMetadata
1518
from .qnn import QnnInferenceBackend, QnnInferenceMetadata, QnnInferenceQuantization
16-
from .tensorrt import CudaArchitecture, TensorRTInferenceMetadata, TensorRTHardwareCompatibility, TensorRTPrecision
19+
from .tensorrt import (
20+
CudaArchitecture, TensorRTInferenceMetadata,
21+
TensorRTHardwareCompatibility, TensorRTPrecision
22+
)
1723
from .tensorrt_rtx import TensorRTRTXInferenceMetadata
1824
from .tflite import TFLiteInterpreterMetadata

muna/beta/metadata/onnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ._torch import PyTorchInferenceMetadataBase
1010

11+
OnnxRuntimeExecutionProvider = Literal["cpu", "coreml", "cuda", "openvino", "xnnpack"]
1112
OnnxRuntimeOptimizationLevel = Literal["none", "basic", "extended"]
1213

1314
class OnnxRuntimeInferenceMetadata(PyTorchInferenceMetadataBase):
@@ -27,4 +28,9 @@ class OnnxRuntimeInferenceMetadata(PyTorchInferenceMetadataBase):
2728
default="none",
2829
description="ONNX model optimization level. Defaults to `none`.",
2930
exclude=True
31+
)
32+
providers: list[OnnxRuntimeExecutionProvider] | None = Field(
33+
default=None,
34+
description="ONNXRuntime execution providers to build with.",
35+
exclude=True
3036
)

muna/beta/metadata/onnxruntime.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
88
from typing import Annotated, Literal
99

10+
from .onnx import OnnxRuntimeExecutionProvider
11+
1012
def _validate_ort_inference_session(session: "onnxruntime.InferenceSession") -> "onnxruntime.InferenceSession": # type: ignore
1113
try:
1214
from onnxruntime import InferenceSession
@@ -32,4 +34,14 @@ class OnnxRuntimeInferenceSessionMetadata(BaseModel, **ConfigDict(arbitrary_type
3234
model_path: str | Path = Field(
3335
description="ONNX model path. The model must exist at this path in the compiler sandbox.",
3436
exclude=True
37+
)
38+
external_data_path: str | Path | None = Field(
39+
default=None,
40+
description="Path to ONNX external data file (e.g. .onnx.data).",
41+
exclude=True
42+
)
43+
providers: list[OnnxRuntimeExecutionProvider] | None = Field(
44+
default=None,
45+
description="ONNXRuntime execution providers to build with.",
46+
exclude=True
3547
)

muna/cli/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@
4040
help="Compile a Python function for deployment.",
4141
rich_help_panel="Compilation"
4242
)(compile_function)
43+
44+
# Functions
4345
app.command(
4446
name="predict",
4547
help="Invoke a compiled Python function.",
4648
context_settings={ "allow_extra_args": True, "ignore_unknown_options": True },
47-
rich_help_panel="Compilation"
49+
rich_help_panel="Functions"
4850
)(create_prediction)
49-
50-
# Predictors
5151
app.command(
5252
name="retrieve",
5353
help="Retrieve a compiled function.",

0 commit comments

Comments
 (0)