diff --git a/examples/apple/coreml/scripts/debugger_cli.py b/examples/apple/coreml/scripts/debugger_cli.py index 88390f8d8cb..ba1002226bb 100644 --- a/examples/apple/coreml/scripts/debugger_cli.py +++ b/examples/apple/coreml/scripts/debugger_cli.py @@ -149,7 +149,7 @@ def main() -> None: root_dir_path=get_root_dir_path(), conda_env_name=args.conda_environment_name ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) diff --git a/examples/apple/coreml/scripts/export.py b/examples/apple/coreml/scripts/export.py index 1aa5806e371..8aecfdabec0 100644 --- a/examples/apple/coreml/scripts/export.py +++ b/examples/apple/coreml/scripts/export.py @@ -158,7 +158,7 @@ def main(): f"Valid compute units are {valid_compute_units}." ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index dfb958dce53..1c68424a5fd 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -152,7 +152,7 @@ def get_model_config(args): raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.") model_config = get_model_config(args) - model, example_inputs, _ = EagerModelFactory.create_model(**model_config) + model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config) model = model.eval() diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index ef13a3c346c..2f06076aade 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -50,7 +50,7 @@ def get_model_and_inputs_from_name(model_name: str): logging.warning( "Using a model from examples/models not all of these are currently supported" ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[model_name] ) # Case 3: Model is in an external python file loaded as a module. diff --git a/examples/devtools/scripts/export_bundled_program.py b/examples/devtools/scripts/export_bundled_program.py index 143a7b0e666..8c3cee77e53 100644 --- a/examples/devtools/scripts/export_bundled_program.py +++ b/examples/devtools/scripts/export_bundled_program.py @@ -139,7 +139,7 @@ def main() -> None: f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) diff --git a/examples/devtools/scripts/gen_sample_etrecord.py b/examples/devtools/scripts/gen_sample_etrecord.py index 9194b7caa23..55544395b5a 100644 --- a/examples/devtools/scripts/gen_sample_etrecord.py +++ b/examples/devtools/scripts/gen_sample_etrecord.py @@ -74,7 +74,7 @@ def main() -> None: f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 4b2d640003d..d7961da5ea7 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -774,7 +774,7 @@ def _load_llama_model( logging.info( f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}" ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model( "llama2", "Llama2Model", checkpoint=checkpoint, @@ -824,6 +824,7 @@ def _load_llama_model( use_kv_cache=use_kv_cache, generate_full_logits=generate_full_logits, example_inputs=example_inputs, + example_kwarg_inputs=example_kwarg_inputs, enable_dynamic_shape=enable_dynamic_shape, calibration_tasks=calibration_tasks, calibration_limit=calibration_limit, diff --git a/examples/models/llama2/runner/eager.py b/examples/models/llama2/runner/eager.py index d246a2df212..42357d6e55c 100644 --- a/examples/models/llama2/runner/eager.py +++ b/examples/models/llama2/runner/eager.py @@ -31,7 +31,7 @@ def __init__(self, args): **params, ) super().__init__(tokenizer_path=args.tokenizer, model_args=model_args) - self.model, _, _ = EagerModelFactory.create_model( + self.model, _, _, _ = EagerModelFactory.create_model( "llama2", "Llama2Model", checkpoint=args.checkpoint, diff --git a/examples/models/model_factory.py b/examples/models/model_factory.py index fb317e3bca3..5abe5efe462 100644 --- a/examples/models/model_factory.py +++ b/examples/models/model_factory.py @@ -6,7 +6,7 @@ import importlib import os -from typing import Any, Tuple +from typing import Any, Dict, Tuple import torch @@ -19,7 +19,7 @@ class EagerModelFactory: @staticmethod def create_model( module_name, model_class_name, **kwargs - ) -> Tuple[torch.nn.Module, Any, Any]: + ) -> Tuple[torch.nn.Module, Tuple[Any], Dict[str, Any], Any]: """ Create an instance of a model class that implements EagerModelBase and retrieve related data. @@ -42,14 +42,18 @@ def create_model( if hasattr(module, model_class_name): model_class = getattr(module, model_class_name) model = model_class(**kwargs) + example_kwarg_inputs = None + dynamic_shapes = None + if hasattr(model, "get_example_kwarg_inputs()"): + example_kwarg_inputs = model.get_example_kwarg_inputs() if hasattr(model, "get_dynamic_shapes"): - return ( - model.get_eager_model(), - model.get_example_inputs(), - model.get_dynamic_shapes(), - ) - else: - return model.get_eager_model(), model.get_example_inputs(), None + dynamic_shapes = model.get_dynamic_shapes() + return ( + model.get_eager_model(), + model.get_example_inputs(), + example_kwarg_inputs, + dynamic_shapes, + ) raise ValueError( f"Model class '{model_class_name}' not found in module '{module_name}'." diff --git a/examples/models/test/test_export.py b/examples/models/test/test_export.py index b3030c24fea..6a7c793029c 100644 --- a/examples/models/test/test_export.py +++ b/examples/models/test/test_export.py @@ -69,7 +69,7 @@ def validate_tensor_allclose( return self.assertTrue(result) def test_mv3_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["mv3"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( @@ -81,7 +81,7 @@ def test_mv3_export_to_executorch(self): ) def test_mv2_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["mv2"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( @@ -90,7 +90,7 @@ def test_mv2_export_to_executorch(self): self.validate_tensor_allclose(eager_output, executorch_output[0]) def test_vit_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["vit"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( @@ -102,7 +102,7 @@ def test_vit_export_to_executorch(self): ) def test_w2l_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["w2l"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( @@ -111,7 +111,7 @@ def test_w2l_export_to_executorch(self): self.validate_tensor_allclose(eager_output, executorch_output[0]) def test_ic3_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["ic3"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( @@ -123,7 +123,7 @@ def test_ic3_export_to_executorch(self): ) def test_resnet18_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["resnet18"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( @@ -132,7 +132,7 @@ def test_resnet18_export_to_executorch(self): self.validate_tensor_allclose(eager_output, executorch_output[0]) def test_resnet50_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["resnet50"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( @@ -141,7 +141,7 @@ def test_resnet50_export_to_executorch(self): self.validate_tensor_allclose(eager_output, executorch_output[0]) def test_dl3_export_to_executorch(self): - eager_model, example_inputs, _ = EagerModelFactory.create_model( + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL["dl3"] ) eager_output, executorch_output = self.collect_executorch_and_eager_outputs( diff --git a/examples/portable/scripts/export.py b/examples/portable/scripts/export.py index 6055ecef0f3..ec829aa2a7e 100644 --- a/examples/portable/scripts/export.py +++ b/examples/portable/scripts/export.py @@ -58,7 +58,7 @@ def main() -> None: f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." ) - model, example_inputs, dynamic_shapes = EagerModelFactory.create_model( + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) diff --git a/examples/portable/scripts/export_and_delegate.py b/examples/portable/scripts/export_and_delegate.py index 50f2ce6d901..6a8a28d5338 100644 --- a/examples/portable/scripts/export_and_delegate.py +++ b/examples/portable/scripts/export_and_delegate.py @@ -57,7 +57,7 @@ def export_composite_module_with_lower_graph(): "Running the example to export a composite module with lowered graph..." ) - m, m_inputs, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"]) + m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"]) m_compile_spec = m.get_compile_spec() # pre-autograd export. eventually this will become torch.export @@ -166,7 +166,7 @@ def export_and_lower_the_whole_graph(): """ logging.info("Running the example to export and lower the whole graph...") - m, m_inputs, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"]) + m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"]) m_compile_spec = m.get_compile_spec() m_inputs = m.get_example_inputs() diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index 9e073b998d7..2e49a2344b8 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -58,7 +58,7 @@ def main() -> None: f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) diff --git a/examples/xnnpack/aot_compiler.py b/examples/xnnpack/aot_compiler.py index f65f9b73a58..c3538db4d83 100644 --- a/examples/xnnpack/aot_compiler.py +++ b/examples/xnnpack/aot_compiler.py @@ -79,7 +79,7 @@ f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}." ) - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) diff --git a/examples/xnnpack/quantization/example.py b/examples/xnnpack/quantization/example.py index e5453842281..141b0701d0c 100644 --- a/examples/xnnpack/quantization/example.py +++ b/examples/xnnpack/quantization/example.py @@ -162,7 +162,7 @@ def main() -> None: ) start = time.perf_counter() - model, example_inputs, _ = EagerModelFactory.create_model( + model, example_inputs, _, _ = EagerModelFactory.create_model( *MODEL_NAME_TO_MODEL[args.model_name] ) end = time.perf_counter() diff --git a/extension/export_util/utils.py b/extension/export_util/utils.py index 40ceb6ffec2..66154b95faa 100644 --- a/extension/export_util/utils.py +++ b/extension/export_util/utils.py @@ -26,6 +26,8 @@ def _to_core_aten( model: Union[torch.fx.GraphModule, torch.nn.Module], example_inputs: Tuple[Value, ...], + *, + example_kwarg_inputs: Optional[Dict] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, strict=True, verbose=True, @@ -38,7 +40,11 @@ def _to_core_aten( f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" ) core_aten_ep = export( - model, example_inputs, dynamic_shapes=dynamic_shapes, strict=strict + model, + example_inputs, + example_kwarg_inputs, + dynamic_shapes=dynamic_shapes, + strict=strict, ) if verbose: logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") @@ -69,6 +75,8 @@ def _core_aten_to_edge( def export_to_edge( model: Union[torch.fx.GraphModule, torch.nn.Module], example_inputs: Tuple[Value, ...], + *, + example_kwarg_inputs: Optional[Dict] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, edge_constant_methods: Optional[Dict[str, Any]] = None, edge_compile_config=_EDGE_COMPILE_CONFIG, @@ -76,7 +84,12 @@ def export_to_edge( verbose=True, ) -> EdgeProgramManager: core_aten_ep = _to_core_aten( - model, example_inputs, dynamic_shapes, strict=strict, verbose=verbose + model, + example_inputs, + example_kwarg_inputs=example_kwarg_inputs, + dynamic_shapes=dynamic_shapes, + strict=strict, + verbose=verbose, ) return _core_aten_to_edge( core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose @@ -86,6 +99,8 @@ def export_to_edge( def export_to_exec_prog( model: Union[torch.fx.GraphModule, torch.nn.Module], example_inputs: Tuple[Value, ...], + *, + example_kwarg_inputs: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, edge_constant_methods: Optional[Dict[str, Any]] = None, edge_compile_config=_EDGE_COMPILE_CONFIG, @@ -96,7 +111,13 @@ def export_to_exec_prog( # pre-autograd export. eventually this will become torch.export m = export_for_training(m, example_inputs).module() - core_aten_ep = _to_core_aten(m, example_inputs, dynamic_shapes, strict=strict) + core_aten_ep = _to_core_aten( + m, + example_inputs, + example_kwarg_inputs=example_kwarg_inputs, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) edge_m = _core_aten_to_edge( core_aten_ep, edge_constant_methods, edge_compile_config diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ee54fe3660d..11d92f32f56 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -10,7 +10,7 @@ import logging from enum import Enum -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( @@ -68,6 +68,7 @@ def __init__( dtype, use_kv_cache, example_inputs, + example_kwarg_inputs: Optional[Dict] = None, args: Optional[Any] = None, enable_dynamic_shape: bool = False, generate_full_logits: bool = False, @@ -87,6 +88,7 @@ def __init__( self.max_seq_len = max_seq_len self.dtype = dtype self.example_inputs = example_inputs + self.example_kwarg_inputs = example_kwarg_inputs self.use_kv_cache = use_kv_cache self.generate_full_logits = generate_full_logits self.enable_dynamic_shape = enable_dynamic_shape @@ -186,12 +188,16 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": self.pre_autograd_graph_module = torch.export.export( self.model, self.example_inputs, + self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, strict=True, ).module() else: self.pre_autograd_graph_module = capture_pre_autograd_graph( - self.model, self.example_inputs, dynamic_shapes=dynamic_shape + self.model, + self.example_inputs, + kwargs=self.example_kwarg_inputs, + dynamic_shapes=dynamic_shape, ) return self @@ -340,6 +346,7 @@ def export_to_edge(self) -> "LLMEdgeManager": self.edge_manager = export_to_edge( self.pre_autograd_graph_module, # pyre-fixme[6] self.example_inputs, + example_kwarg_inputs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, edge_constant_methods=self.metadata, edge_compile_config=edge_config,