Skip to content

Commit 71c1e26

Browse files
committed
Add MLX metadata and Audio type for transcriptions
1 parent 1b2b7b8 commit 71c1e26

File tree

16 files changed

+74
-36
lines changed

16 files changed

+74
-36
lines changed

Changelog.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
## 0.0.92
2-
+ Added `muna.beta.Audio` type for creating transcriptions on raw audio buffers.
2+
+ Added `beta.MLXInferenceMetadata` to compile PyTorch models for inference with MLX on Apple Silicon.
3+
+ Added `beta.MLXInferenceSessionMetadata` to compile ONNXRuntime `InferenceSession` instances for inference with MLX on Apple Silicon.
4+
+ Added `beta.Audio` type for creating transcriptions on raw PCM audio buffers.
5+
+ Removed `beta.OnnxInferenceMetadata.output_keys` field for specifying model output dictionary keys.
36

47
## 0.0.91
58
+ Fixed sporadic memory corruption when creating predictions with image inputs on Windows.

muna/beta/metadata/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .iree import IREEInferenceBackend, IREEInferenceMetadata
1010
from .litert import LiteRTInferenceMetadata
1111
from .llama import LlamaCppBackend, LlamaCppInferenceMetadata
12+
from .mlx import MLXInferenceMetadata, MLXInferenceSessionMetadata
1213
from .onnx import (
1314
OnnxRuntimeInferenceMetadata, OnnxRuntimeExecutionProvider,
1415
OnnxRuntimeOptimizationLevel

muna/beta/metadata/_torch.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,6 @@ class PyTorchInferenceMetadataBase(BaseModel, **ConfigDict(arbitrary_types_allow
6060
description="Model input tensor shapes. Use this to specify dynamic axes.",
6161
exclude=True
6262
)
63-
output_keys: list[str] | None = Field(
64-
default=None,
65-
description="Model output dictionary keys. Use this if the model returns a dictionary.",
66-
exclude=True
67-
)
6863
optimum_config: Annotated[object | None, BeforeValidator(_validate_optimum_exporter_config)] = Field(
6964
default=None,
7065
description="Optimum exporter configuration. Required when `exporter` is `optimum`.",

muna/beta/metadata/coreml.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ class CoreMLInferenceMetadata(PyTorchInferenceMetadataBase):
1414
1515
Members:
1616
model (torch.nn.Module): PyTorch module to apply metadata to.
17-
model_args (tuple[Tensor,...]): Positional inputs to the model.
17+
model_args (tuple): Positional inputs to the model.
1818
input_shapes (list): Model input tensor shapes. Use this to specify dynamic axes.
19-
output_keys (list): Model output dictionary keys. Use this if the model returns a dictionary.
2019
"""
2120
kind: Literal["meta.inference.coreml"] = Field(default="meta.inference.coreml", init=False)

muna/beta/metadata/executorch.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic import Field
77
from typing import Literal
88

9-
from ._torch import PyTorchInferenceMetadataBase, TorchExporter
9+
from ._torch import PyTorchInferenceMetadataBase
1010

1111
ExecuTorchInferenceBackend = Literal["xnnpack", "vulkan"]
1212

@@ -16,13 +16,12 @@ class ExecuTorchInferenceMetadata(PyTorchInferenceMetadataBase):
1616
1717
Members:
1818
model (torch.nn.Module): PyTorch module to apply metadata to.
19-
model_args (tuple[Tensor,...]): Positional inputs to the model.
19+
model_args (tuple): Positional inputs to the model.
2020
input_shapes (list): Model input tensor shapes. Use this to specify dynamic axes.
21-
output_keys (list): Model output dictionary keys. Use this if the model returns a dictionary.
2221
backend (ExecuTorchInferenceBackend): ExecuTorch backend to execute the model.
2322
"""
2423
kind: Literal["meta.inference.executorch"] = Field(default="meta.inference.executorch", init=False)
25-
exporter: TorchExporter | None = Field(default=None, init=False)
24+
exporter: None = Field(default=None, init=False, exclude=True)
2625
backend: ExecuTorchInferenceBackend = Field(
2726
default="xnnpack",
2827
description="ExecuTorch backend to execute the model.",

muna/beta/metadata/iree.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ class IREEInferenceMetadata(PyTorchInferenceMetadataBase):
1717
Members:
1818
model (torch.nn.Module): PyTorch module to apply metadata to.
1919
exporter (TorchExporter): PyTorch exporter to use.
20-
model_args (tuple[Tensor,...]): Positional inputs to the model.
20+
model_args (tuple): Positional inputs to the model.
2121
input_shapes (list): Model input tensor shapes. Use this to specify dynamic axes.
22-
output_keys (list): Model output dictionary keys. Use this if the model returns a dictionary.
2322
"""
2423
kind: Literal["meta.inference.iree"] = Field(default="meta.inference.iree", init=False)
2524
backend: IREEInferenceBackend = Field(

muna/beta/metadata/litert.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@
66
from pydantic import Field
77
from typing import Literal
88

9-
from ._torch import PyTorchInferenceMetadataBase, TorchExporter
9+
from ._torch import PyTorchInferenceMetadataBase
1010

1111
class LiteRTInferenceMetadata(PyTorchInferenceMetadataBase):
1212
"""
1313
Metadata to compile a PyTorch model for inference with LiteRT.
1414
1515
Members:
1616
model (torch.nn.Module): PyTorch module to apply metadata to.
17-
model_args (tuple[Tensor,...]): Positional inputs to the model.
17+
model_args (tuple): Positional inputs to the model.
1818
input_shapes (list): Model input tensor shapes. Use this to specify dynamic axes.
19-
output_keys (list): Model output dictionary keys. Use this if the model returns a dictionary.
2019
"""
2120
kind: Literal["meta.inference.litert"] = Field(default="meta.inference.litert", init=False)
22-
exporter: TorchExporter | None = Field(default=None, init=False)
21+
exporter: None = Field(default=None, init=False, exclude=True)

muna/beta/metadata/mlx.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#
2+
# Muna
3+
# Copyright © 2026 NatML Inc. All Rights Reserved.
4+
#
5+
6+
from pydantic import Field
7+
from typing import Literal
8+
9+
from ._torch import PyTorchInferenceMetadataBase
10+
from .onnxruntime import OnnxRuntimeInferenceSessionMetadata
11+
12+
class MLXInferenceMetadata(PyTorchInferenceMetadataBase):
13+
"""
14+
Metadata to compile a PyTorch model for inference with MLX on Apple Silicon.
15+
16+
Members:
17+
model (torch.nn.Module): PyTorch module to apply metadata to.
18+
exporter (TorchExporter): PyTorch exporter to use.
19+
model_args (tuple): Positional inputs to the model.
20+
input_shapes (list): Model input tensor shapes. Use this to specify dynamic axes.
21+
optimum_config (optimum.ExporterConfig): Optimum exporter configuration. Required when `exporter` is `optimum`.
22+
"""
23+
kind: Literal["meta.inference.mlx"] = Field(default="meta.inference.mlx", init=False)
24+
25+
class MLXInferenceSessionMetadata(OnnxRuntimeInferenceSessionMetadata):
26+
"""
27+
Metadata to compile an OnnxRuntime `InferenceSession` for inference with MLX on Apple Silicon.
28+
29+
Members:
30+
session (onnxruntime.InferenceSession): OnnxRuntime inference session to apply metadata to.
31+
model_path (str | Path): ONNX model path. The file must exist in the compiler sandbox.
32+
external_data_path (str | Path): ONNX model external data path. This file must exist in the compiler sandbox.
33+
"""
34+
kind: Literal["meta.inference.mlx_onnx"] = Field(default="meta.inference.mlx_onnx", init=False)
35+
providers: None = Field(default=None, init=False, exclude=True)

muna/beta/metadata/onnx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ class OnnxRuntimeInferenceMetadata(PyTorchInferenceMetadataBase):
1818
Members:
1919
model (torch.nn.Module): PyTorch module to apply metadata to.
2020
exporter (TorchExporter): PyTorch exporter to use.
21-
model_args (tuple[Tensor,...]): Positional inputs to the model.
21+
model_args (tuple): Positional inputs to the model.
2222
input_shapes (list): Model input tensor shapes. Use this to specify dynamic axes.
23-
output_keys (list): Model output dictionary keys. Use this if the model returns a dictionary.
2423
optimum_config (optimum.ExporterConfig): Optimum exporter configuration. Required when `exporter` is `optimum`.
2524
optimization (OnnxRuntimeOptimizationLevel): ONNX model optimization level.
2625
providers (list): Execution providers that can be used to accelerate inference for this model.

muna/beta/metadata/onnxruntime.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ class OnnxRuntimeInferenceSessionMetadata(BaseModel, **ConfigDict(arbitrary_type
2424
2525
Members:
2626
session (onnxruntime.InferenceSession): OnnxRuntime inference session to apply metadata to.
27-
model_path (str | Path): ONNX model path. The model must exist at this path in the compiler sandbox.
27+
model_path (str | Path): ONNX model path. The file must exist in the compiler sandbox.
28+
external_data_path (str | Path): ONNX model external data path. This file must exist in the compiler sandbox.
29+
providers (list): Execution providers that can be used to accelerate inference for this model.
2830
"""
2931
kind: Literal["meta.inference.onnxruntime"] = Field(default="meta.inference.onnxruntime", init=False)
3032
session: Annotated[object, BeforeValidator(_validate_ort_inference_session)] = Field(

0 commit comments

Comments
 (0)