Skip to content

Commit c220a1f

Browse files
committed
Add optimum exporter
1 parent 1acdc19 commit c220a1f

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

Changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## 0.0.90
2-
*INCOMPLETE*
2+
+ Added `optimum` exporter in `OnnxRuntimeInferenceMetadata` to export LLMs using 🤗 Optimum.
33

44
## 0.0.89
55
+ Added `OnnxRuntimeInferenceSessionMetadata.external_data_path` field for specifying the path to external data files.

muna/beta/metadata/_torch.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
77
from typing import Annotated, Literal
88

9-
TorchExporter = Literal["none", "dynamo", "torchscript"]
9+
TorchExporter = Literal["none", "dynamo", "torchscript", "optimum"]
1010

1111
def _validate_torch_module(module: "torch.nn.Module") -> "torch.nn.Module": # type: ignore
1212
try:
@@ -17,7 +17,9 @@ def _validate_torch_module(module: "torch.nn.Module") -> "torch.nn.Module": # ty
1717
except ImportError:
1818
raise ImportError("PyTorch is required to create this metadata but it is not installed.")
1919

20-
def _validate_torch_tensor_args(args: list) -> list:
20+
def _validate_torch_tensor_args(args: list) -> list | None:
21+
if args is None:
22+
return args
2123
try:
2224
from torch import Tensor
2325
for idx, arg in enumerate(args):
@@ -27,13 +29,30 @@ def _validate_torch_tensor_args(args: list) -> list:
2729
except ImportError:
2830
raise ImportError("PyTorch is required to create this metadata but it is not installed.")
2931

32+
def _validate_optimum_exporter_config(config: object) -> object:
33+
if config is None:
34+
return config
35+
try:
36+
from optimum.exporters.base import ExporterConfig
37+
if not isinstance(config, ExporterConfig):
38+
raise ValueError(f"Expected `optimum.exporters.base.ExporterConfig` instance for `optimum_config` but got `{type(config).__qualname__}`")
39+
return config
40+
except ImportError:
41+
pass
42+
3043
class PyTorchInferenceMetadataBase(BaseModel, **ConfigDict(arbitrary_types_allowed=True, frozen=True)):
3144
model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(
3245
description="PyTorch module to apply metadata to.",
3346
exclude=True
3447
)
35-
model_args: Annotated[list[object], BeforeValidator(_validate_torch_tensor_args)] = Field(
36-
description="Positional inputs to the model.",
48+
exporter: TorchExporter | None = Field(
49+
default=None,
50+
description="PyTorch exporter to use.",
51+
exclude=True
52+
)
53+
model_args: Annotated[list[object] | None, BeforeValidator(_validate_torch_tensor_args)] = Field(
54+
default=None,
55+
description="Positional inputs to the model. Required except when `exporter` is `optimum`.",
3756
exclude=True
3857
)
3958
input_shapes: list[tuple] | None = Field(
@@ -46,8 +65,8 @@ class PyTorchInferenceMetadataBase(BaseModel, **ConfigDict(arbitrary_types_allow
4665
description="Model output dictionary keys. Use this if the model returns a dictionary.",
4766
exclude=True
4867
)
49-
exporter: TorchExporter | None = Field(
68+
optimum_config: Annotated[object | None, BeforeValidator(_validate_optimum_exporter_config)] = Field(
5069
default=None,
51-
description="PyTorch exporter to use.",
70+
description="Optimum exporter configuration. Required when `exporter` is `optimum`.",
5271
exclude=True
5372
)

0 commit comments

Comments
 (0)