From f69655c06b9cc7125b1c85010e91f4f6c0a6e6c8 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 27 May 2025 15:27:12 -0700 Subject: [PATCH] Use llm_config instead of args in export_llama functions [ghstack-poisoned] --- examples/models/llama/eval_llama_lib.py | 66 +++-- examples/models/llama/export_llama_lib.py | 260 ++++++++++-------- examples/models/llama/runner/eager.py | 59 +++- .../llama/tests/test_export_llama_lib.py | 10 +- examples/models/llava/export_llava.py | 79 ++++-- 5 files changed, 289 insertions(+), 185 deletions(-) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 47b13df52e0..5a877ab85f2 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -164,6 +164,7 @@ def _model_call(self, inps): def gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, + llm_config=None, ): """ Generates a wrapper interface around the provided model and tokenizer for @@ -172,7 +173,15 @@ def gen_eval_wrapper( Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore + # If llm_config is not provided, convert args to llm_config + if llm_config is None: + from executorch.examples.models.llama.config.llm_config_utils import ( + convert_args_to_llm_config, + ) + + llm_config = convert_args_to_llm_config(args) + + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore @@ -182,7 +191,7 @@ def gen_eval_wrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, - max_seq_length=args.max_seq_length, # pyre-ignore + max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings @@ -191,12 +200,14 @@ def gen_eval_wrapper( tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. - max_seq_length=args.max_seq_length - 1, + max_seq_length=llm_config.export.max_seq_length - 1, ) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config, args) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) @@ -208,9 +219,9 @@ def gen_eval_wrapper( return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, # pyre-ignore - enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch @@ -234,8 +245,8 @@ def gen_eval_wrapper( return EagerEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, ) @@ -296,12 +307,18 @@ def eval_llama( model_name: str, args: argparse.ArgumentParser, ) -> None: + # Convert args to LlmConfig + from executorch.examples.models.llama.config.llm_config_utils import ( + convert_args_to_llm_config, + ) + + llm_config = convert_args_to_llm_config(args) + # Generate the eval wrapper - eval_wrapper = gen_eval_wrapper(model_name, args) + eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files - # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` if args.tasks and "mmlu" in args.tasks: import datasets @@ -312,8 +329,8 @@ def eval_llama( eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, - num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` - limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` + num_fewshot=args.num_fewshot, + limit=args.limit, ) for task, res in eval_results["results"].items(): @@ -326,19 +343,26 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py """ - assert args.use_attention_sink is not None # pyre-ignore [16] - assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] - attention_sink_params = args.use_attention_sink.split(",") + # Convert args to LlmConfig + from executorch.examples.models.llama.config.llm_config_utils import ( + convert_args_to_llm_config, + ) + + llm_config = convert_args_to_llm_config(args) + + assert llm_config.model.use_attention_sink is not None + assert args.attention_sink_eval_tokens > 0 + attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] + assert llm_config.export.max_seq_length == sink_size + window_size device = "cuda" if torch.cuda.is_available() else "cpu" - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config, args) model = manager.model.eval().to(device=device) - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") @@ -347,7 +371,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: - for text in eval_data["text"]: # pyre-ignore [16] + for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e3fbfaa5872..427fc8e8a74 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -579,49 +579,53 @@ def export_llama( if isinstance(export_options, argparse.Namespace): # Legacy CLI. args = export_options - llm_config = convert_args_to_llm_config(export_options) # noqa: F841 + llm_config = convert_args_to_llm_config(export_options) elif isinstance(export_options, DictConfig): # Hydra CLI. - llm_config = export_options # noqa: F841 + llm_config = export_options + # Create an args object for backward compatibility during transition + args = argparse.Namespace() + for key, value in llm_config.items(): + setattr(args, key, value) else: raise ValueError( "Input to export_llama must be either of type argparse.Namespace or LlmConfig" ) - # TODO: refactor rest of export_llama to use llm_config instead of args. - # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. - if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: - repo_id = HUGGING_FACE_REPO_IDS[args.model] - if args.model == "qwen2_5": + model_name = llm_config.base.model_class + if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS: + repo_id = HUGGING_FACE_REPO_IDS[model_name] + if model_name == "qwen2_5": from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21] convert_weights, ) - elif args.model.startswith("qwen3"): + elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "phi_4_mini": + elif model_name == "phi_4_mini": from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "smollm2": + elif model_name == "smollm2": from executorch.examples.models.smollm2 import ( # pyre-ignore[21] convert_weights, ) else: raise ValueError( - f"Converting weights to meta format for {args.model} is not yet supported" + f"Converting weights to meta format for {model_name} is not yet supported" ) - args.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + llm_config.base.checkpoint = checkpoint - if args.profile_path is not None: + if llm_config.debug.profile_path is not None: try: from executorch.util.python_profiler import CProfilerFlameGraph - with CProfilerFlameGraph(args.profile_path): - builder = _export_llama(args) + with CProfilerFlameGraph(llm_config.debug.profile_path): + builder = _export_llama(llm_config, args) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" @@ -632,14 +636,14 @@ def export_llama( ) return "" else: - builder = _export_llama(args) + builder = _export_llama(llm_config, args) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" return filename -def _prepare_for_llama_export(args) -> LLMEdgeManager: +def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager: """ Helper function for export_llama. Loads the model from checkpoint and params, and sets up a LLMEdgeManager with initial transforms and dtype conversion. @@ -647,39 +651,49 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: Returns a LLMEdgeManager prior to calling export_to_edge with quantizers """ # load model from checkpoint and params.json - checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None + checkpoint_path = ( + canonical_path(llm_config.base.checkpoint) + if llm_config.base.checkpoint + else None + ) checkpoint_dir = ( - canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None + canonical_path(llm_config.base.checkpoint_dir) + if llm_config.base.checkpoint_dir + else None + ) + params_path = ( + canonical_path(llm_config.base.params) if llm_config.base.params else None ) - params_path = canonical_path(args.params) if args.params else None - output_dir_path = canonical_path(args.output_dir, dir=True) - weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA + output_dir_path = canonical_path(llm_config.export.output_dir, dir=True) + weight_type = WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA - # Convert dtype override string arg to actual type. - dtype_override = DType[args.dtype_override] + # Convert dtype override string to actual type + dtype_override = DType[llm_config.model.dtype_override] edge_manager = _load_llama_model( - args.model, + llm_config.base.model_class, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, - use_kv_cache=args.use_kv_cache, - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - generate_full_logits=args.generate_full_logits, + use_kv_cache=llm_config.model.use_kv_cache, + use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, + generate_full_logits=llm_config.debug.generate_full_logits, weight_type=weight_type, - enable_dynamic_shape=args.enable_dynamic_shape, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - calibration_data=args.calibration_data, - tokenizer_path=args.tokenizer_path, - verbose=args.verbose, - max_seq_len=args.max_seq_length, - max_context_len=args.max_context_length, - input_prune_map_path=args.input_prune_map, - output_prune_map_path=args.output_prune_map, - metadata_str=args.metadata, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + calibration_data=llm_config.quantization.calibration_data, + tokenizer_path=llm_config.base.tokenizer_path, + verbose=llm_config.debug.verbose, + max_seq_len=llm_config.export.max_seq_length, + max_context_len=llm_config.export.max_context_length, + input_prune_map_path=llm_config.model.input_prune_map, + output_prune_map_path=llm_config.model.output_prune_map, + metadata_str=llm_config.base.metadata, dtype_override=dtype_override, + use_qnn=llm_config.backend.qnn.enabled, + export_only=llm_config.export.export_only, args=args, ) @@ -710,64 +724,64 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( dtype_override=dtype_override, - checkpoint=args.checkpoint, + checkpoint=llm_config.base.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore - tokenizer_path=args.tokenizer_path, - use_spin_quant=args.use_spin_quant, - embedding_quantize=args.embedding_quantize, - use_shared_embedding=args.use_shared_embedding, - quantization_mode=args.quantization_mode, - group_size=args.group_size, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - expand_rope_table=args.expand_rope_table, + tokenizer_path=llm_config.base.tokenizer_path, + use_spin_quant=llm_config.quantization.use_spin_quant, + embedding_quantize=llm_config.quantization.embedding_quantize, + use_shared_embedding=llm_config.model.use_shared_embedding, + quantization_mode=llm_config.quantization.qmode, + group_size=llm_config.quantization.group_size, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + expand_rope_table=llm_config.model.expand_rope_table, use_custom_sdpa_with_attention_mask=getattr( - args, "use_custom_sdpa_with_attention_mask", False + llm_config.model, "use_custom_sdpa_with_attention_mask", False ), - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - quantize_kv_cache=args.quantize_kv_cache, - use_kv_cache=args.use_kv_cache, - qnn=args.qnn, - use_qnn_sha=args.use_qnn_sha, - optimized_rotation_path=args.optimized_rotation_path, - mps=args.mps, - coreml=args.coreml, - coreml_ios=args.coreml_ios, - vulkan=args.vulkan, - use_qat=args.use_qat, - use_lora=args.use_lora, - preq_mode=args.preq_mode, - preq_group_size=args.preq_group_size, - preq_embedding_quantize=args.preq_embedding_quantize, - local_global_attention=args.local_global_attention, + use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, + quantize_kv_cache=llm_config.model.quantize_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, + qnn=llm_config.backend.qnn.enabled, + use_qnn_sha=llm_config.backend.qnn.use_sha, + optimized_rotation_path=llm_config.backend.qnn.optimized_rotation_path, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + coreml_ios=llm_config.backend.coreml.ios, + vulkan=llm_config.backend.vulkan.enabled, + use_qat=llm_config.quantization.use_qat, + use_lora=llm_config.base.use_lora, + preq_mode=llm_config.base.preq_mode, + preq_group_size=llm_config.base.preq_group_size, + preq_embedding_quantize=llm_config.base.preq_embedding_quantize, + local_global_attention=llm_config.model.local_global_attention, ) ) return edge_manager -def get_quantizer_and_quant_params(args): +def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) - quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library) + quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library) quant_dtype = None - if args.qnn and args.pt2e_quantize: + if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) - if args.coreml and args.pt2e_quantize: + if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" - coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) + coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(coreml_quantizer) - if args.vulkan and args.pt2e_quantize: + if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 ), "Should not enable both vulkan and other quantizers" - vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) + vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -790,28 +804,32 @@ def _qmode_type(value): ) -def _validate_args(args): +def _validate_args(llm_config): """ TODO: Combine all the backends under --backend args """ - if args.max_context_length < args.max_seq_length: + if llm_config.export.max_context_length < llm_config.export.max_seq_length: raise ValueError( - f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." + f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." ) - if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn): + if llm_config.model.enable_dynamic_shape and ( + llm_config.backend.coreml.enabled + or llm_config.backend.mps.enabled + or llm_config.backend.qnn.enabled + ): raise ValueError( "Dynamic shape is not supported with coreml, MPS or qnn backends." " Please use --disable_dynamic_shape." ) - if args.num_sharding > 0 and not args.qnn: + if llm_config.backend.qnn.num_sharding > 0 and not llm_config.backend.qnn.enabled: raise ValueError("Model shard is only supported with qnn backend now.") - if args.use_shared_embedding: + if llm_config.model.use_shared_embedding: if not ( - args.embedding_quantize is not None - and args.embedding_quantize.startswith("torchao:") + llm_config.quantization.embedding_quantize is not None + and llm_config.quantization.embedding_quantize.startswith("torchao:") ): raise ValueError( "Shared embedding is only supported with torchao quantization." @@ -1039,28 +1057,30 @@ def _to_edge_and_lower_llama( # noqa: C901 return builder -def _export_llama(args) -> LLMEdgeManager: # noqa: C901 - _validate_args(args) +def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901 + _validate_args(llm_config) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) additional_passes = [] - if args.model in TORCHTUNE_DEFINED_MODELS: + if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(args).export() + builder_exported = _prepare_for_llama_export(llm_config, args).export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname - if args.export_only: + if llm_config.export.export_only: exit() if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: - # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False - args.xnnpack = True + # Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False + llm_config.backend.xnnpack.enabled = True - if args.xnnpack: + if llm_config.backend.xnnpack.enabled: builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, @@ -1068,9 +1088,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - xnnpack_extended_ops=args.xnnpack_extended_ops, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) else: builder = _to_edge_and_lower_llama( @@ -1080,33 +1100,33 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - vulkan=args.vulkan, - mps=args.mps, - coreml=args.coreml, - qnn=args.qnn, - dtype_override=args.dtype_override, - enable_dynamic_shape=args.enable_dynamic_shape, - use_kv_cache=args.use_kv_cache, - embedding_quantize=args.embedding_quantize, - pt2e_quantize=args.pt2e_quantize, - coreml_ios=args.coreml_ios, - coreml_quantize=args.coreml_quantize, - coreml_compute_units=args.coreml_compute_units, - use_qnn_sha=args.use_qnn_sha, - num_sharding=args.num_sharding, - soc_model=args.soc_model, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + vulkan=llm_config.backend.vulkan.enabled, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + qnn=llm_config.backend.qnn.enabled, + dtype_override=llm_config.model.dtype_override, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + use_kv_cache=llm_config.model.use_kv_cache, + embedding_quantize=llm_config.quantization.embedding_quantize, + pt2e_quantize=llm_config.quantization.pt2e_quantize, + coreml_ios=llm_config.backend.coreml.ios, + coreml_quantize=llm_config.backend.coreml.quantize, + coreml_compute_units=llm_config.backend.coreml.compute_units, + use_qnn_sha=llm_config.backend.qnn.use_sha, + num_sharding=llm_config.backend.qnn.num_sharding, + soc_model=llm_config.backend.qnn.soc_model, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) - if args.profile_memory: + if llm_config.debug.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") if builder.dtype == DType.fp16: modelname = f"{modelname}_h" - if args.output_name: - modelname = args.output_name + if llm_config.export.output_name: + modelname = llm_config.export.output_name if modelname.endswith(".pte"): output_file = modelname modelname = modelname[:-4] @@ -1179,6 +1199,8 @@ def _load_llama_model( output_prune_map_path: Optional[str] = None, metadata_str: Optional[str] = None, dtype_override: Optional[DType] = None, + use_qnn: bool = False, + export_only: bool = False, args, ) -> "LLMEdgeManager": """ @@ -1239,8 +1261,8 @@ def _load_llama_model( calibration_seq_length=calibration_seq_length, calibration_data=calibration_data, tokenizer_path=tokenizer_path, - use_legacy_export=args.qnn, - save_exported_program=args.export_only, + use_legacy_export=use_qnn, + save_exported_program=export_only, verbose=verbose, metadata=_load_llama_model_metadata( weight_type, diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 0b842a8f976..b308f287782 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -10,6 +10,10 @@ import torch +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.config.llm_config_utils import ( + convert_args_to_llm_config, +) from executorch.examples.models.llama.export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser, @@ -23,19 +27,24 @@ class EagerLlamaRunner(LlamaRunner): Runs llama in eager mode with provided checkpoint file. """ - def __init__(self, args): - with open(args.params, "r") as f: + def __init__( + self, + llm_config: LlmConfig, + tokenizer_config_path: Optional[str] = None, + use_attention_sink: bool = False, + ): + with open(llm_config.base.params, "r") as f: params = json.loads(f.read()) super().__init__( - tokenizer_path=args.tokenizer_path, - tokenizer_config_path=args.tokenizer_config_path, - max_seq_len=args.max_seq_length, + tokenizer_path=llm_config.base.tokenizer_path, + tokenizer_config_path=tokenizer_config_path, + max_seq_len=llm_config.export.max_seq_length, max_batch_size=1, - use_kv_cache=args.use_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, vocab_size=params["vocab_size"], device="cuda" if torch.cuda.is_available() else "cpu", ) - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config, None) self.model = manager.model.eval().to(device=self.device) def forward( @@ -49,6 +58,7 @@ def forward( def build_args_parser() -> argparse.ArgumentParser: parser = _build_args_parser() + # Runner-specific arguments that aren't part of LlmConfig parser.add_argument( "--prompt", type=str, @@ -89,22 +99,41 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: parser = build_args_parser() args = parser.parse_args() + # Convert args to LlmConfig for model configuration. + llm_config = convert_args_to_llm_config(args) + + # Extract runner-specific parameters. + prompt = args.prompt + temperature = args.temperature + show_tokens = args.show_tokens + chat_mode = args.chat + tokenizer_config_path = args.tokenizer_config_path + use_attention_sink = args.use_attention_sink + with torch.no_grad(): - runner = runner_class(args) # pyre-ignore: Missing argument [20] + # Create runner with LlmConfig and separate runner parameters. + runner = runner_class( + llm_config=llm_config, + tokenizer_config_path=tokenizer_config_path, + use_attention_sink=use_attention_sink, + ) + generated_tokens = ( runner.chat_completion( - max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length, - temperature=args.temperature, - show_progress=args.show_tokens, + max_seq_len=( + 1000000 if use_attention_sink else llm_config.export.max_seq_length + ), + temperature=temperature, + show_progress=show_tokens, ) - if args.chat + if chat_mode else runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + prompt=prompt, + temperature=temperature, echo=True, ) ) - if args.show_tokens: + if show_tokens: print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index b94adb5fa0c..75223dc35a0 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -7,6 +7,7 @@ import unittest from executorch.devtools.backend_debug import get_delegation_info +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( _export_llama, build_args_parser, @@ -34,13 +35,20 @@ def test_has_expected_ops_and_op_counts(self): # we cannot test quantization args in this way # since quantization requires promoting meta tensors # to device=cpu, which requires real weights. + + llm_config = LlmConfig() + llm_config.model.use_sdpa_with_kv_cache = True + llm_config.model.use_kv_cache = True + llm_config.debug.verbose = True + + # We still need args for backward compatibility during transition parser = build_args_parser() args = parser.parse_args([]) args.use_sdpa_with_kv_cache = True args.use_kv_cache = True args.verbose = True - builder = _export_llama(args) + builder = _export_llama(llm_config, args) graph_module = builder.edge_manager.exported_program().graph_module delegation_info = get_delegation_info(graph_module) diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 18ef83ee1e4..41f0c60980b 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -16,6 +16,10 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.config.llm_config_utils import ( + convert_args_to_llm_config, +) from executorch.examples.models.llama.export_llama_lib import ( build_args_parser, get_quantizer_and_quant_params, @@ -92,32 +96,30 @@ def forward(self, input_pos, embeddings): dynamic_shapes=dynamic_shapes, ) - dtype_override = DType.fp32 + # (Legacy) parse args then convert to LlmConfig. parser = build_args_parser() - args = parser.parse_args( - [ - "-p", - "params.json", - "-X", - "-qmode", - "8da4w", - "--group_size", - "128", - "--embedding-quantize", - "4,32", - ] - ) + args = parser.parse_args() + llm_config = convert_args_to_llm_config(args) + + # Manually set some LlmConfig options. + llm_config.base.params = "params.json" + llm_config.backend.xnnpack.enabled = True + llm_config.quantization.qmode = "8da4w" + llm_config.quantization.group_size = 128 + llm_config.quantization.embedding_quantize = "4,32" + + dtype_override = DType.fp32 quant_transform = get_quant_weight_transform( - quantization_mode=args.quantization_mode, - group_size=args.group_size, + quantization_mode=llm_config.quantization.qmode, + group_size=llm_config.quantization.group_size, computation_dtype=dtype_override, - checkpoint_path=args.checkpoint, - tokenizer_path=args.tokenizer_path, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, + checkpoint_path=llm_config.base.checkpoint, + tokenizer_path=llm_config.base.tokenizer_path, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, ) - _, quantizers, _ = get_quantizer_and_quant_params(args) + _, quantizers, _ = get_quantizer_and_quant_params(llm_config) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: source_transforms.append(replace_kv_cache_with_custom_kv_cache) @@ -279,6 +281,20 @@ def get_tokenizer_for_llava_runner(llava_model): t.export("tokenizer.bin") +def create_llava_config_from_args(args): + """ + Create an LlmConfig from command line arguments for LLaVA export + """ + llm_config = LlmConfig() + + llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache + llm_config.export.max_seq_length = args.max_seq_len + llm_config.export.output_name = args.pte_name + llm_config.debug.profile_memory = args.profile_memory + + return llm_config + + def main(): parser = ArgumentParser() parser.add_argument( @@ -311,28 +327,33 @@ def main(): help="Generate chrome trace of activation memory for intermediate tensors.", ) args = parser.parse_args() + + # Create LlmConfig from args + llm_config = create_llava_config_from_args(args) + logging.info( - f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}, max_seq_len: {args.max_seq_len}" + f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}" ) + llava_model = LlavaModel( - use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache, - max_seq_len=args.max_seq_len, + use_sdpa_with_kv_cache_op=llm_config.model.use_sdpa_with_kv_cache, + max_seq_len=llm_config.export.max_seq_length, ) executorch_program = export_all(llava_model) # memory profiling - if args.profile_memory: + if llm_config.debug.profile_memory: for method_name in executorch_program.methods: generate_memory_trace( executorch_program, - f"{args.pte_name}_{method_name}.json", + f"{llm_config.export.output_name}_{method_name}.json", method_name=method_name, ) - with open(args.pte_name, "wb") as f: + with open(llm_config.export.output_name, "wb") as f: executorch_program.write_to_file(f) - logging.info(f"Exported ExecuTorch program to {args.pte_name}") + logging.info(f"Exported ExecuTorch program to {llm_config.export.output_name}") # artifacts if args.with_artifacts: