diff --git a/examples/models/llama/eval_llama.py b/examples/models/llama/eval_llama.py index 7c959d08b9b..9f914ccbc34 100644 --- a/examples/models/llama/eval_llama.py +++ b/examples/models/llama/eval_llama.py @@ -11,6 +11,7 @@ import torch from .eval_llama_lib import ( + _convert_cli_to_config_format, build_args_parser, eval_llama, eval_llama_with_attention_sink, @@ -28,10 +29,11 @@ def main() -> None: args = parser.parse_args() # Overrides this arg, because evaluation requires full logits. args.generate_full_logits = True + config = _convert_cli_to_config_format(args) if args.use_attention_sink: - eval_llama_with_attention_sink(modelname, args) # pyre-ignore + eval_llama_with_attention_sink(modelname, config) # pyre-ignore else: - eval_llama(modelname, args) # pyre-ignore + eval_llama(modelname, config) # pyre-ignore if __name__ == "__main__": diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 6872222861d..88192bfff0c 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -4,35 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import argparse - from typing import Optional, Union import torch - -from datasets import load_dataset from executorch.examples.models.llama.export_llama_lib import ( + _convert_args_to_config, + _prepare_for_llama_export, + build_args_parser as _build_args_parser, get_quantizer_and_quant_params, ) from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken - from executorch.extension.llm.export.builder import LLMEdgeManager from executorch.extension.llm.tokenizer.tokenizer import ( Tokenizer as SentencePieceTokenizer, ) from executorch.extension.llm.tokenizer.utils import get_tokenizer from lm_eval.evaluator import simple_evaluate -from torch.nn import CrossEntropyLoss -from tqdm import tqdm +from omegaconf import DictConfig, OmegaConf from .evaluate.eager_eval import EagerEvalWrapper -from .export_llama_lib import ( - _prepare_for_llama_export, - build_args_parser as _build_args_parser, -) - class GraphModuleEvalWrapper(EagerEvalWrapper): """ @@ -165,7 +157,7 @@ def _model_call(self, inps): def gen_eval_wrapper( model_name: str, - args: argparse.ArgumentParser, + config: DictConfig, ): """ Generates a wrapper interface around the provided model and tokenizer for @@ -174,17 +166,17 @@ def gen_eval_wrapper( Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore + tokenizer = get_tokenizer(config.export.tokenizer_path) # ExecuTorch Binary Evaluation - if (model := args.pte) is not None: # pyre-ignore - if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore + if (model := config.eval.pte) is not None: + if (tokenizer_bin := config.eval.tokenizer_bin) is not None: # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime return ETRunnerEvalWrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, - max_seq_length=args.max_seq_length, # pyre-ignore + max_seq_length=config.sequence.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings @@ -193,12 +185,12 @@ def gen_eval_wrapper( tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. - max_seq_length=args.max_seq_length - 1, + max_seq_length=config.sequence.max_seq_length - 1, ) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(config) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) @@ -210,9 +202,9 @@ def gen_eval_wrapper( return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, # pyre-ignore - enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore + max_seq_length=config.sequence.max_seq_length, + use_kv_cache=config.kv_cache.use_kv_cache, # pyre-ignore + enable_dynamic_shape=config.misc.enable_dynamic_shape, # pyre-ignore ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch @@ -230,18 +222,94 @@ def gen_eval_wrapper( # that is not available in this eval_llama. We save the checkpoint # here for consistency with eval_llama. The accuracy results we # get from eval_llama can be used as a reference to other evaluations. - if args.output_eager_checkpoint_file is not None: # pyre-ignore - torch.save(model, args.output_eager_checkpoint_file) + if config.eval.output_eager_checkpoint_file is not None: # pyre-ignore + torch.save(model, config.eval.output_eager_checkpoint_file) return EagerEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, + max_seq_length=config.sequence.max_seq_length, + use_kv_cache=config.kv_cache.use_kv_cache, + ) + + +def eval_llama( + model_name: str, + config: DictConfig, +) -> None: + # Generate the eval wrapper + eval_wrapper = gen_eval_wrapper(model_name, config) + + # Needed for loading mmlu dataset. + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files + if config.eval.tasks and "mmlu" in config.eval.tasks: + import datasets + + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True + + # Evaluate the model + tasks = ( + None if config.eval.tasks is None else OmegaConf.to_container(config.eval.tasks) + ) + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=tasks, + num_fewshot=config.eval.num_fewshot, + limit=config.eval.limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + + +def eval_llama_with_attention_sink( + model_name: str, + config: DictConfig, +) -> None: + # Generate the eval wrapper + eval_wrapper = gen_eval_wrapper(model_name, config) + + # Needed for loading mmlu dataset. + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files + if config.eval.tasks and "mmlu" in config.eval.tasks: + import datasets + + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=OmegaConf.to_container(config.eval.tasks), + num_fewshot=config.eval.num_fewshot, + limit=config.eval.limit, ) + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + + +def _convert_cli_to_config_format(args) -> DictConfig: + """Convert CLI arguments to config format.""" + # First convert common args using the shared function + config = _convert_args_to_config(args) + + # Add evaluation-specific settings + config.eval = OmegaConf.create() + config.eval.tasks = args.tasks + config.eval.limit = args.limit + config.eval.num_fewshot = args.num_fewshot + config.eval.pte = args.pte + config.eval.tokenizer_bin = args.tokenizer_bin + config.eval.output_eager_checkpoint_file = args.output_eager_checkpoint_file + config.eval.attention_sink_eval_tokens = args.attention_sink_eval_tokens + + return config + def build_args_parser() -> argparse.ArgumentParser: + """Build argument parser for evaluation, extending the export parser with eval-specific args.""" # Start with arg parser from export_llama_lib parser = _build_args_parser() @@ -288,92 +356,7 @@ def build_args_parser() -> argparse.ArgumentParser: help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", ) - # Set of parameters secpific to AttentionSink. + # Set of parameters specific to AttentionSink. parser.add_argument("--attention_sink_eval_tokens", type=int, default=0) return parser - - -def eval_llama( - model_name: str, - args: argparse.ArgumentParser, -) -> None: - # Generate the eval wrapper - eval_wrapper = gen_eval_wrapper(model_name, args) - - # Needed for loading mmlu dataset. - # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files - # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` - if args.tasks and "mmlu" in args.tasks: - import datasets - - datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True - - # Evaluate the model - with torch.no_grad(): - eval_results = simple_evaluate( - model=eval_wrapper, - tasks=args.tasks, - num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` - limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` - ) - - for task, res in eval_results["results"].items(): - print(f"{task}: {res}") - - -def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser): - """ - Evaluate the model's perplexity when AttentionSink is enabled. - - This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py - """ - assert args.use_attention_sink is not None # pyre-ignore [16] - assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] - attention_sink_params = args.use_attention_sink.split(",") - assert len(attention_sink_params) == 3 - sink_size = int(attention_sink_params[0]) - window_size = int(attention_sink_params[1]) - - assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] - - device = "cuda" if torch.cuda.is_available() else "cpu" - manager: LLMEdgeManager = _prepare_for_llama_export(args) - model = manager.model.eval().to(device=device) - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] - - eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - - nlls = [] - loss_fn = CrossEntropyLoss(reduction="none") - progress_bar = tqdm(total=args.attention_sink_eval_tokens) - input_pos = 0 - while input_pos < args.attention_sink_eval_tokens: - for text in eval_data["text"]: # pyre-ignore [16] - tokens = tokenizer.encode(text, bos=False, eos=False) - if len(tokens) <= 0: - continue - with torch.no_grad(): - num_tokens = min( - len(tokens) - 1, args.attention_sink_eval_tokens - input_pos - ) - logits = model( - torch.tensor( - [tokens[:num_tokens]], dtype=torch.int64, device=device - ), - torch.tensor([input_pos], dtype=torch.int64, device=device), - ).squeeze(dim=0) - neg_log_likelihood = loss_fn( - logits, - torch.tensor( - [tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device - ).view(-1), - ) - nlls.append(neg_log_likelihood) - input_pos += num_tokens - progress_bar.update(num_tokens) - if input_pos >= args.attention_sink_eval_tokens: - break - ppl = torch.exp(torch.cat(nlls).mean()) - print(f"Perplexity: {ppl.item()}") - return ppl.item() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 37a4e6952d8..80623ce375a 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -48,6 +48,7 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace +from omegaconf import DictConfig, OmegaConf from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -519,13 +520,127 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama(args) -> str: - if args.profile_path is not None: +def _convert_args_to_config(args: argparse.Namespace) -> DictConfig: + """Convert argparse.Namespace to DictConfig.""" + # Create a dictionary from args + args_dict = {} + + # Add model settings + args_dict["model"] = { + "name": args.model, + "type": "LLAMA" if not args.fairseq2 else "FAIRSEQ2", + "dtype_override": args.dtype_override, + "params": args.params, + } + + # Add export settings + args_dict["export"] = { + "output_dir": args.output_dir, + "checkpoint": args.checkpoint, + "checkpoint_dir": args.checkpoint_dir, + "tokenizer_path": args.tokenizer_path, + "output_name": args.output_name, + "metadata": args.metadata, + "so_library": args.so_library, + "export_only": args.export_only, + } + + # Add sequence settings + args_dict["sequence"] = { + "max_seq_length": args.max_seq_length, + "max_context_length": args.max_context_length, + } + + # Add KV cache settings + args_dict["kv_cache"] = { + "use_kv_cache": args.use_kv_cache, + "quantize_kv_cache": args.quantize_kv_cache, + "use_sdpa_with_kv_cache": args.use_sdpa_with_kv_cache, + } + + # Add quantization settings + args_dict["quantization"] = { + "mode": args.quantization_mode, + "embedding_quantize": args.embedding_quantize, + "pt2e_quantize": args.pt2e_quantize, + "group_size": args.group_size, + "use_spin_quant": args.use_spin_quant, + "use_qat": args.use_qat, + "use_lora": args.use_lora, + "preq_mode": args.preq_mode, + "preq_group_size": args.preq_group_size, + "preq_embedding_quantize": args.preq_embedding_quantize, + } + + # Add calibration settings + args_dict["calibration"] = { + "tasks": args.calibration_tasks, + "limit": args.calibration_limit, + "seq_length": args.calibration_seq_length, + "data": args.calibration_data, + } + + # Add backend settings + args_dict["backend"] = { + "xnnpack": { + "enabled": args.xnnpack, + "extended_ops": args.xnnpack_extended_ops, + }, + "coreml": { + "enabled": args.coreml, + "enable_state": args.coreml_enable_state, + "preserve_sdpa": args.coreml_preserve_sdpa, + "quantize": args.coreml_quantize, + "ios": args.coreml_ios, + "compute_units": args.coreml_compute_units, + }, + "vulkan": { + "enabled": args.vulkan, + }, + "qnn": { + "enabled": args.qnn, + "use_sha": args.use_qnn_sha, + "soc_model": args.soc_model, + "optimized_rotation_path": args.optimized_rotation_path, + }, + "mps": { + "enabled": args.mps, + }, + } + + # Add additional settings + args_dict["misc"] = { + "profile_memory": args.profile_memory, + "profile_path": args.profile_path, + "enable_dynamic_shape": args.enable_dynamic_shape, + "num_sharding": args.num_sharding, + "expand_rope_table": args.expand_rope_table, + "generate_etrecord": args.generate_etrecord, + "generate_full_logits": args.generate_full_logits, + "use_attention_sink": args.use_attention_sink, + "output_prune_map": args.output_prune_map, + "input_prune_map": args.input_prune_map, + "verbose": args.verbose, + } + + # Convert to DictConfig + return OmegaConf.create(args_dict) + + +def export_llama(args: Union[argparse.Namespace, DictConfig]) -> str: + """Export Llama model to flatbuffer format.""" + # Convert args to config if needed + if isinstance(args, argparse.Namespace): + config = _convert_args_to_config(args) + else: + config = args + + if config.misc.profile_path is not None: try: from executorch.util.python_profiler import CProfilerFlameGraph - with CProfilerFlameGraph(args.profile_path): - builder = _export_llama(args) + with CProfilerFlameGraph(config.misc.profile_path): + builder = _export_llama(config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" @@ -536,14 +651,14 @@ def export_llama(args) -> str: ) return "" else: - builder = _export_llama(args) + builder = _export_llama(config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" return filename -def _prepare_for_llama_export(args) -> LLMEdgeManager: +def _prepare_for_llama_export(config: DictConfig) -> LLMEdgeManager: """ Helper function for export_llama. Loads the model from checkpoint and params, and sets up a LLMEdgeManager with initial transforms and dtype conversion. @@ -551,40 +666,51 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: Returns a LLMEdgeManager prior to calling export_to_edge with quantizers """ # load model from checkpoint and params.json - checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None + checkpoint_path = ( + canonical_path(config.export.checkpoint) if config.export.checkpoint else None + ) checkpoint_dir = ( - canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None + canonical_path(config.export.checkpoint_dir) + if config.export.checkpoint_dir + else None + ) + params_path = canonical_path(config.model.params) + output_dir_path = canonical_path(config.export.output_dir, dir=True) + weight_type = ( + WeightType.FAIRSEQ2 if config.model.type == "FAIRSEQ2" else WeightType.LLAMA ) - params_path = canonical_path(args.params) if args.params else None - output_dir_path = canonical_path(args.output_dir, dir=True) - weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA # Convert dtype override string arg to actual type. - dtype_override = DType[args.dtype_override] + dtype_override = DType[config.model.dtype_override] + calibration_tasks = ( + None + if config.calibration.tasks is None + else OmegaConf.to_container(config.calibration.tasks) + ) edge_manager = _load_llama_model( - args.model, + config.model.name, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, - use_kv_cache=args.use_kv_cache, - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - generate_full_logits=args.generate_full_logits, + use_kv_cache=config.kv_cache.use_kv_cache, + use_sdpa_with_kv_cache=config.kv_cache.use_sdpa_with_kv_cache, + generate_full_logits=config.misc.generate_full_logits, weight_type=weight_type, - enable_dynamic_shape=args.enable_dynamic_shape, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - calibration_data=args.calibration_data, - tokenizer_path=args.tokenizer_path, - verbose=args.verbose, - max_seq_len=args.max_seq_length, - max_context_len=args.max_context_length, - input_prune_map_path=args.input_prune_map, - output_prune_map_path=args.output_prune_map, - metadata_str=args.metadata, + enable_dynamic_shape=config.misc.enable_dynamic_shape, + calibration_tasks=calibration_tasks, + calibration_limit=config.calibration.limit, + calibration_seq_length=config.calibration.seq_length, + calibration_data=config.calibration.data, + tokenizer_path=config.export.tokenizer_path, + verbose=config.misc.verbose, + max_seq_len=config.sequence.max_seq_length, + max_context_len=config.sequence.max_context_length, + input_prune_map_path=config.misc.input_prune_map, + output_prune_map_path=config.misc.output_prune_map, + metadata_str=config.export.metadata, dtype_override=dtype_override, - args=args, + args=config, ) # At this point, the model is loaded in the default fp32. @@ -613,37 +739,37 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}") edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( - modelname=args.model, + modelname=config.model.name, dtype_override=dtype_override, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), - args=args, + config=config, ) ) return edge_manager -def get_quantizer_and_quant_params(args): +def get_quantizer_and_quant_params(config: DictConfig): pt2e_quant_params = get_pt2e_quantization_params( - args.pt2e_quantize, args.quantization_mode + config.quantization.pt2e_quantize, config.quantization.mode ) - quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library) + quantizers = get_pt2e_quantizers(pt2e_quant_params, config.export.so_library) quant_dtype = None - if args.qnn and args.pt2e_quantize: + if config.backend.qnn.enabled and config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - args.pt2e_quantize, args.quantization_mode + config.quantization.pt2e_quantize, config.quantization.mode ) quantizers.append(qnn_quantizer) - if args.coreml and args.pt2e_quantize: + if config.backend.coreml.enabled and config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" - coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) + coreml_quantizer = get_coreml_quantizer(config.quantization.pt2e_quantize) quantizers.append(coreml_quantizer) - if args.vulkan and args.pt2e_quantize: + if config.backend.vulkan.enabled and config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 ), "Should not enable both vulkan and other quantizers" - vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) + vulkan_quantizer = get_vulkan_quantizer(config.quantization.pt2e_quantize) quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -705,8 +831,8 @@ def _to_edge_and_lower_llama_xnnpack( pt2e_quant_params, quantizers, quant_dtype, - args, -) -> LLMEdgeManager: # noqa: C901 + config: DictConfig, +) -> LLMEdgeManager: partitioners = [] # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled @@ -714,7 +840,7 @@ def _to_edge_and_lower_llama_xnnpack( modelname = f"xnnpack_dq_{modelname}" - if args.xnnpack_extended_ops: + if config.backend.xnnpack.extended_ops: partitioners.append( get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) ) @@ -725,7 +851,7 @@ def _to_edge_and_lower_llama_xnnpack( logging.info(f"--> {partitioner.__class__.__name__}") # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower(). - if args.generate_etrecord: + if config.misc.generate_etrecord: raise NotImplementedError( "export_llama does not support XNNPack and generating ETRecord at the moment." ) @@ -733,20 +859,20 @@ def _to_edge_and_lower_llama_xnnpack( builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( partitioners ) - if args.verbose: + if config.misc.verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) return builder.to_executorch(passes=additional_passes) -def _to_edge_and_lower_llama( # noqa: C901 +def _to_edge_and_lower_llama( builder_exported, modelname, additional_passes, pt2e_quant_params, quantizers, quant_dtype, - args, + config: DictConfig, ): builder_exported_to_edge = builder_exported.pt2e_quantize( quantizers @@ -754,11 +880,11 @@ def _to_edge_and_lower_llama( # noqa: C901 # to_backend partitioners = [] - if args.vulkan: + if config.backend.vulkan.enabled: partitioners.append( get_vulkan_partitioner( - args.dtype_override, - args.enable_dynamic_shape, + config.model.dtype_override, + config.misc.enable_dynamic_shape, ) ) # Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK @@ -770,27 +896,30 @@ def _to_edge_and_lower_llama( # noqa: C901 # Need to remove asserts from the graph to prevent graph breaks remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) - if args.mps: - partitioners.append(get_mps_partitioner(args.use_kv_cache)) + if config.backend.mps.enabled: + partitioners.append(get_mps_partitioner(config.kv_cache.use_kv_cache)) modelname = f"mps_{modelname}" - if args.coreml: + if config.backend.coreml.enabled: coreml_partitioner = get_coreml_partitioner( - args.coreml_ios, - args.embedding_quantize, - args.pt2e_quantize, - args.coreml_quantize, - args.coreml_compute_units, + config.backend.coreml.ios, + config.quantization.embedding_quantize, + config.quantization.pt2e_quantize, + config.backend.coreml.quantize, + config.backend.coreml.compute_units, ) partitioners.append(coreml_partitioner) modelname = f"coreml_{modelname}" - if args.qnn: + if config.backend.qnn.enabled: from executorch.extension.llm.custom_ops import model_sharding partitioners.append( get_qnn_partitioner( - args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model + config.kv_cache.use_kv_cache, + config.quantization.pt2e_quantize, + config.misc.num_sharding, + config.backend.qnn.soc_model, ) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` @@ -798,11 +927,11 @@ def _to_edge_and_lower_llama( # noqa: C901 _transform(builder_exported_to_edge.edge_manager.exported_program()) - if args.num_sharding > 0: + if config.misc.num_sharding > 0: model_sharding.split_graph( builder_exported_to_edge.edge_manager.exported_program(), builder_exported_to_edge.metadata["get_n_layers"], - shares=args.num_sharding, + shares=config.misc.num_sharding, ) # pyre-ignore @@ -811,7 +940,7 @@ def _to_edge_and_lower_llama( # noqa: C901 ) atten = builder_exported_to_edge.model.layers[0].attention - if args.use_qnn_sha: + if config.backend.qnn.use_sha: cache_shape = torch.Size( (atten.max_batch_size, atten.max_context_len, atten.head_dim) ) @@ -833,7 +962,7 @@ def _to_edge_and_lower_llama( # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") - if args.generate_etrecord: + if config.misc.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -841,9 +970,9 @@ def _to_edge_and_lower_llama( # noqa: C901 # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive. edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager) builder = builder_exported_to_edge.to_backend(partitioners) - if args.verbose: + if config.misc.verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) - if args.num_sharding > 0 and args.qnn: + if config.misc.num_sharding > 0 and config.backend.qnn.enabled: from executorch.backends.qualcomm.utils.utils import canonicalize_program canonicalize_program(builder.edge_manager.exported_program()) @@ -862,9 +991,9 @@ def _to_edge_and_lower_llama( # noqa: C901 logging.info("Generated etrecord.bin") else: builder = builder_exported_to_edge.to_backend(partitioners) - if args.verbose: + if config.misc.verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) - if args.num_sharding > 0 and args.qnn: + if config.misc.num_sharding > 0 and config.backend.qnn.enabled: from executorch.backends.qualcomm.utils.utils import canonicalize_program canonicalize_program(builder.edge_manager.exported_program()) @@ -874,28 +1003,28 @@ def _to_edge_and_lower_llama( # noqa: C901 return builder -def _export_llama(args) -> LLMEdgeManager: # noqa: C901 - _validate_args(args) +def _export_llama(config: DictConfig) -> LLMEdgeManager: # noqa: C901 + _validate_config(config) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(config) additional_passes = [] - if args.model in TORCHTUNE_DEFINED_MODELS: + if config.model.name in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(args).export() + builder_exported = _prepare_for_llama_export(config).export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname - if args.export_only: + if config.export.export_only: exit() if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: - # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False - args.xnnpack = True + # Force xnnpack to be true if pt2e_quant_params is not None and config.backend.xnnpack.enabled is False + config.backend.xnnpack.enabled = True - if args.xnnpack: + if config.backend.xnnpack.enabled: builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, @@ -903,7 +1032,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - args, + config, ) else: builder = _to_edge_and_lower_llama( @@ -913,17 +1042,17 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - args, + config, ) - if args.profile_memory: + if config.misc.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") if builder.dtype == DType.fp16: modelname = f"{modelname}_h" - if args.output_name: - modelname = args.output_name + if config.export.output_name: + modelname = config.export.output_name if modelname.endswith(".pte"): output_file = modelname modelname = modelname[:-4] @@ -940,6 +1069,42 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 return builder +def _validate_config(config: DictConfig) -> None: + """Validate configuration values.""" + if config.sequence.max_context_length < config.sequence.max_seq_length: + raise ValueError( + f"max_context_length {config.sequence.max_context_length} must be >= max_seq_len {config.sequence.max_seq_length}. " + "max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. " + "Please use --max_context_length to specify context length." + ) + + if config.misc.enable_dynamic_shape and ( + config.backend.coreml.enabled + or config.backend.mps.enabled + or config.backend.qnn.enabled + ): + raise ValueError( + "Dynamic shape is not supported with coreml, MPS or qnn backends. " + "Please use --disable_dynamic_shape." + ) + + if config.misc.num_sharding > 0 and not config.backend.qnn.enabled: + raise ValueError("Model shard is only supported with qnn backend now.") + + if ( + config.quantization.mode is not None + and config.quantization.mode.startswith("torchao:") + ) or ( + config.quantization.embedding_quantize is not None + and config.quantization.embedding_quantize.startswith("torchao:") + ): + if config.misc.enable_dynamic_shape: + raise ValueError( + "Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape. " + "If you need this feature, please file an issue." + ) + + def _load_llama_model_metadata( weight_type: WeightType, use_kv_cache: bool, @@ -1085,7 +1250,7 @@ def _get_source_transforms( # noqa dtype_override: DType, *, checkpoint_dtype: Optional[DType] = None, - args, + config: DictConfig, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: """ Return a list of functions that transform a graph. @@ -1108,21 +1273,21 @@ def _get_source_transforms( # noqa transforms = [] - if args.use_spin_quant: - if args.use_spin_quant == "cuda": + if config.quantization.use_spin_quant: + if config.quantization.use_spin_quant == "cuda": from .source_transformation.spin_quant import ( inject_fast_hadamard_transform_cuda_for_spin_quant, ) transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant) - elif args.use_spin_quant == "native": + elif config.quantization.use_spin_quant == "native": from .source_transformation.spin_quant import ( inject_fast_hadamard_transform_native_for_spin_quant, ) transforms.append(inject_fast_hadamard_transform_native_for_spin_quant) - if args.quantization_mode: + if config.quantization.mode: """ When this option is selected, it finds all linear layers and transforms into quantized linear equivalent module. @@ -1139,13 +1304,13 @@ def _get_source_transforms( # noqa modelname = f"{modelname}_q" transforms.append( get_quant_weight_transform( - args=args, + config=config, computation_dtype=dtype_override, checkpoint_dtype=checkpoint_dtype, ) ) - if args.embedding_quantize: + if config.quantization.embedding_quantize: """ When this option is selected, it finds all embedding layers and transforms into quantized embedding equivalent module. @@ -1156,64 +1321,66 @@ def _get_source_transforms( # noqa this wil be a no-op. """ modelname = f"{modelname}_e" - transforms.append(get_quant_embedding_transform(args, checkpoint_dtype)) + transforms.append(get_quant_embedding_transform(config, checkpoint_dtype)) - if args.expand_rope_table: + if config.misc.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) - if args.use_sdpa_with_kv_cache: + if config.kv_cache.use_sdpa_with_kv_cache: transforms.append(replace_kv_cache_with_custom_kv_cache) transforms.append(replace_sdpa_with_custom_op) - if args.quantize_kv_cache: - assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" + if config.kv_cache.quantize_kv_cache: + assert ( + config.kv_cache.use_kv_cache + ), "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) - if args.use_kv_cache: - if args.qnn: + if config.kv_cache.use_kv_cache: + if config.backend.qnn.enabled: from executorch.backends.qualcomm.utils.utils import ( convert_linear_to_conv2d, ) - if args.use_qnn_sha: - if args.optimized_rotation_path: + if config.backend.qnn.use_sha: + if config.backend.qnn.optimized_rotation_path: transforms.append(fuse_layer_norms) transforms.append( - get_model_with_r1_r2(args.optimized_rotation_path) + get_model_with_r1_r2(config.backend.qnn.optimized_rotation_path) ) transforms.append(replace_attention_to_attention_sha) transforms.append(replace_causal_mask) transforms.append(replace_rms_norm_with_native_rms_norm) - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. + # pyre-ignore: Module `backends` has no attribute `qualcomm`. transforms.append(convert_linear_to_conv2d) else: transforms.append(replace_kv_cache_with_simple_kv_cache) transforms.append(replace_sdpa_with_flex_sdpa) transforms.append(replace_causal_mask) transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: + if config.backend.qnn.optimized_rotation_path: transforms.append(fuse_layer_norms) transforms.append( - get_model_with_r1_r2(args.optimized_rotation_path) + get_model_with_r1_r2(config.backend.qnn.optimized_rotation_path) ) - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. + # pyre-ignore: Module `backends` has no attribute `qualcomm`. transforms.append(convert_linear_to_conv2d) - elif args.mps: + elif config.backend.mps.enabled: # Currently mps doesn't support sdpa op, use the simpler decomposition # to get free perf gain. transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_causal_mask) - elif args.coreml: + elif config.backend.coreml.enabled: # iOS 18 introduced fused sdpa op - if args.coreml_ios >= 18: + if config.backend.coreml.ios >= 18: transforms.append(replace_sdpa_with_coreml_sdpa) else: transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_kv_cache_with_coreml_kv_cache) - if args.vulkan: + if config.backend.vulkan.enabled: transforms.append(replace_with_vulkan_rotary_emb) return transforms diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index cca6ede1d79..d0c5a7d6c0d 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,7 +10,7 @@ # Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph # Install lm-eval for Model Evaluation with lm-evalution-harness. -pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile +pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile omegaconf # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 0b842a8f976..f38e5b337ec 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -11,11 +11,13 @@ import torch from executorch.examples.models.llama.export_llama_lib import ( + _convert_args_to_config, _prepare_for_llama_export, build_args_parser as _build_args_parser, ) from executorch.examples.models.llama.runner.generation import LlamaRunner from executorch.extension.llm.export.builder import LLMEdgeManager +from omegaconf import DictConfig, OmegaConf class EagerLlamaRunner(LlamaRunner): @@ -23,19 +25,19 @@ class EagerLlamaRunner(LlamaRunner): Runs llama in eager mode with provided checkpoint file. """ - def __init__(self, args): - with open(args.params, "r") as f: + def __init__(self, config): + with open(config.model.params, "r") as f: params = json.loads(f.read()) super().__init__( - tokenizer_path=args.tokenizer_path, - tokenizer_config_path=args.tokenizer_config_path, - max_seq_len=args.max_seq_length, + tokenizer_path=config.export.tokenizer_path, + tokenizer_config_path=config.eager.tokenizer_config_path, + max_seq_len=config.sequence.max_seq_length, max_batch_size=1, - use_kv_cache=args.use_kv_cache, + use_kv_cache=config.kv_cache.use_kv_cache, vocab_size=params["vocab_size"], device="cuda" if torch.cuda.is_available() else "cpu", ) - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(config) self.model = manager.model.eval().to(device=self.device) def forward( @@ -85,26 +87,46 @@ def build_args_parser() -> argparse.ArgumentParser: return parser +def _convert_cli_to_config_format(args) -> DictConfig: + """Convert CLI arguments to config format.""" + # First convert common args using the shared function + config = _convert_args_to_config(args) + + # Add evaluation-specific settings + config.eager = OmegaConf.create() + config.eager.prompt = args.prompt + config.eager.temperature = args.temperature + config.eager.show_tokens = args.show_tokens + config.eager.chat = args.chat + config.eager.tokenizer_config_path = args.tokenizer_config_path + + return config + + def execute_runner(runner_class: Type[LlamaRunner]) -> None: parser = build_args_parser() args = parser.parse_args() - + config = _convert_cli_to_config_format(args) with torch.no_grad(): - runner = runner_class(args) # pyre-ignore: Missing argument [20] + runner = runner_class(config) # pyre-ignore: Missing argument [20] generated_tokens = ( runner.chat_completion( - max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length, - temperature=args.temperature, - show_progress=args.show_tokens, + max_seq_len=( + 1000000 + if config.misc.use_attention_sink + else config.sequence.max_seq_length + ), + temperature=config.eager.temperature, + show_progress=config.eager.show_tokens, ) - if args.chat + if config.eager.chat else runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + prompt=config.eager.prompt, + temperature=config.eager.temperature, echo=True, ) ) - if args.show_tokens: + if config.eager.show_tokens: print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 17cff7c63fd..9ace8153fbf 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -15,6 +15,7 @@ import torch.nn.functional as F from executorch.extension.llm.export.builder import DType +from omegaconf import DictConfig, OmegaConf from sentencepiece import SentencePieceProcessor @@ -784,9 +785,9 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: ############################ Source Transform Start ####################### -def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None): - if args.embedding_quantize.startswith("torchao:"): - bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") +def get_quant_embedding_transform(config: DictConfig, dtype_override: Optional[DType] = None): + if config.quantization.embedding_quantize.startswith("torchao:"): + bitwidth, group_size = config.quantization.embedding_quantize.split(":")[1].split(",") group_size = int(group_size) bitwidth = int(bitwidth) from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer @@ -803,7 +804,7 @@ def _torchao_embedding_quantizer(model): return _torchao_embedding_quantizer - bitwidth, group_size = args.embedding_quantize.split(",") + bitwidth, group_size = config.quantization.embedding_quantize.split(",") if group_size == "none" or group_size == "None" or group_size == "0": group_size = None else: @@ -820,37 +821,39 @@ def _torchao_embedding_quantizer(model): def get_quant_weight_transform( - args, + config: DictConfig, computation_dtype: Optional[DType] = None, checkpoint_dtype: Optional[DType] = None, ): # If these optional args are None, don't provide them to quantize(). - quant_args_str = [ - "group_size", - "calibration_tasks", - "calibration_limit", - "calibration_seq_length", - ] - arg_dict = vars(args) - quant_args = { - param: val - for param in quant_args_str - if (val := arg_dict.get(param)) is not None - } + quant_args = {} + if config.quantization.group_size is not None: + quant_args['group_size'] = config.quantization.group_size + if config.calibration.tasks is not None: + quant_args['calibration_tasks'] = OmegaConf.to_container(config.calibration.tasks) + if config.calibration.limit is not None: + quant_args['calibration_limit'] = config.calibration.limit + if config.calibration.seq_length is not None: + quant_args['calibration_seq_length'] = config.calibration.seq_length + + + + group_size = config.quantization.group_size + calibration_tasks = config.calibration.tasks + calibration_limit = config.calibration.limit + calibration_seq_length = config.calibration.seq_length return partial( quantize, **quant_args, - qmode=args.quantization_mode, + qmode=config.quantization.mode, computation_dtype=computation_dtype, checkpoint_dtype=checkpoint_dtype, - checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None), + checkpoint_path=(Path(path) if (path := config.export.checkpoint) is not None else None), tokenizer_path=( - Path(path) if (path := args.tokenizer_path) is not None else None + Path(path) if (path := config.export.tokenizer_path) is not None else None ), ) - - def _load_torchao_aten_lib(libname): import glob import os diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 64def112908..941a0a70288 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -17,6 +17,7 @@ XNNPACKQuantizer, ) from executorch.examples.models.llama.export_llama_lib import ( + _convert_args_to_config, build_args_parser, get_quantizer_and_quant_params, ) @@ -110,8 +111,9 @@ def forward(self, input_pos, embeddings): "4,32", ] ) - quant_transform = get_quant_weight_transform(args, dtype_override) - _, quantizers, _ = get_quantizer_and_quant_params(args) + config = _convert_args_to_config(args) + quant_transform = get_quant_weight_transform(config, dtype_override) + _, quantizers, _ = get_quantizer_and_quant_params(config) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: source_transforms.append(replace_kv_cache_with_custom_kv_cache)