Skip to content
195 changes: 139 additions & 56 deletions src/transformers/integrations/executorch.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion tests/models/cohere2/test_modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def test_export_static_cache(self):
max_new_tokens = 30 - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
exported_program = convert_and_export_with_cache(
model, config=model.config, generation_config=model.generation_config
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/exaone4/test_modeling_exaone4.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ def test_export_static_cache(self):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
exported_program = convert_and_export_with_cache(
model, config=model.config, generation_config=model.generation_config
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
9 changes: 7 additions & 2 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,13 @@ def test_export_static_cache(self):
# Static Cache + export
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
18 changes: 14 additions & 4 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,13 @@ def test_export_static_cache(self):
# Static Cache + export
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand All @@ -388,8 +393,13 @@ def test_export_hybrid_cache(self):

# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)

# Test generation with the exported model
prompt = "What is the capital of France?"
Expand Down
9 changes: 7 additions & 2 deletions tests/models/gemma3/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,13 @@ def test_export_text_only_with_hybrid_cache(self):

# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
logging.info(f"\nExported program: {exported_program}")

# Test generation with the exported model
Expand Down
9 changes: 7 additions & 2 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,13 @@ def test_export_static_cache(self):
# Static Cache + export
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
9 changes: 7 additions & 2 deletions tests/models/olmo/test_modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,13 @@ def test_export_static_cache(self):
# Static Cache + export
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/olmo2/test_modeling_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ def test_export_static_cache(self):
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
exported_program = convert_and_export_with_cache(
model, config=model.config, generation_config=model.generation_config
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
9 changes: 7 additions & 2 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,13 @@ def test_export_static_cache(self):
# Static Cache + export
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
10 changes: 8 additions & 2 deletions tests/models/qwen2/test_modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,17 @@ def test_export_static_cache(self):
# Static Cache + export
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
strict = version.parse(torch.__version__) != version.parse(
"2.7.0"
) # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = exportable_module.export(strict=strict)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
10 changes: 8 additions & 2 deletions tests/models/qwen3/test_modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,14 @@ def test_export_static_cache(self):
# Static Cache + export
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(strict=strict)
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/smollm3/test_modeling_smollm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def test_export_static_cache(self):

# Static Cache + export
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = convert_and_export_with_cache(model, strict=strict)
exported_program = convert_and_export_with_cache(
model, config=model.config, generation_config=model.generation_config, strict=strict
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
26 changes: 23 additions & 3 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,9 @@ def test_static_cache_exportability(self):

from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=model.generation_config
)
exported_program = exportable_module.export(
input_ids=input_ids,
cache_position=cache_position,
Expand Down Expand Up @@ -841,8 +843,26 @@ def test_hybrid_cache_exportability(self):
model.eval()
max_batch_size = 1
max_cache_len = 23
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
exported_program = exportable_module.export()
# Create generation config for the hybrid cache model
from transformers.generation.configuration_utils import GenerationConfig

generation_config = GenerationConfig(
use_cache=True,
cache_implementation="hybrid",
max_length=max_cache_len,
cache_config={
"batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": model.device,
},
)
exportable_module = TorchExportableModuleForDecoderOnlyLM(
model, config=model.config, generation_config=generation_config
)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
n_g_key_caches = n_g_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("key_cache"):
Expand Down