diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index 57258fdbc11..95b3ff0fb7c 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -194,7 +194,7 @@ def gen_eval_wrapper( manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args) if len(quantizers) != 0: - manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers) + manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() @@ -209,7 +209,7 @@ def gen_eval_wrapper( ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch - # for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but + # for quantizers. Currently export_for_training only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 0d292b11e7b..8cff6e8e11a 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -581,7 +581,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 # export_to_edge builder_exported_to_edge = ( _prepare_for_llama_export(modelname, args) - .capture_pre_autograd_graph() + .export() .pt2e_quantize(quantizers) .export_to_edge() ) diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 47a5407cf18..cc475f1b196 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -53,7 +53,7 @@ class LlavaEdgeManager(LLMEdgeManager): - def capture_pre_autograd_graph(self) -> "LlavaEdgeManager": + def export(self) -> "LlavaEdgeManager": dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -107,7 +107,7 @@ def forward(self, input_pos, embeddings): text_model_em.set_output_dir("./") .to_dtype(dtype_override) .source_transform(source_transforms) - .capture_pre_autograd_graph() + .export() .pt2e_quantize(quantizers) ) @@ -148,7 +148,7 @@ def forward(self, images): dynamic_shapes=dynamic_shapes, args=None, ) - .capture_pre_autograd_graph() + .export() .pt2e_quantize([quantizer]) ) diff --git a/examples/portable/scripts/export.py b/examples/portable/scripts/export.py index ec829aa2a7e..353d8a034e0 100644 --- a/examples/portable/scripts/export.py +++ b/examples/portable/scripts/export.py @@ -65,9 +65,7 @@ def main() -> None: backend_config = ExecutorchBackendConfig() if args.segment_alignment is not None: backend_config.segment_alignment = int(args.segment_alignment, 16) - if ( - dynamic_shapes is not None - ): # capture_pre_autograd_graph does not work with dynamic shapes + if dynamic_shapes is not None: edge_manager = export_to_edge( model, example_inputs, diff --git a/extension/llm/README.md b/extension/llm/README.md index ddcf4c727d2..7f4baed7d31 100644 --- a/extension/llm/README.md +++ b/extension/llm/README.md @@ -10,7 +10,7 @@ Commonly used methods in this class include: - _source_transform_: execute a series of source transform passes. Some transform passes include - weight only quantization, which can be done at source (eager mode) level. - replace some torch operators to a custom operator. For example, _replace_sdpa_with_custom_op_. -- _capture_pre_autograd_graph_: get a graph that is ready for pt2 graph-based quantization. +- _torch.export_for_training_: get a graph that is ready for pt2 graph-based quantization. - _pt2e_quantize_ with passed in quantizers. - util functions in _quantizer_lib.py_ can help to get different quantizers based on the needs. - _export_to_edge_: export to edge dialect diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 76779bdd636..d2a413fc793 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -82,7 +82,7 @@ def __init__( dynamic_shapes: Optional[Any] = None, ): self.model = model - # graph module returned from capture_pre_autograd_graph + # graph module returned from export() self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None self.modelname = modelname self.max_seq_len = max_seq_len @@ -176,7 +176,7 @@ def _get_edge_config(self) -> EdgeCompileConfig: ) return edge_config - def capture_pre_autograd_graph(self) -> "LLMEdgeManager": + def export(self) -> "LLMEdgeManager": dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -296,7 +296,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage composed_quantizer = ComposableQuantizer(quantizers) assert ( self.pre_autograd_graph_module is not None - ), "Please run capture_pre_autograd_graph first" + ), "Please run export() first" m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) logging.info( f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" @@ -344,8 +344,8 @@ def export_to_edge(self) -> "LLMEdgeManager": # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): if self.pre_autograd_graph_module is None: - # Run capture_pre_autograd_graph if it didn't run - self.capture_pre_autograd_graph() + # Run export() if it didn't run + self.export() self.edge_manager = export_to_edge( self.pre_autograd_graph_module, # pyre-fixme[6] self.example_inputs,