66from pydantic import BaseModel , BeforeValidator , ConfigDict , Field
77from typing import Annotated , Literal
88
9- TorchExporter = Literal ["none" , "dynamo" , "torchscript" ]
9+ TorchExporter = Literal ["none" , "dynamo" , "torchscript" , "optimum" ]
1010
1111def _validate_torch_module (module : "torch.nn.Module" ) -> "torch.nn.Module" : # type: ignore
1212 try :
@@ -17,7 +17,9 @@ def _validate_torch_module(module: "torch.nn.Module") -> "torch.nn.Module": # ty
1717 except ImportError :
1818 raise ImportError ("PyTorch is required to create this metadata but it is not installed." )
1919
20- def _validate_torch_tensor_args (args : list ) -> list :
20+ def _validate_torch_tensor_args (args : list ) -> list | None :
21+ if args is None :
22+ return args
2123 try :
2224 from torch import Tensor
2325 for idx , arg in enumerate (args ):
@@ -27,13 +29,30 @@ def _validate_torch_tensor_args(args: list) -> list:
2729 except ImportError :
2830 raise ImportError ("PyTorch is required to create this metadata but it is not installed." )
2931
32+ def _validate_optimum_exporter_config (config : object ) -> object :
33+ if config is None :
34+ return config
35+ try :
36+ from optimum .exporters .base import ExporterConfig
37+ if not isinstance (config , ExporterConfig ):
38+ raise ValueError (f"Expected `optimum.exporters.base.ExporterConfig` instance for `optimum_config` but got `{ type (config ).__qualname__ } `" )
39+ return config
40+ except ImportError :
41+ pass
42+
3043class PyTorchInferenceMetadataBase (BaseModel , ** ConfigDict (arbitrary_types_allowed = True , frozen = True )):
3144 model : Annotated [object , BeforeValidator (_validate_torch_module )] = Field (
3245 description = "PyTorch module to apply metadata to." ,
3346 exclude = True
3447 )
35- model_args : Annotated [list [object ], BeforeValidator (_validate_torch_tensor_args )] = Field (
36- description = "Positional inputs to the model." ,
48+ exporter : TorchExporter | None = Field (
49+ default = None ,
50+ description = "PyTorch exporter to use." ,
51+ exclude = True
52+ )
53+ model_args : Annotated [list [object ] | None , BeforeValidator (_validate_torch_tensor_args )] = Field (
54+ default = None ,
55+ description = "Positional inputs to the model. Required except when `exporter` is `optimum`." ,
3756 exclude = True
3857 )
3958 input_shapes : list [tuple ] | None = Field (
@@ -46,8 +65,8 @@ class PyTorchInferenceMetadataBase(BaseModel, **ConfigDict(arbitrary_types_allow
4665 description = "Model output dictionary keys. Use this if the model returns a dictionary." ,
4766 exclude = True
4867 )
49- exporter : TorchExporter | None = Field (
68+ optimum_config : Annotated [ object | None , BeforeValidator ( _validate_optimum_exporter_config )] = Field (
5069 default = None ,
51- description = "PyTorch exporter to use ." ,
70+ description = "Optimum exporter configuration. Required when `exporter` is `optimum` ." ,
5271 exclude = True
5372 )
0 commit comments