diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index d0a18d88b9d..c11ff478e6f 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -22,6 +22,7 @@ TosaPipelineMI, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( build_args_parser, get_llama_model, @@ -89,8 +90,9 @@ def prepare_model(self): ] parser = build_args_parser() args = parser.parse_args(args) + llm_config = LlmConfig.from_args(args) - llama_model, llama_inputs, llama_meta = get_llama_model(args) + llama_model, llama_inputs, llama_meta = get_llama_model(llm_config) return llama_model, llama_inputs, llama_meta diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 2fc67bcca0e..5ccbc987b4d 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -20,6 +20,7 @@ serialize_from_bundled_program_to_flatbuffer, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -131,28 +132,24 @@ def parse_args(): return args -def get_model_config(args): - model_config = {} - model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0] - model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1] - - if args.model_name == "llama2": - if args.checkpoint: - model_config["checkpoint"] = args.checkpoint - if args.params: - model_config["params"] = args.params - model_config["use_kv_cache"] = True - return model_config - - if __name__ == "__main__": # noqa: C901 args = parse_args() if args.model_name not in MODEL_NAME_TO_MODEL: raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.") - model_config = get_model_config(args) - model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config) + llm_config = LlmConfig() + if args.model_name == "llama2": + if args.checkpoint: + llm_config.base.checkpoint = args.checkpoint + if args.params: + llm_config.base.params = args.params + llm_config.model.use_kv_cache = True + model, example_inputs, _, _ = EagerModelFactory.create_model( + module_name=MODEL_NAME_TO_MODEL[args.model_name][0], + model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], + llm_config=llm_config, + ) model = model.eval() diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index b51e164d483..86b7e957628 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -67,6 +67,7 @@ runtime.python_library( "//caffe2:torch", "//executorch/examples/models:model_base", "//executorch/examples/models/llama:llama_transformer", + "//executorch/examples/models/llama/config:llm_config", "//executorch/examples/models:checkpoint", ], ) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 47b13df52e0..20ba6dbaa9f 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -164,6 +164,7 @@ def _model_call(self, inps): def gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, + llm_config=None, ): """ Generates a wrapper interface around the provided model and tokenizer for @@ -172,7 +173,13 @@ def gen_eval_wrapper( Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore + # If llm_config is not provided, convert args to llm_config + if llm_config is None: + from executorch.examples.models.llama.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore @@ -182,7 +189,7 @@ def gen_eval_wrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, - max_seq_length=args.max_seq_length, # pyre-ignore + max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings @@ -191,12 +198,14 @@ def gen_eval_wrapper( tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. - max_seq_length=args.max_seq_length - 1, + max_seq_length=llm_config.export.max_seq_length - 1, ) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) @@ -208,9 +217,9 @@ def gen_eval_wrapper( return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, # pyre-ignore - enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch @@ -234,8 +243,8 @@ def gen_eval_wrapper( return EagerEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, ) @@ -296,12 +305,16 @@ def eval_llama( model_name: str, args: argparse.ArgumentParser, ) -> None: + # Convert args to LlmConfig + from executorch.examples.models.llama.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + # Generate the eval wrapper - eval_wrapper = gen_eval_wrapper(model_name, args) + eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files - # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` if args.tasks and "mmlu" in args.tasks: import datasets @@ -312,8 +325,8 @@ def eval_llama( eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, - num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` - limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` + num_fewshot=args.num_fewshot, + limit=args.limit, ) for task, res in eval_results["results"].items(): @@ -326,19 +339,24 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py """ - assert args.use_attention_sink is not None # pyre-ignore [16] - assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] - attention_sink_params = args.use_attention_sink.split(",") + # Convert args to LlmConfig + from executorch.examples.models.llama.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + assert llm_config.model.use_attention_sink is not None + assert args.attention_sink_eval_tokens > 0 + attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] + assert llm_config.export.max_seq_length == sink_size + window_size device = "cuda" if torch.cuda.is_available() else "cpu" - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") @@ -347,7 +365,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: - for text in eval_data["text"]: # pyre-ignore [16] + for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py index 73eca7e2a5a..4871de00e25 100644 --- a/examples/models/llama/export_llama_hydra.py +++ b/examples/models/llama/export_llama_hydra.py @@ -13,6 +13,7 @@ from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import export_llama from hydra.core.config_store import ConfigStore +from omegaconf import OmegaConf cs = ConfigStore.instance() cs.store(name="llm_config", node=LlmConfig) @@ -20,7 +21,7 @@ @hydra.main(version_base=None, config_name="llm_config") def main(llm_config: LlmConfig) -> None: - export_llama(llm_config) + export_llama(OmegaConf.to_object(llm_config)) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 12406cc762e..1f055d65822 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -52,7 +52,6 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace -from omegaconf.dictconfig import DictConfig from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -153,7 +152,8 @@ def build_model( argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}" parser = build_args_parser() args = parser.parse_args(shlex.split(argString)) - return export_llama(args) + llm_config = LlmConfig.from_args(args) + return export_llama(llm_config) def parse_list_of_ints(s): @@ -571,54 +571,53 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: def export_llama( - export_options: Union[argparse.Namespace, DictConfig], + export_options: Union[argparse.Namespace, LlmConfig], ) -> str: if isinstance(export_options, argparse.Namespace): # Legacy CLI. - args = export_options - llm_config = LlmConfig.from_args(export_options) # noqa: F841 - elif isinstance(export_options, DictConfig): + llm_config = LlmConfig.from_args(export_options) + elif isinstance(export_options, LlmConfig): # Hydra CLI. - llm_config = export_options # noqa: F841 + llm_config = export_options else: raise ValueError( "Input to export_llama must be either of type argparse.Namespace or LlmConfig" ) - # TODO: refactor rest of export_llama to use llm_config instead of args. - # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. - if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: - repo_id = HUGGING_FACE_REPO_IDS[args.model] - if args.model == "qwen2_5": + model_name = llm_config.base.model_class + if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS: + repo_id = HUGGING_FACE_REPO_IDS[model_name] + if model_name == "qwen2_5": from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21] convert_weights, ) - elif args.model.startswith("qwen3"): + elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "phi_4_mini": + elif model_name == "phi_4_mini": from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "smollm2": + elif model_name == "smollm2": from executorch.examples.models.smollm2 import ( # pyre-ignore[21] convert_weights, ) else: raise ValueError( - f"Converting weights to meta format for {args.model} is not yet supported" + f"Converting weights to meta format for {model_name} is not yet supported" ) - args.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + llm_config.base.checkpoint = checkpoint - if args.profile_path is not None: + if llm_config.debug.profile_path is not None: try: from executorch.util.python_profiler import CProfilerFlameGraph - with CProfilerFlameGraph(args.profile_path): - builder = _export_llama(args) + with CProfilerFlameGraph(llm_config.debug.profile_path): + builder = _export_llama(llm_config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" @@ -629,14 +628,14 @@ def export_llama( ) return "" else: - builder = _export_llama(args) + builder = _export_llama(llm_config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" return filename -def _prepare_for_llama_export(args) -> LLMEdgeManager: +def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: """ Helper function for export_llama. Loads the model from checkpoint and params, and sets up a LLMEdgeManager with initial transforms and dtype conversion. @@ -644,41 +643,30 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: Returns a LLMEdgeManager prior to calling export_to_edge with quantizers """ # load model from checkpoint and params.json - checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None + checkpoint_path = ( + canonical_path(llm_config.base.checkpoint) + if llm_config.base.checkpoint + else None + ) checkpoint_dir = ( - canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None + canonical_path(llm_config.base.checkpoint_dir) + if llm_config.base.checkpoint_dir + else None ) - params_path = canonical_path(args.params) if args.params else None - output_dir_path = canonical_path(args.output_dir, dir=True) - weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA - - # Convert dtype override string arg to actual type. - dtype_override = DType[args.dtype_override] - - edge_manager = _load_llama_model( - args.model, - checkpoint=checkpoint_path, - checkpoint_dir=checkpoint_dir, - params_path=params_path, - use_kv_cache=args.use_kv_cache, - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - generate_full_logits=args.generate_full_logits, - weight_type=weight_type, - enable_dynamic_shape=args.enable_dynamic_shape, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - calibration_data=args.calibration_data, - tokenizer_path=args.tokenizer_path, - verbose=args.verbose, - max_seq_len=args.max_seq_length, - max_context_len=args.max_context_length, - input_prune_map_path=args.input_prune_map, - output_prune_map_path=args.output_prune_map, - metadata_str=args.metadata, - dtype_override=dtype_override, - args=args, + params_path = ( + canonical_path(llm_config.base.params) if llm_config.base.params else None ) + output_dir_path = canonical_path(llm_config.export.output_dir, dir=True) + + llm_config.base.checkpoint = checkpoint_path + llm_config.base.checkpoint_dir = checkpoint_dir + llm_config.base.params = params_path + llm_config.export.output_dir = output_dir_path + + # Convert dtype override string to actual type. + dtype_override = DType[llm_config.model.dtype_override] + + edge_manager = _load_llama_model(llm_config) # At this point, the model is loaded in the default fp32. @@ -707,64 +695,64 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( dtype_override=dtype_override, - checkpoint=args.checkpoint, + checkpoint=llm_config.base.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore - tokenizer_path=args.tokenizer_path, - use_spin_quant=args.use_spin_quant, - embedding_quantize=args.embedding_quantize, - use_shared_embedding=args.use_shared_embedding, - quantization_mode=args.quantization_mode, - group_size=args.group_size, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - expand_rope_table=args.expand_rope_table, + tokenizer_path=llm_config.base.tokenizer_path, + use_spin_quant=llm_config.quantization.use_spin_quant, + embedding_quantize=llm_config.quantization.embedding_quantize, + use_shared_embedding=llm_config.model.use_shared_embedding, + quantization_mode=llm_config.quantization.qmode, + group_size=llm_config.quantization.group_size, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + expand_rope_table=llm_config.model.expand_rope_table, use_custom_sdpa_with_attention_mask=getattr( - args, "use_custom_sdpa_with_attention_mask", False + llm_config.model, "use_custom_sdpa_with_attention_mask", False ), - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - quantize_kv_cache=args.quantize_kv_cache, - use_kv_cache=args.use_kv_cache, - qnn=args.qnn, - use_qnn_sha=args.use_qnn_sha, - optimized_rotation_path=args.optimized_rotation_path, - mps=args.mps, - coreml=args.coreml, - coreml_ios=args.coreml_ios, - vulkan=args.vulkan, - use_qat=args.use_qat, - use_lora=args.use_lora, - preq_mode=args.preq_mode, - preq_group_size=args.preq_group_size, - preq_embedding_quantize=args.preq_embedding_quantize, - local_global_attention=args.local_global_attention, + use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, + quantize_kv_cache=llm_config.model.quantize_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, + qnn=llm_config.backend.qnn.enabled, + use_qnn_sha=llm_config.backend.qnn.use_sha, + optimized_rotation_path=llm_config.backend.qnn.optimized_rotation_path, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + coreml_ios=llm_config.backend.coreml.ios, + vulkan=llm_config.backend.vulkan.enabled, + use_qat=llm_config.quantization.use_qat, + use_lora=llm_config.base.use_lora, + preq_mode=llm_config.base.preq_mode, + preq_group_size=llm_config.base.preq_group_size, + preq_embedding_quantize=llm_config.base.preq_embedding_quantize, + local_global_attention=llm_config.model.local_global_attention, ) ) return edge_manager -def get_quantizer_and_quant_params(args): +def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) - quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library) + quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library) quant_dtype = None - if args.qnn and args.pt2e_quantize: + if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) - if args.coreml and args.pt2e_quantize: + if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" - coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) + coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(coreml_quantizer) - if args.vulkan and args.pt2e_quantize: + if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 ), "Should not enable both vulkan and other quantizers" - vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) + vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -787,28 +775,28 @@ def _qmode_type(value): ) -def _validate_args(args): - """ - TODO: Combine all the backends under --backend args - """ - - if args.max_context_length < args.max_seq_length: +def _validate_args(llm_config): + if llm_config.export.max_context_length < llm_config.export.max_seq_length: raise ValueError( - f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." + f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." ) - if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn): + if llm_config.model.enable_dynamic_shape and ( + llm_config.backend.coreml.enabled + or llm_config.backend.mps.enabled + or llm_config.backend.qnn.enabled + ): raise ValueError( "Dynamic shape is not supported with coreml, MPS or qnn backends." " Please use --disable_dynamic_shape." ) - if args.num_sharding > 0 and not args.qnn: + if llm_config.backend.qnn.num_sharding > 0 and not llm_config.backend.qnn.enabled: raise ValueError("Model shard is only supported with qnn backend now.") - if args.use_shared_embedding: + if llm_config.model.use_shared_embedding: if not ( - args.embedding_quantize is not None - and args.embedding_quantize.startswith("torchao:") + llm_config.quantization.embedding_quantize is not None + and llm_config.quantization.embedding_quantize.startswith("torchao:") ): raise ValueError( "Shared embedding is only supported with torchao quantization." @@ -1033,28 +1021,30 @@ def _to_edge_and_lower_llama( # noqa: C901 return builder -def _export_llama(args) -> LLMEdgeManager: # noqa: C901 - _validate_args(args) +def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 + _validate_args(llm_config) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) additional_passes = [] - if args.model in TORCHTUNE_DEFINED_MODELS: + if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(args).export() + builder_exported = _prepare_for_llama_export(llm_config).export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname - if args.export_only: + if llm_config.export.export_only: exit() if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: - # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False - args.xnnpack = True + # Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False + llm_config.backend.xnnpack.enabled = True - if args.xnnpack: + if llm_config.backend.xnnpack.enabled: builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, @@ -1062,9 +1052,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - xnnpack_extended_ops=args.xnnpack_extended_ops, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) else: builder = _to_edge_and_lower_llama( @@ -1074,33 +1064,33 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - vulkan=args.vulkan, - mps=args.mps, - coreml=args.coreml, - qnn=args.qnn, - dtype_override=args.dtype_override, - enable_dynamic_shape=args.enable_dynamic_shape, - use_kv_cache=args.use_kv_cache, - embedding_quantize=args.embedding_quantize, - pt2e_quantize=args.pt2e_quantize, - coreml_ios=args.coreml_ios, - coreml_quantize=args.coreml_quantize, - coreml_compute_units=args.coreml_compute_units, - use_qnn_sha=args.use_qnn_sha, - num_sharding=args.num_sharding, - soc_model=args.soc_model, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + vulkan=llm_config.backend.vulkan.enabled, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + qnn=llm_config.backend.qnn.enabled, + dtype_override=llm_config.model.dtype_override, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + use_kv_cache=llm_config.model.use_kv_cache, + embedding_quantize=llm_config.quantization.embedding_quantize, + pt2e_quantize=llm_config.quantization.pt2e_quantize, + coreml_ios=llm_config.backend.coreml.ios, + coreml_quantize=llm_config.backend.coreml.quantize, + coreml_compute_units=llm_config.backend.coreml.compute_units, + use_qnn_sha=llm_config.backend.qnn.use_sha, + num_sharding=llm_config.backend.qnn.num_sharding, + soc_model=llm_config.backend.qnn.soc_model, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) - if args.profile_memory: + if llm_config.debug.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") if builder.dtype == DType.fp16: modelname = f"{modelname}_h" - if args.output_name: - modelname = args.output_name + if llm_config.export.output_name: + modelname = llm_config.export.output_name if modelname.endswith(".pte"): output_file = modelname modelname = modelname[:-4] @@ -1150,31 +1140,7 @@ def _load_llama_model_metadata( return metadata -def _load_llama_model( - modelname: str = "llama3", - *, - checkpoint: Optional[str] = None, - checkpoint_dir: Optional[str] = None, - params_path: Optional[str] = None, - use_kv_cache: bool = False, - use_sdpa_with_kv_cache: bool = False, - generate_full_logits: bool = False, - weight_type: WeightType = WeightType.LLAMA, - enable_dynamic_shape: bool = False, - calibration_tasks: Optional[List[str]] = None, - calibration_limit: Optional[int] = None, - calibration_seq_length: Optional[int] = None, - calibration_data: Optional[str] = None, - tokenizer_path: Optional[str] = None, - verbose: bool = False, - max_seq_len: int = 128, - max_context_len: int = 128, - input_prune_map_path: Optional[str] = None, - output_prune_map_path: Optional[str] = None, - metadata_str: Optional[str] = None, - dtype_override: Optional[DType] = None, - args, -) -> "LLMEdgeManager": +def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": """ A helper util that builds a Llama2 model. It returns a LLMEdgeManager that can help further lower the model to ExecuTorch. @@ -1182,6 +1148,7 @@ def _load_llama_model( An instance of LLMEdgeManager which contains the eager mode model. """ + modelname = llm_config.base.model_class if modelname in EXECUTORCH_DEFINED_MODELS: module_name = "llama" model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. @@ -1194,53 +1161,40 @@ def _load_llama_model( else: raise ValueError(f"{modelname} is not a valid Llama model.") - torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None - model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( EagerModelFactory.create_model( module_name, model_class_name, - checkpoint=checkpoint, - checkpoint_dir=checkpoint_dir, - params=params_path, - use_kv_cache=use_kv_cache, - use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, - generate_full_logits=generate_full_logits, - fairseq2=weight_type == WeightType.FAIRSEQ2, - max_seq_len=max_seq_len, - max_context_len=max_context_len, - enable_dynamic_shape=enable_dynamic_shape, - input_prune_map_path=input_prune_map_path, - output_prune_map_path=output_prune_map_path, - dtype=torch_dtype, - args=args, + llm_config=llm_config, ) ) + # Convert dtype override string to actual type. + dtype_override = DType[llm_config.model.dtype_override] return LLMEdgeManager( model=model, modelname=modelname, max_seq_len=model.max_seq_len, # type: ignore dtype=dtype_override, - use_kv_cache=use_kv_cache, - generate_full_logits=generate_full_logits, + use_kv_cache=llm_config.model.use_kv_cache, + generate_full_logits=llm_config.debug.generate_full_logits, example_inputs=example_inputs, example_kwarg_inputs=example_kwarg_inputs, dynamic_shapes=dynamic_shapes, - enable_dynamic_shape=enable_dynamic_shape, - calibration_tasks=calibration_tasks, - calibration_limit=calibration_limit, - calibration_seq_length=calibration_seq_length, - calibration_data=calibration_data, - tokenizer_path=tokenizer_path, - use_legacy_export=args.qnn, - save_exported_program=args.export_only, - verbose=verbose, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + calibration_data=llm_config.quantization.calibration_data, + tokenizer_path=llm_config.base.tokenizer_path, + use_legacy_export=llm_config.backend.qnn.enabled, + save_exported_program=llm_config.export.export_only, + verbose=llm_config.debug.verbose, metadata=_load_llama_model_metadata( - weight_type, - use_kv_cache, - use_sdpa_with_kv_cache, - enable_dynamic_shape, + WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA, + llm_config.model.use_kv_cache, + llm_config.model.use_sdpa_with_kv_cache, + llm_config.model.enable_dynamic_shape, # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got # `Union[Tensor, Module]`. model.max_seq_len, @@ -1253,7 +1207,7 @@ def _load_llama_model( # pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor, # Module]`. model.vocab_size, - metadata_str, + llm_config.base.metadata, ), ) @@ -1470,9 +1424,9 @@ def _get_source_transforms( # noqa return transforms -def get_llama_model(args): - _validate_args(args) - e_mgr = _prepare_for_llama_export(args) +def get_llama_model(llm_config: LlmConfig): + _validate_args(llm_config) + e_mgr = _prepare_for_llama_export(llm_config) model = ( e_mgr.model.eval().to(device="cuda") if torch.cuda.is_available() diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index d6400c29db8..ec9646be6f4 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -8,7 +8,7 @@ import json import os -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import torch from executorch.examples.models.checkpoint import ( @@ -16,6 +16,7 @@ get_default_model_resource_dir, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope @@ -36,26 +37,24 @@ def convert_to_llama_checkpoint(**kwargs): class Llama2Model(EagerModelBase): - def __init__(self, **kwargs): + def __init__(self, llm_config: Optional[LlmConfig] = None): resource_dir = get_default_model_resource_dir(__file__) - # Use single checkpoint file. - checkpoint_path = kwargs.get("checkpoint", None) - # Check if checkpoint_dir was provided for a sharded checkpoint. - checkpoint_dir = kwargs.get("checkpoint_dir", None) + self.llm_config = llm_config if llm_config else LlmConfig() - # Params file. - params_path = kwargs.get("params", None) + checkpoint_path = self.llm_config.base.checkpoint + checkpoint_dir = self.llm_config.base.checkpoint_dir + params_path = self.llm_config.base.params - self.use_kv_cache = kwargs.get("use_kv_cache", False) - self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) - self.generate_full_logits = kwargs.get("generate_full_logits", False) - self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) - self.input_prune_map_path = kwargs.get("input_prune_map_path", None) - self.output_prune_map_path = kwargs.get("output_prune_map_path", None) - self.max_seq_len = kwargs.get("max_seq_len", 128) - self.max_context_len = kwargs.get("max_context_len", 128) - self.args = kwargs.get("args", None) + self.use_kv_cache = self.llm_config.model.use_kv_cache + self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache + self.generate_full_logits = self.llm_config.debug.generate_full_logits + self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape + self.input_prune_map_path = self.llm_config.model.input_prune_map + self.output_prune_map_path = self.llm_config.model.output_prune_map + self.max_seq_len = self.llm_config.export.max_seq_length + self.max_context_len = self.llm_config.export.max_context_length + self.verbose = self.llm_config.debug.verbose assert ( self.max_context_len >= self.max_seq_len @@ -99,7 +98,7 @@ def __init__(self, **kwargs): checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) # If given checkpoint is fairseq, convert to llama checkpoint. - fairseq2_checkpoint = kwargs.get("fairseq2", False) + fairseq2_checkpoint = self.llm_config.base.fairseq2 if fairseq2_checkpoint: print("Using fairseq2 checkpoint") checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) @@ -158,13 +157,14 @@ def __init__(self, **kwargs): if model_args.use_scaled_rope: # Older models don't have use_scaled_rope configuration - assert self.args.model not in ["llama2", "stories110m"] + model_name = str(self.llm_config.base.model_class) + assert model_name not in ["llama2", "stories110m"] # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor - if self.args.model not in ["llama3", "llama3_1"]: + if model_name not in ["llama3", "llama3_1"]: model_args.rope_scale_factor = 32 - if kwargs.get("verbose", False): + if self.verbose: print("============= weights ================") print("{key} : {weights.numel()} : {weights.size()}") for key, weights in checkpoint.items(): @@ -196,7 +196,7 @@ def __init__(self, **kwargs): self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime( self.model_ ) - elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: + elif self.llm_config.quantization.use_spin_quant: print("Using SPIN quantization.") self._transform_for_pre_quantization(checkpoint, model_args) @@ -205,11 +205,12 @@ def __init__(self, **kwargs): ) sanitize_checkpoint_from_pre_quantization(checkpoint) - elif hasattr(self.args, "use_qat") and self.args.use_qat: + elif self.llm_config.quantization.use_qat: print("Using QAT quantization.") self._transform_for_pre_quantization(checkpoint, model_args) - if hasattr(self.args, "use_lora") and self.args.use_lora: - assert model_args.lora_args["rank"] == self.args.use_lora + if self.llm_config.base.use_lora: + lora_rank = self.llm_config.base.use_lora + assert model_args.lora_args["rank"] == lora_rank from .source_transformation.lora import ( transform_linear_for_lora_after_quantization, ) @@ -217,7 +218,7 @@ def __init__(self, **kwargs): self.model_ = transform_linear_for_lora_after_quantization( self.model_, checkpoint, - self.args.use_lora, + lora_rank, ) from .source_transformation.pre_quantization import ( @@ -226,16 +227,16 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) - if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink: + if self.llm_config.model.use_attention_sink: from .source_transformation.attention_sink import enable_attention_sink - attention_sink_params = self.args.use_attention_sink.split(",") + attention_sink_params = self.llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.args.max_context_length == sink_size + window_size + assert self.llm_config.export.max_context_length == sink_size + window_size self.model_ = enable_attention_sink( module=self.model_, @@ -278,7 +279,7 @@ def __init__(self, **kwargs): f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match." ) if unexpected: - if kwargs.get("verbose", False): + if self.verbose: print(f"Unexpected keys: {unexpected}") # Prune the input layer if input_prune_map is provided @@ -326,20 +327,22 @@ def get_example_inputs_kvcache_sdpa(self): ) def _transform_for_pre_quantization(self, checkpoint, model_args): - assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" - assert self.args.preq_mode in [ + assert self.llm_config.base.preq_mode, "preq_mode must be specified" + assert self.llm_config.base.preq_mode in [ "8da4w", "8da4w_output_8da8w", - ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." - assert hasattr( - self.args, "preq_group_size" - ), "preq_group_size must be specified" - assert hasattr(self.args, "dtype_override"), "dtype_override must be specified" + ], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant." + assert self.llm_config.base.preq_group_size, "preq_group_size must be specified" + assert self.llm_config.model.dtype_override, "dtype_override must be specified" + from .source_transformation.pre_quantization import ( transform_linear_for_pre_quantization, ) - assert self.args.preq_group_size == model_args.quantization_args["group_size"] + assert ( + self.llm_config.base.preq_group_size + == model_args.quantization_args["group_size"] + ) mapping = { "fp32": torch.float32, @@ -348,7 +351,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): } # Transform the output layer first if needed. - if self.args.preq_mode == "8da4w_output_8da8w": + if self.llm_config.base.preq_mode == "8da4w_output_8da8w": from .source_transformation.pre_quantization import ( transform_output_linear_for_pre_quantization, ) @@ -356,20 +359,20 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, - dtype=mapping[self.args.dtype_override], + dtype=mapping[self.llm_config.model.dtype_override], ) self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, - self.args.preq_group_size, - mapping[self.args.dtype_override], + self.llm_config.base.preq_group_size, + mapping[self.llm_config.model.dtype_override], ) embedding_bit_width, embedding_group_size = None, None - if hasattr(self.args, "preq_embedding_quantize"): + if self.llm_config.base.preq_embedding_quantize: embedding_bit_width, embedding_group_size = ( - self.args.preq_embedding_quantize.split(",") + self.llm_config.base.preq_embedding_quantize.split(",") ) from .source_transformation.pre_quantization import ( transform_embedding_for_pre_quantization, @@ -387,7 +390,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, - mapping[self.args.dtype_override], + mapping[self.llm_config.model.dtype_override], int(embedding_bit_width), embedding_group_size, ) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 0b842a8f976..c55ad0eea28 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -10,6 +10,7 @@ import torch +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser, @@ -23,19 +24,24 @@ class EagerLlamaRunner(LlamaRunner): Runs llama in eager mode with provided checkpoint file. """ - def __init__(self, args): - with open(args.params, "r") as f: + def __init__( + self, + llm_config: LlmConfig, + tokenizer_config_path: Optional[str] = None, + use_attention_sink: bool = False, + ): + with open(llm_config.base.params, "r") as f: params = json.loads(f.read()) super().__init__( - tokenizer_path=args.tokenizer_path, - tokenizer_config_path=args.tokenizer_config_path, - max_seq_len=args.max_seq_length, + tokenizer_path=llm_config.base.tokenizer_path, + tokenizer_config_path=tokenizer_config_path, + max_seq_len=llm_config.export.max_seq_length, max_batch_size=1, - use_kv_cache=args.use_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, vocab_size=params["vocab_size"], device="cuda" if torch.cuda.is_available() else "cpu", ) - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) self.model = manager.model.eval().to(device=self.device) def forward( @@ -49,6 +55,7 @@ def forward( def build_args_parser() -> argparse.ArgumentParser: parser = _build_args_parser() + # Runner-specific arguments that aren't part of LlmConfig parser.add_argument( "--prompt", type=str, @@ -89,22 +96,41 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: parser = build_args_parser() args = parser.parse_args() + # Convert args to LlmConfig for model configuration. + llm_config = LlmConfig.from_args(args) + + # Extract runner-specific parameters. + prompt = args.prompt + temperature = args.temperature + show_tokens = args.show_tokens + chat_mode = args.chat + tokenizer_config_path = args.tokenizer_config_path + use_attention_sink = args.use_attention_sink + with torch.no_grad(): - runner = runner_class(args) # pyre-ignore: Missing argument [20] + # Create runner with LlmConfig and separate runner parameters. + runner = runner_class( + llm_config=llm_config, + tokenizer_config_path=tokenizer_config_path, + use_attention_sink=use_attention_sink, + ) + generated_tokens = ( runner.chat_completion( - max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length, - temperature=args.temperature, - show_progress=args.show_tokens, + max_seq_len=( + 1000000 if use_attention_sink else llm_config.export.max_seq_length + ), + temperature=temperature, + show_progress=show_tokens, ) - if args.chat + if chat_mode else runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + prompt=prompt, + temperature=temperature, echo=True, ) ) - if args.show_tokens: + if show_tokens: print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index b94adb5fa0c..f2ac9497604 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -7,6 +7,7 @@ import unittest from executorch.devtools.backend_debug import get_delegation_info +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( _export_llama, build_args_parser, @@ -40,7 +41,8 @@ def test_has_expected_ops_and_op_counts(self): args.use_kv_cache = True args.verbose = True - builder = _export_llama(args) + llm_config = LlmConfig.from_args(args) + builder = _export_llama(llm_config) graph_module = builder.edge_manager.exported_program().graph_module delegation_info = get_delegation_info(graph_module) diff --git a/examples/models/llama3_2_vision/runner/eager.py b/examples/models/llama3_2_vision/runner/eager.py index c5d91013077..5e68a43bf8e 100644 --- a/examples/models/llama3_2_vision/runner/eager.py +++ b/examples/models/llama3_2_vision/runner/eager.py @@ -8,6 +8,7 @@ from typing import Optional import torch +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export from executorch.examples.models.llama.runner.eager import execute_runner @@ -22,18 +23,23 @@ class EagerLlamaRunner(TorchTuneLlamaRunner): Runs llama in eager mode with provided checkpoint file. """ - def __init__(self, args): - with open(args.params, "r") as f: + def __init__( + self, + llm_config: LlmConfig, + tokenizer_config_path: Optional[str] = None, + use_attention_sink: bool = False, + ): + with open(llm_config.base.params, "r") as f: params = json.loads(f.read()) super().__init__( - tokenizer_path=args.tokenizer_path, - max_seq_len=args.max_seq_length, + tokenizer_path=llm_config.base.tokenizer_path, + max_seq_len=llm_config.export.max_seq_length, max_batch_size=1, - use_kv_cache=args.use_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, vocab_size=params["vocab_size"], device="cuda" if torch.cuda.is_available() else "cpu", ) - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) self.model = manager.model.eval().to(device=self.device) def forward( diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 18ef83ee1e4..32b3ff448ac 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -16,8 +16,8 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( - build_args_parser, get_quantizer_and_quant_params, ) from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( @@ -92,32 +92,26 @@ def forward(self, input_pos, embeddings): dynamic_shapes=dynamic_shapes, ) + # Manually set some LlmConfig options. + llm_config = LlmConfig() + llm_config.base.params = "params.json" + llm_config.backend.xnnpack.enabled = True + llm_config.quantization.qmode = "8da4w" + llm_config.quantization.group_size = 128 + llm_config.quantization.embedding_quantize = "4,32" + dtype_override = DType.fp32 - parser = build_args_parser() - args = parser.parse_args( - [ - "-p", - "params.json", - "-X", - "-qmode", - "8da4w", - "--group_size", - "128", - "--embedding-quantize", - "4,32", - ] - ) quant_transform = get_quant_weight_transform( - quantization_mode=args.quantization_mode, - group_size=args.group_size, + quantization_mode=llm_config.quantization.qmode, + group_size=llm_config.quantization.group_size, computation_dtype=dtype_override, - checkpoint_path=args.checkpoint, - tokenizer_path=args.tokenizer_path, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, + checkpoint_path=llm_config.base.checkpoint, + tokenizer_path=llm_config.base.tokenizer_path, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, ) - _, quantizers, _ = get_quantizer_and_quant_params(args) + _, quantizers, _ = get_quantizer_and_quant_params(llm_config) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: source_transforms.append(replace_kv_cache_with_custom_kv_cache) @@ -279,6 +273,20 @@ def get_tokenizer_for_llava_runner(llava_model): t.export("tokenizer.bin") +def create_llava_config_from_args(args): + """ + Create an LlmConfig from command line arguments for LLaVA export + """ + llm_config = LlmConfig() + + llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache + llm_config.export.max_seq_length = args.max_seq_len + llm_config.export.output_name = args.pte_name + llm_config.debug.profile_memory = args.profile_memory + + return llm_config + + def main(): parser = ArgumentParser() parser.add_argument( @@ -311,28 +319,33 @@ def main(): help="Generate chrome trace of activation memory for intermediate tensors.", ) args = parser.parse_args() + + # Create LlmConfig from args + llm_config = create_llava_config_from_args(args) + logging.info( - f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}, max_seq_len: {args.max_seq_len}" + f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}" ) + llava_model = LlavaModel( - use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache, - max_seq_len=args.max_seq_len, + use_sdpa_with_kv_cache_op=llm_config.model.use_sdpa_with_kv_cache, + max_seq_len=llm_config.export.max_seq_length, ) executorch_program = export_all(llava_model) # memory profiling - if args.profile_memory: + if llm_config.debug.profile_memory: for method_name in executorch_program.methods: generate_memory_trace( executorch_program, - f"{args.pte_name}_{method_name}.json", + f"{llm_config.export.output_name}_{method_name}.json", method_name=method_name, ) - with open(args.pte_name, "wb") as f: + with open(llm_config.export.output_name, "wb") as f: executorch_program.write_to_file(f) - logging.info(f"Exported ExecuTorch program to {args.pte_name}") + logging.info(f"Exported ExecuTorch program to {llm_config.export.output_name}") # artifacts if args.with_artifacts: