diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 49b666912246..cd9b12847f3a 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -201,7 +201,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): def __init__( self, model: PreTrainedModel, - ): + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Optional[torch.device] = None, + ) -> None: """ Initializes the exportable module. @@ -214,20 +217,19 @@ def __init__( super().__init__() config = model.config.get_text_config() - _generation_config = model.generation_config if not hasattr(config, "use_cache") or config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: - self.model = TorchExportableModuleWithHybridCache(model) + self.model = TorchExportableModuleWithHybridCache(model, batch_size, max_cache_len, device) else: # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, # there is only 1 type of layers, so export will use `StaticCache` by default. logging.info( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) - self.model = TorchExportableModuleWithStaticCache(model) + self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device) # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) @@ -471,17 +473,27 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - ): + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Optional[torch.device] = None, + ) -> None: """ Initializes the wrapper module with the pretrained model. Args: model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching enabled and use a 'static' caching implementation. + batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we raise a ValueError. + max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if + not provided. + device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised). Raises: AssertionError: If the pretrained model does not have caching enabled or if it does not use a 'static' caching implementation in `model.generation_config`. + ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`. """ super().__init__() @@ -494,16 +506,6 @@ def __init__( "The model must have a generation config to be exported with static caching. " "Please set `generation_config` in `model`." ) - if "batch_size" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a batch_size in its cache_config. " - 'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)' - ) - if "max_cache_len" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a max_cache_len in its cache_config. " - 'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)' - ) if not generation_config.use_cache: raise AssertionError( "The model must have caching enabled to be exported with static caching. " @@ -515,15 +517,26 @@ def __init__( "Please set `generation_config.cache_implementation='static'`." ) + cache_config = {} if generation_config.cache_config is None else generation_config.cache_config + + # Ensure batch_size and max_cache_len are set + if batch_size is None: + batch_size = cache_config.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size must be provided, either as an argument or in cache_config.") + if max_cache_len is None: + max_cache_len = cache_config.get("max_cache_len", None) + if max_cache_len is None: + raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.") + # Infer device if not provided + if device is None: + device = cache_config.get("device", model.device) + + # Initialize the static cache self.model = model - self.static_cache = StaticCache( - max_cache_len=generation_config.cache_config.get("max_cache_len"), - config=config, - ) - batch_size = generation_config.cache_config.get("batch_size") + self.static_cache = StaticCache(max_cache_len=max_cache_len, config=config) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - device = generation_config.cache_config.get("device") dtype = self.model.dtype # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device) @@ -639,48 +652,60 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - ): + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Optional[torch.device] = None, + ) -> None: """ Initializes the exportable module. Args: model (`PreTrainedModel`): The pretrained model to wrap. - + batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we raise a ValueError. + max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if + not provided. + device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised). Raises: - AssertionError: If the model doesn't have the expected configuration for an hybrid StaticCache. + AssertionError: If the model doesn't have the expected configuration for hybrid StaticCache. + ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`. """ super().__init__() self.model = model config = model.config.get_text_config() generation_config = model.generation_config + # Sanity checks if generation_config is None: raise AssertionError( "The model must have a generation config to be exported with static caching. " "Please set `generation_config` in `model`." ) - if "batch_size" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a batch_size in its cache_config. " - 'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)' - ) - if "max_cache_len" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a max_cache_len in its cache_config. " - 'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)' - ) if not config.use_cache: raise AssertionError("Model must have caching enabled.") + cache_config = {} if generation_config.cache_config is None else generation_config.cache_config + # Ensure batch_size and max_cache_len are set + if batch_size is None: + batch_size = cache_config.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size must be provided, either as an argument or in cache_config.") + if max_cache_len is None: + max_cache_len = cache_config.get("max_cache_len", None) + if max_cache_len is None: + raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.") + # Infer device if not provided + if device is None: + device = cache_config.get("device", model.device) + # Initialize the cache - self.cache = StaticCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len")) + self.cache = StaticCache(config=config, max_cache_len=max_cache_len) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - max_batch_size = generation_config.cache_config.get("batch_size") - device = generation_config.cache_config.get("device") dtype = self.model.dtype # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) - self.cache.early_initialization(max_batch_size, num_heads, head_dim, dtype, device) + self.cache.early_initialization(batch_size, num_heads, head_dim, dtype, device) # Register all key and value cache tensors as buffers for i in range(len(self.cache)): diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 284cd4c19909..097c82a0e5a0 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -416,6 +416,9 @@ def test_export_static_cache(self): ("cuda", 8): [ "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have been looking on the internet and I have" ], + ("rocm", (9, 5)): [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have been looking on the internet and I have" + ], } ) EXPECTED_TEXT_COMPLETION = expectations.get_expectation() diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 59921594d691..423640cb31b1 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -255,6 +255,9 @@ def test_export_static_cache(self): ("cuda", 8): [ "Hello I am doing a project for my class and I am having trouble with the code. I am trying to make a" ], + ("rocm", (9, 5)): [ + "Hello I am doing a project for my school and I need to know how to make a program that will take a number" + ], } ) EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() @@ -320,7 +323,7 @@ def test_export_hybrid_cache(self): # Export + hybrid cache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), cache_position=torch.tensor([0], dtype=torch.long, device=model.device), diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index f1e1a5a95fdd..99ea97561f78 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -733,7 +733,7 @@ def test_export_text_only_with_hybrid_cache(self): # Export + hybrid cache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), cache_position=torch.tensor([0], dtype=torch.long, device=model.device), diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 74da981d092a..30c5082393ef 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -257,7 +257,7 @@ def test_export_static_cache(self): "My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use" ], ("rocm", (9, 5)): [ - "My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking, but" + "My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use" ] }) # fmt: off EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()