diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e3fbfaa5872..e5da85185a1 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -579,49 +579,53 @@ def export_llama( if isinstance(export_options, argparse.Namespace): # Legacy CLI. args = export_options - llm_config = convert_args_to_llm_config(export_options) # noqa: F841 + llm_config = convert_args_to_llm_config(export_options) elif isinstance(export_options, DictConfig): # Hydra CLI. - llm_config = export_options # noqa: F841 + llm_config = export_options + # Create an args object for backward compatibility during transition + args = argparse.Namespace() + for key, value in llm_config.items(): + setattr(args, key, value) else: raise ValueError( "Input to export_llama must be either of type argparse.Namespace or LlmConfig" ) - # TODO: refactor rest of export_llama to use llm_config instead of args. - # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. - if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: - repo_id = HUGGING_FACE_REPO_IDS[args.model] - if args.model == "qwen2_5": + model_name = llm_config.base.model_class + if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS: + repo_id = HUGGING_FACE_REPO_IDS[model_name] + if model_name == "qwen2_5": from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21] convert_weights, ) - elif args.model.startswith("qwen3"): + elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "phi_4_mini": + elif model_name == "phi_4_mini": from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "smollm2": + elif model_name == "smollm2": from executorch.examples.models.smollm2 import ( # pyre-ignore[21] convert_weights, ) else: raise ValueError( - f"Converting weights to meta format for {args.model} is not yet supported" + f"Converting weights to meta format for {model_name} is not yet supported" ) - args.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + llm_config.base.checkpoint = checkpoint - if args.profile_path is not None: + if llm_config.debug.profile_path is not None: try: from executorch.util.python_profiler import CProfilerFlameGraph - with CProfilerFlameGraph(args.profile_path): - builder = _export_llama(args) + with CProfilerFlameGraph(llm_config.debug.profile_path): + builder = _export_llama(llm_config, args) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" @@ -632,14 +636,14 @@ def export_llama( ) return "" else: - builder = _export_llama(args) + builder = _export_llama(llm_config, args) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" return filename -def _prepare_for_llama_export(args) -> LLMEdgeManager: +def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager: """ Helper function for export_llama. Loads the model from checkpoint and params, and sets up a LLMEdgeManager with initial transforms and dtype conversion. @@ -647,38 +651,46 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: Returns a LLMEdgeManager prior to calling export_to_edge with quantizers """ # load model from checkpoint and params.json - checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None + checkpoint_path = ( + canonical_path(llm_config.base.checkpoint) + if llm_config.base.checkpoint + else None + ) checkpoint_dir = ( - canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None + canonical_path(llm_config.base.checkpoint_dir) + if llm_config.base.checkpoint_dir + else None + ) + params_path = ( + canonical_path(llm_config.base.params) if llm_config.base.params else None ) - params_path = canonical_path(args.params) if args.params else None - output_dir_path = canonical_path(args.output_dir, dir=True) - weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA + output_dir_path = canonical_path(llm_config.export.output_dir, dir=True) + weight_type = WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA - # Convert dtype override string arg to actual type. - dtype_override = DType[args.dtype_override] + # Convert dtype override string to actual type + dtype_override = DType[llm_config.model.dtype_override] edge_manager = _load_llama_model( - args.model, + llm_config.base.model_class, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, - use_kv_cache=args.use_kv_cache, - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - generate_full_logits=args.generate_full_logits, + use_kv_cache=llm_config.model.use_kv_cache, + use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, + generate_full_logits=llm_config.debug.generate_full_logits, weight_type=weight_type, - enable_dynamic_shape=args.enable_dynamic_shape, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - calibration_data=args.calibration_data, - tokenizer_path=args.tokenizer_path, - verbose=args.verbose, - max_seq_len=args.max_seq_length, - max_context_len=args.max_context_length, - input_prune_map_path=args.input_prune_map, - output_prune_map_path=args.output_prune_map, - metadata_str=args.metadata, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + calibration_data=llm_config.quantization.calibration_data, + tokenizer_path=llm_config.base.tokenizer_path, + verbose=llm_config.debug.verbose, + max_seq_len=llm_config.export.max_seq_length, + max_context_len=llm_config.export.max_context_length, + input_prune_map_path=llm_config.model.input_prune_map, + output_prune_map_path=llm_config.model.output_prune_map, + metadata_str=llm_config.base.metadata, dtype_override=dtype_override, args=args, ) @@ -710,64 +722,64 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( dtype_override=dtype_override, - checkpoint=args.checkpoint, + checkpoint=llm_config.base.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore - tokenizer_path=args.tokenizer_path, - use_spin_quant=args.use_spin_quant, - embedding_quantize=args.embedding_quantize, - use_shared_embedding=args.use_shared_embedding, - quantization_mode=args.quantization_mode, - group_size=args.group_size, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - expand_rope_table=args.expand_rope_table, + tokenizer_path=llm_config.base.tokenizer_path, + use_spin_quant=llm_config.quantization.use_spin_quant, + embedding_quantize=llm_config.quantization.embedding_quantize, + use_shared_embedding=llm_config.model.use_shared_embedding, + quantization_mode=llm_config.quantization.qmode, + group_size=llm_config.quantization.group_size, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + expand_rope_table=llm_config.model.expand_rope_table, use_custom_sdpa_with_attention_mask=getattr( - args, "use_custom_sdpa_with_attention_mask", False + llm_config.model, "use_custom_sdpa_with_attention_mask", False ), - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - quantize_kv_cache=args.quantize_kv_cache, - use_kv_cache=args.use_kv_cache, - qnn=args.qnn, - use_qnn_sha=args.use_qnn_sha, - optimized_rotation_path=args.optimized_rotation_path, - mps=args.mps, - coreml=args.coreml, - coreml_ios=args.coreml_ios, - vulkan=args.vulkan, - use_qat=args.use_qat, - use_lora=args.use_lora, - preq_mode=args.preq_mode, - preq_group_size=args.preq_group_size, - preq_embedding_quantize=args.preq_embedding_quantize, - local_global_attention=args.local_global_attention, + use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, + quantize_kv_cache=llm_config.model.quantize_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, + qnn=llm_config.backend.qnn.enabled, + use_qnn_sha=llm_config.backend.qnn.use_sha, + optimized_rotation_path=llm_config.backend.qnn.optimized_rotation_path, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + coreml_ios=llm_config.backend.coreml.ios, + vulkan=llm_config.backend.vulkan.enabled, + use_qat=llm_config.quantization.use_qat, + use_lora=llm_config.base.use_lora, + preq_mode=llm_config.base.preq_mode, + preq_group_size=llm_config.base.preq_group_size, + preq_embedding_quantize=llm_config.base.preq_embedding_quantize, + local_global_attention=llm_config.model.local_global_attention, ) ) return edge_manager -def get_quantizer_and_quant_params(args): +def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) - quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library) + quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library) quant_dtype = None - if args.qnn and args.pt2e_quantize: + if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) - if args.coreml and args.pt2e_quantize: + if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" - coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) + coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(coreml_quantizer) - if args.vulkan and args.pt2e_quantize: + if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 ), "Should not enable both vulkan and other quantizers" - vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) + vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -790,28 +802,32 @@ def _qmode_type(value): ) -def _validate_args(args): +def _validate_args(llm_config): """ TODO: Combine all the backends under --backend args """ - if args.max_context_length < args.max_seq_length: + if llm_config.export.max_context_length < llm_config.export.max_seq_length: raise ValueError( - f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." + f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." ) - if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn): + if llm_config.model.enable_dynamic_shape and ( + llm_config.backend.coreml.enabled + or llm_config.backend.mps.enabled + or llm_config.backend.qnn.enabled + ): raise ValueError( "Dynamic shape is not supported with coreml, MPS or qnn backends." " Please use --disable_dynamic_shape." ) - if args.num_sharding > 0 and not args.qnn: + if llm_config.backend.qnn.num_sharding > 0 and not llm_config.backend.qnn.enabled: raise ValueError("Model shard is only supported with qnn backend now.") - if args.use_shared_embedding: + if llm_config.model.use_shared_embedding: if not ( - args.embedding_quantize is not None - and args.embedding_quantize.startswith("torchao:") + llm_config.quantization.embedding_quantize is not None + and llm_config.quantization.embedding_quantize.startswith("torchao:") ): raise ValueError( "Shared embedding is only supported with torchao quantization." @@ -1039,28 +1055,30 @@ def _to_edge_and_lower_llama( # noqa: C901 return builder -def _export_llama(args) -> LLMEdgeManager: # noqa: C901 - _validate_args(args) +def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901 + _validate_args(llm_config) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) additional_passes = [] - if args.model in TORCHTUNE_DEFINED_MODELS: + if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(args).export() + builder_exported = _prepare_for_llama_export(llm_config, args).export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname - if args.export_only: + if llm_config.export.export_only: exit() if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: - # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False - args.xnnpack = True + # Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False + llm_config.backend.xnnpack.enabled = True - if args.xnnpack: + if llm_config.backend.xnnpack.enabled: builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, @@ -1068,9 +1086,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - xnnpack_extended_ops=args.xnnpack_extended_ops, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) else: builder = _to_edge_and_lower_llama( @@ -1080,33 +1098,33 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - vulkan=args.vulkan, - mps=args.mps, - coreml=args.coreml, - qnn=args.qnn, - dtype_override=args.dtype_override, - enable_dynamic_shape=args.enable_dynamic_shape, - use_kv_cache=args.use_kv_cache, - embedding_quantize=args.embedding_quantize, - pt2e_quantize=args.pt2e_quantize, - coreml_ios=args.coreml_ios, - coreml_quantize=args.coreml_quantize, - coreml_compute_units=args.coreml_compute_units, - use_qnn_sha=args.use_qnn_sha, - num_sharding=args.num_sharding, - soc_model=args.soc_model, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + vulkan=llm_config.backend.vulkan.enabled, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + qnn=llm_config.backend.qnn.enabled, + dtype_override=llm_config.model.dtype_override, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + use_kv_cache=llm_config.model.use_kv_cache, + embedding_quantize=llm_config.quantization.embedding_quantize, + pt2e_quantize=llm_config.quantization.pt2e_quantize, + coreml_ios=llm_config.backend.coreml.ios, + coreml_quantize=llm_config.backend.coreml.quantize, + coreml_compute_units=llm_config.backend.coreml.compute_units, + use_qnn_sha=llm_config.backend.qnn.use_sha, + num_sharding=llm_config.backend.qnn.num_sharding, + soc_model=llm_config.backend.qnn.soc_model, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) - if args.profile_memory: + if llm_config.debug.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") if builder.dtype == DType.fp16: modelname = f"{modelname}_h" - if args.output_name: - modelname = args.output_name + if llm_config.export.output_name: + modelname = llm_config.export.output_name if modelname.endswith(".pte"): output_file = modelname modelname = modelname[:-4] @@ -1179,6 +1197,8 @@ def _load_llama_model( output_prune_map_path: Optional[str] = None, metadata_str: Optional[str] = None, dtype_override: Optional[DType] = None, + use_qnn: bool = False, + export_only: bool = False, args, ) -> "LLMEdgeManager": """ @@ -1239,8 +1259,8 @@ def _load_llama_model( calibration_seq_length=calibration_seq_length, calibration_data=calibration_data, tokenizer_path=tokenizer_path, - use_legacy_export=args.qnn, - save_exported_program=args.export_only, + use_legacy_export=use_qnn, + save_exported_program=export_only, verbose=verbose, metadata=_load_llama_model_metadata( weight_type, diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index b94adb5fa0c..75223dc35a0 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -7,6 +7,7 @@ import unittest from executorch.devtools.backend_debug import get_delegation_info +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( _export_llama, build_args_parser, @@ -34,13 +35,20 @@ def test_has_expected_ops_and_op_counts(self): # we cannot test quantization args in this way # since quantization requires promoting meta tensors # to device=cpu, which requires real weights. + + llm_config = LlmConfig() + llm_config.model.use_sdpa_with_kv_cache = True + llm_config.model.use_kv_cache = True + llm_config.debug.verbose = True + + # We still need args for backward compatibility during transition parser = build_args_parser() args = parser.parse_args([]) args.use_sdpa_with_kv_cache = True args.use_kv_cache = True args.verbose = True - builder = _export_llama(args) + builder = _export_llama(llm_config, args) graph_module = builder.edge_manager.exported_program().graph_module delegation_info = get_delegation_info(graph_module)