From 311074a740a5d5208bb5570a2fef5719d4f6f57e Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 27 May 2025 16:41:49 -0700 Subject: [PATCH] refactor: Use LlmConfig for model parameters instead of kwargs [ghstack-poisoned] --- examples/models/llama/export_llama_lib.py | 8 ------ examples/models/llama/model.py | 35 ++++++++++++++++------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index c1dd27dc390..0a4ecb73199 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1229,15 +1229,7 @@ def _load_llama_model( checkpoint=checkpoint, checkpoint_dir=checkpoint_dir, params=params_path, - use_kv_cache=use_kv_cache, - use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, - generate_full_logits=generate_full_logits, fairseq2=weight_type == WeightType.FAIRSEQ2, - max_seq_len=max_seq_len, - max_context_len=max_context_len, - enable_dynamic_shape=enable_dynamic_shape, - input_prune_map_path=input_prune_map_path, - output_prune_map_path=output_prune_map_path, dtype=torch_dtype, llm_config=llm_config, ) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ac0119bcbe5..4d441a5a32e 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -47,15 +47,30 @@ def __init__(self, **kwargs): # Params file. params_path = kwargs.get("params", None) - self.use_kv_cache = kwargs.get("use_kv_cache", False) - self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) - self.generate_full_logits = kwargs.get("generate_full_logits", False) - self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) - self.input_prune_map_path = kwargs.get("input_prune_map_path", None) - self.output_prune_map_path = kwargs.get("output_prune_map_path", None) - self.max_seq_len = kwargs.get("max_seq_len", 128) - self.max_context_len = kwargs.get("max_context_len", 128) self.llm_config = kwargs.get("llm_config", None) + + # Set all parameters from llm_config if available, otherwise use kwargs as fallback + if self.llm_config: + self.use_kv_cache = self.llm_config.model.use_kv_cache + self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache + self.generate_full_logits = self.llm_config.debug.generate_full_logits + self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape + self.input_prune_map_path = self.llm_config.model.input_prune_map + self.output_prune_map_path = self.llm_config.model.output_prune_map + self.max_seq_len = self.llm_config.export.max_seq_length + self.max_context_len = self.llm_config.export.max_context_length + self.verbose = self.llm_config.debug.verbose + else: + # Fallback to kwargs for backward compatibility + self.use_kv_cache = kwargs.get("use_kv_cache", False) + self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) + self.generate_full_logits = kwargs.get("generate_full_logits", False) + self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) + self.input_prune_map_path = kwargs.get("input_prune_map_path", None) + self.output_prune_map_path = kwargs.get("output_prune_map_path", None) + self.max_seq_len = kwargs.get("max_seq_len", 128) + self.max_context_len = kwargs.get("max_context_len", 128) + self.verbose = kwargs.get("verbose", False) assert ( self.max_context_len >= self.max_seq_len @@ -165,7 +180,7 @@ def __init__(self, **kwargs): if model_name not in ["llama3", "llama3_1"]: model_args.rope_scale_factor = 32 - if kwargs.get("verbose", False): + if self.verbose: print("============= weights ================") print("{key} : {weights.numel()} : {weights.size()}") for key, weights in checkpoint.items(): @@ -280,7 +295,7 @@ def __init__(self, **kwargs): f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match." ) if unexpected: - if kwargs.get("verbose", False): + if self.verbose: print(f"Unexpected keys: {unexpected}") # Prune the input layer if input_prune_map is provided