diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index a56dceaa24f4..f9bc88eaa138 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -198,34 +198,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): def __init__( self, model: PreTrainedModel, - max_batch_size: int = 1, - max_cache_len: int = 4096, ): """ Initializes the exportable module with `HybridCache`. Args: model (`PreTrainedModel`): The pretrained model to wrap. - max_batch_size (int): Maximum batch size for the cache. - max_cache_len (int): Maximum sequence length for the cache. Raises: ValueError: If the model is configured with a unsupported cache implementation. """ super().__init__() - if not hasattr(model.config, "use_cache") or model.config.use_cache is False: + config = model.config.get_text_config() + _generation_config = model.generation_config + + if not hasattr(config, "use_cache") or config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") - if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None: - self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: + self.model = TorchExportableModuleWithHybridCache(model) else: # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, # there is only 1 type of layers, so export will use `StaticCache` by default. logging.info( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) - self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len) + self.model = TorchExportableModuleWithStaticCache(model) # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) @@ -233,24 +232,31 @@ def __init__( def forward( self, - input_ids: torch.Tensor, - cache_position: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the module, which is compatible with the ExecuTorch llm runner. Args: input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module. cache_position (`torch.Tensor`): Tensor representing current input position in the cache. Returns: torch.Tensor: Logits output from the model. """ - return self.model.forward(input_ids, cache_position) + return self.model.forward( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + ) def export( self, input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None, dynamic_shapes: Optional[dict] = None, strict: Optional[bool] = None, @@ -260,14 +266,49 @@ def export( Args: input_ids (`Optional[torch.Tensor]`): - Tensor representing current input token id to the module. If not provided, a default tensor will be used. + Tensor representing current input token id to the module. Must specify either this or inputs_embeds. + inputs_embeds (`Optional[torch.Tensor]`): + Tensor representing current input embeddings to the module. Must specify either this or input_ids. cache_position (`Optional[torch.Tensor]`): Tensor representing current input position in the cache. If not provided, a default tensor will be used. dynamic_shapes (`Optional[dict]`): Dynamic shapes to use for export if specified. strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`. + + Returns: + torch.export.ExportedProgram: The exported program that can be used for inference. + + Examples: + Export with input_ids: + ```python + # Prepare inputs + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device) + cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device) + + # Export + exported = exportable_module.export( + input_ids=input_ids, + cache_position=cache_position + ) + ``` + + Export with inputs_embeds: + ```python + # Prepare embeddings + inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768 + cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device) + + # Export + exported = exportable_module.export( + inputs_embeds=inputs_embeds, + cache_position=cache_position + ) + ``` """ + if not (input_ids is None) ^ (inputs_embeds is None): + raise ValueError("Need to specify either input_ids or inputs_embeds.") + if hasattr(self.model, "base_model_prefix"): base = getattr(self.model, self.model.base_model_prefix, self.model) model_device = base.device @@ -279,20 +320,29 @@ def export( "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default." ) - example_input_ids = ( - input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device) - ) - example_cache_position = ( - cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) - ) + if input_ids is not None: + input_kwargs = { + "input_ids": input_ids, + "cache_position": cache_position + if cache_position is not None + else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device), + } + else: # inputs_embeds + input_kwargs = { + "inputs_embeds": inputs_embeds, + "cache_position": cache_position + if cache_position is not None + else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device), + } exported_program = torch.export.export( self.model, - args=(example_input_ids, example_cache_position), - kwargs={}, + args=(), + kwargs=input_kwargs, dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) + return exported_program @staticmethod @@ -341,7 +391,7 @@ def generate( curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) # Forward pass - _ = exported_module(curr_input_ids, curr_cache_position) + _ = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position) curr_position += 1 # Generate new tokens @@ -351,7 +401,7 @@ def generate( curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) # Forward pass to get next token logits - outputs = exported_module(curr_input_ids, curr_cache_position) + outputs = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position) # Get the next token ID if do_sample: @@ -418,15 +468,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - max_batch_size: int = 1, - max_cache_len: int = 4096, ): """ Initializes the wrapper module with the pretrained model. Args: model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching - enabled and use a 'static' caching implementation. + enabled and use a 'static' caching implementation. Raises: AssertionError: If the pretrained model does not have caching enabled or if it does @@ -434,27 +482,31 @@ def __init__( """ super().__init__() + config = model.config.get_text_config() + generation_config = model.generation_config + # Sanity checks - if model.generation_config is None: - # Use default generation config if not specified - model.generation_config = GenerationConfig( - use_cache=model.config.use_cache, - cache_implementation="static", - max_length=max_cache_len, - cache_config={ - "batch_size": max_batch_size, - "max_cache_len": max_cache_len, - "device": "cpu", - }, + if generation_config is None: + raise AssertionError( + "The model must have a generation config to be exported with static caching. " + "Please set `generation_config` in `model`." ) - - if not model.generation_config.use_cache: + if "batch_size" not in generation_config.cache_config: + raise ValueError( + "The model's generation config must specify a batch_size in its cache_config. " + 'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)' + ) + if "max_cache_len" not in generation_config.cache_config: + raise ValueError( + "The model's generation config must specify a max_cache_len in its cache_config. " + 'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)' + ) + if not generation_config.use_cache: raise AssertionError( "The model must have caching enabled to be exported with static caching. " "Please set `generation_config.use_cache=True`." ) - - if model.generation_config.cache_implementation != "static": + if generation_config.cache_implementation != "static": raise AssertionError( "The model must use a 'static' caching implementation to be exported with static caching. " "Please set `generation_config.cache_implementation='static'`." @@ -462,22 +514,29 @@ def __init__( self.model = model self.static_cache = StaticCache( - config=self.model.config, - max_batch_size=self.model.generation_config.cache_config.get("batch_size"), - max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"), - device=self.model.generation_config.cache_config.get("device"), + config=config, + max_batch_size=generation_config.cache_config.get("batch_size"), + max_cache_len=generation_config.cache_config.get("max_cache_len"), + device=generation_config.cache_config.get("device"), dtype=self.model.dtype, ) + for i in range(len(self.static_cache)): self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) - def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ): """ Forward pass of the module, which is compatible with the ExecuTorch runtime. Args: input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module. cache_position (`torch.Tensor`): Tensor representing current input position in the cache. Returns: @@ -493,15 +552,13 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`, ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. """ - _, seqlen = input_ids.shape - position_ids = cache_position.unsqueeze(0) past_key_values = self.static_cache outs = self.model( input_ids=input_ids, - attention_mask=None, - position_ids=position_ids, + inputs_embeds=inputs_embeds, cache_position=cache_position, + attention_mask=None, past_key_values=past_key_values, use_cache=True, ) @@ -576,33 +633,45 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - max_batch_size: int = 1, - max_cache_len: int = 4096, ): """ Initializes the exportable module with `HybridCache`. Args: model (`PreTrainedModel`): The pretrained model to wrap. - max_batch_size (int): Maximum batch size for the cache. - max_cache_len (int): Maximum sequence length for the cache. Raises: AssertionError: If the model doesn't have the expected configuration for HybridCache. """ super().__init__() self.model = model + config = model.config.get_text_config() + generation_config = model.generation_config - # Verify the model is configured for HybridCache - if not self.model.config.use_cache: - raise AssertionError("Model must have caching enabled") + if generation_config is None: + raise AssertionError( + "The model must have a generation config to be exported with static caching. " + "Please set `generation_config` in `model`." + ) + if "batch_size" not in generation_config.cache_config: + raise ValueError( + "The model's generation config must specify a batch_size in its cache_config. " + 'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)' + ) + if "max_cache_len" not in generation_config.cache_config: + raise ValueError( + "The model's generation config must specify a max_cache_len in its cache_config. " + 'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)' + ) + if not config.use_cache: + raise AssertionError("Model must have caching enabled.") # Initialize the HybridCache self.cache = HybridCache( - config=self.model.config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=self.model.device, + config=config, + max_batch_size=generation_config.cache_config.get("batch_size"), + max_cache_len=generation_config.cache_config.get("max_cache_len"), + device=generation_config.cache_config.get("device"), dtype=self.model.dtype, ) @@ -613,32 +682,29 @@ def __init__( def forward( self, - input_ids: torch.Tensor, - cache_position: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the module, which is compatible with the ExecuTorch llm runner. Args: input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`Optional[torch.Tensor]`): Tensor representing current input embeddings to the module. cache_position (`torch.Tensor`): Tensor representing current input position in the cache. Returns: torch.Tensor: Logits output from the model. """ - batch_size = input_ids.shape[0] - - # Generate position_ids from cache_position - position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) - # Forward pass with the model outputs = self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, + cache_position=cache_position, attention_mask=None, - position_ids=position_ids, past_key_values=self.cache, use_cache=True, - cache_position=cache_position, ) # Return only the logits to simplify the export @@ -692,8 +758,8 @@ def convert_and_export_with_cache( if is_torch_greater_or_equal("2.6.0"): exported_program = torch.export.export( TorchExportableModuleWithStaticCache(model), - args=(example_input_ids, example_cache_position), - kwargs={}, + args=(), + kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) @@ -710,8 +776,8 @@ def convert_and_export_with_cache( # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. exported_program = torch.export._trace._export( TorchExportableModuleWithStaticCache(model), - args=(example_input_ids,), - kwargs={"cache_position": example_cache_position}, + args=(), + kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, pre_dispatch=False, strict=True, ) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 9276fb12b328..8a1e2ea9eb7f 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -460,7 +460,10 @@ def test_export_static_cache(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exported_program = exportable_module.export( + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 589e08dd1d98..4a6c326f780c 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -365,7 +365,10 @@ def test_export_static_cache(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exported_program = exportable_module.export( + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) @@ -389,7 +392,10 @@ def test_export_hybrid_cache(self): # Export + HybridCache model.eval() exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device), + ) # Test generation with the exported model prompt = "What is the capital of France?" diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 43ac57dbb566..e1b444e2c546 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -809,7 +809,10 @@ def test_export_text_only_with_hybrid_cache(self): # Export + HybridCache model.eval() exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device), + ) logging.info(f"\nExported program: {exported_program}") # Test generation with the exported model diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 136f76f48c9a..a6c2c3eee2b6 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -353,7 +353,10 @@ def test_export_static_cache(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exported_program = exportable_module.export( + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 86913f254fbb..ea23f4e96fda 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -384,7 +384,10 @@ def test_export_static_cache(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exported_program = exportable_module.export( + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 387eb6c4df79..6887c0c6cd64 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -417,7 +417,10 @@ def test_export_static_cache(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exported_program = exportable_module.export( + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index d48226394c33..51bd943cf916 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -303,7 +303,11 @@ def test_export_static_cache(self): strict = version.parse(torch.__version__) != version.parse( "2.7.0" ) # Due to https://github.com/pytorch/pytorch/issues/150994 - exported_program = exportable_module.export(strict=strict) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + strict=strict, + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index a37df40ed4a8..205228073e19 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -293,7 +293,11 @@ def test_export_static_cache(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export(strict=strict) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + strict=strict, + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/test_executorch.py b/tests/test_executorch.py new file mode 100644 index 000000000000..0e33253c08f1 --- /dev/null +++ b/tests/test_executorch.py @@ -0,0 +1,129 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from transformers import AutoModelForCausalLM, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.integrations.executorch import ( + TorchExportableModuleForDecoderOnlyLM, + TorchExportableModuleWithHybridCache, + TorchExportableModuleWithStaticCache, +) +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 +from transformers.testing_utils import require_torch + + +@require_torch +class ExecutorchTest(unittest.TestCase): + def setUp(self): + if not is_torch_greater_or_equal_than_2_3: + self.skipTest("torch >= 2.3 is required") + + set_seed(0) + self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + self.model.eval() + + # Create generation config with static cache for the model + self.model.generation_config = GenerationConfig( + use_cache=True, + cache_implementation="static", + cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, + ) + + self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size) + self.cache_position = torch.arange(3, dtype=torch.long) + + def test_static_cache_module_forward(self): + """Test TorchExportableModuleWithStaticCache forward with both input types""" + generation_config = GenerationConfig( + use_cache=True, + cache_implementation="static", + cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, + ) + + # Set generation config on model + self.model.generation_config = generation_config + module = TorchExportableModuleWithStaticCache(self.model) + + # Test with input_ids + eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits + wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4) + + # Test with inputs_embeds + eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits + wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4) + + def test_hybrid_cache_module_forward(self): + """Test TorchExportableModuleWithHybridCache forward with both input types""" + config = self.model.config + config.sliding_window = 16 + config.layer_types = ["full_attention"] * config.num_hidden_layers + + generation_config = GenerationConfig( + use_cache=True, + cache_implementation="hybrid", + cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, + ) + + # Set generation config on model + self.model.generation_config = generation_config + module = TorchExportableModuleWithHybridCache(self.model) + + # Test with input_ids + eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits + wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4) + + # Test with inputs_embeds + eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits + wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4) + + def test_decoder_only_lm_export_validation(self): + """Test TorchExportableModuleForDecoderOnlyLM export validation""" + module = TorchExportableModuleForDecoderOnlyLM(self.model) + + # Should fail with both input_ids and inputs_embeds + with self.assertRaises(ValueError): + module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds) + + # Should fail with neither + with self.assertRaises(ValueError): + module.export() + + def test_decoder_only_lm_export(self): + """Test TorchExportableModuleForDecoderOnlyLM export with both input types""" + module = TorchExportableModuleForDecoderOnlyLM(self.model) + + # Test export with input_ids + exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position) + eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits + exported_output_ids = exported_program_ids.module()( + input_ids=self.input_ids, cache_position=self.cache_position + ) + torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4) + + # Test export with inputs_embeds + exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position) + eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits + exported_output_embeds = exported_program_embeds.module()( + inputs_embeds=self.inputs_embeds, cache_position=self.cache_position + ) + torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 14b29344f190..74b19395a67f 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -841,8 +841,24 @@ def test_hybrid_cache_exportability(self): model.eval() max_batch_size = 1 max_cache_len = 23 - exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len) - exported_program = exportable_module.export() + # Set generation config on the model for the hybrid cache model + from transformers.generation.configuration_utils import GenerationConfig + + model.generation_config = GenerationConfig( + use_cache=True, + cache_implementation="hybrid", + max_length=max_cache_len, + cache_config={ + "batch_size": max_batch_size, + "max_cache_len": max_cache_len, + "device": model.device, + }, + ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device), + ) n_g_key_caches = n_g_value_caches = 0 for buffer_name, buffer in exported_program.named_buffers(): if buffer_name.startswith("key_cache"):