Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 63 additions & 38 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
def __init__(
self,
model: PreTrainedModel,
):
batch_size: Optional[int] = None,
max_cache_len: Optional[int] = None,
device: Optional[torch.device] = None,
) -> None:
"""
Initializes the exportable module.

Expand All @@ -214,20 +217,19 @@ def __init__(
super().__init__()

config = model.config.get_text_config()
_generation_config = model.generation_config

if not hasattr(config, "use_cache") or config.use_cache is False:
raise ValueError("The model must have caching enabled to be performant.")

if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
self.model = TorchExportableModuleWithHybridCache(model)
self.model = TorchExportableModuleWithHybridCache(model, batch_size, max_cache_len, device)
else:
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
# there is only 1 type of layers, so export will use `StaticCache` by default.
logging.info(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model)
self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
Expand Down Expand Up @@ -471,17 +473,27 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
def __init__(
self,
model: PreTrainedModel,
):
batch_size: Optional[int] = None,
max_cache_len: Optional[int] = None,
device: Optional[torch.device] = None,
) -> None:
"""
Initializes the wrapper module with the pretrained model.

Args:
model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
enabled and use a 'static' caching implementation.
batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found
in `generation_config.cache_config` and otherwise we raise a ValueError.
max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if
not provided.
device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found
in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised).

Raises:
AssertionError: If the pretrained model does not have caching enabled or if it does
not use a 'static' caching implementation in `model.generation_config`.
ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`.
"""
super().__init__()

Expand All @@ -494,16 +506,6 @@ def __init__(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config` in `model`."
)
if "batch_size" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a batch_size in its cache_config. "
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
)
if "max_cache_len" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a max_cache_len in its cache_config. "
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
)
if not generation_config.use_cache:
raise AssertionError(
"The model must have caching enabled to be exported with static caching. "
Expand All @@ -515,15 +517,26 @@ def __init__(
"Please set `generation_config.cache_implementation='static'`."
)

cache_config = {} if generation_config.cache_config is None else generation_config.cache_config

# Ensure batch_size and max_cache_len are set
if batch_size is None:
batch_size = cache_config.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size must be provided, either as an argument or in cache_config.")
if max_cache_len is None:
max_cache_len = cache_config.get("max_cache_len", None)
if max_cache_len is None:
raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.")
# Infer device if not provided
if device is None:
device = cache_config.get("device", model.device)

# Initialize the static cache
self.model = model
self.static_cache = StaticCache(
max_cache_len=generation_config.cache_config.get("max_cache_len"),
config=config,
)
batch_size = generation_config.cache_config.get("batch_size")
self.static_cache = StaticCache(max_cache_len=max_cache_len, config=config)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
device = generation_config.cache_config.get("device")
dtype = self.model.dtype
# We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
Expand Down Expand Up @@ -639,48 +652,60 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
def __init__(
self,
model: PreTrainedModel,
):
batch_size: Optional[int] = None,
max_cache_len: Optional[int] = None,
device: Optional[torch.device] = None,
) -> None:
"""
Initializes the exportable module.

Args:
model (`PreTrainedModel`): The pretrained model to wrap.

batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found
in `generation_config.cache_config` and otherwise we raise a ValueError.
max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if
not provided.
device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found
in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised).
Raises:
AssertionError: If the model doesn't have the expected configuration for an hybrid StaticCache.
AssertionError: If the model doesn't have the expected configuration for hybrid StaticCache.
ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`.
"""
super().__init__()
self.model = model
config = model.config.get_text_config()
generation_config = model.generation_config

# Sanity checks
if generation_config is None:
raise AssertionError(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config` in `model`."
)
if "batch_size" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a batch_size in its cache_config. "
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
)
if "max_cache_len" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a max_cache_len in its cache_config. "
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
)
if not config.use_cache:
raise AssertionError("Model must have caching enabled.")

cache_config = {} if generation_config.cache_config is None else generation_config.cache_config
# Ensure batch_size and max_cache_len are set
if batch_size is None:
batch_size = cache_config.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size must be provided, either as an argument or in cache_config.")
if max_cache_len is None:
max_cache_len = cache_config.get("max_cache_len", None)
if max_cache_len is None:
raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.")
# Infer device if not provided
if device is None:
device = cache_config.get("device", model.device)

# Initialize the cache
self.cache = StaticCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len"))
self.cache = StaticCache(config=config, max_cache_len=max_cache_len)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
max_batch_size = generation_config.cache_config.get("batch_size")
device = generation_config.cache_config.get("device")
dtype = self.model.dtype
# We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
self.cache.early_initialization(max_batch_size, num_heads, head_dim, dtype, device)
self.cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)

# Register all key and value cache tensors as buffers
for i in range(len(self.cache)):
Expand Down
3 changes: 3 additions & 0 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def test_export_static_cache(self):
("cuda", 8): [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have been looking on the internet and I have"
],
("rocm", (9, 5)): [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have been looking on the internet and I have"
],
}
)
EXPECTED_TEXT_COMPLETION = expectations.get_expectation()
Expand Down
5 changes: 4 additions & 1 deletion tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ def test_export_static_cache(self):
("cuda", 8): [
"Hello I am doing a project for my class and I am having trouble with the code. I am trying to make a"
],
("rocm", (9, 5)): [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number"
],
}
)
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
Expand Down Expand Up @@ -320,7 +323,7 @@ def test_export_hybrid_cache(self):

# Export + hybrid cache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
Expand Down
2 changes: 1 addition & 1 deletion tests/models/gemma3/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def test_export_text_only_with_hybrid_cache(self):

# Export + hybrid cache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
Expand Down
2 changes: 1 addition & 1 deletion tests/models/qwen2/test_modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_export_static_cache(self):
"My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use"
],
("rocm", (9, 5)): [
"My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking, but"
"My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use"
]
}) # fmt: off
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()
Expand Down