Skip to content

Commit 6121e9e

Browse files
authored
Support input_embeds in torch exportable decoders (#39836)
* Support input_embeds in torch exportable decoders * Hybrid cache update * Manually change some callsites * AI changes the rest of the call sites * Make either input_ids/inputs_embeds mandatory * Clean up * Ruff check --fix * Fix test * pr review * Revert config/generation_config changes * Ruff check
1 parent cdeaad9 commit 6121e9e

File tree

11 files changed

+324
-84
lines changed

11 files changed

+324
-84
lines changed

src/transformers/integrations/executorch.py

Lines changed: 139 additions & 73 deletions
Large diffs are not rendered by default.

tests/models/gemma/test_modeling_gemma.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,10 @@ def test_export_static_cache(self):
460460
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
461461

462462
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
463-
exported_program = exportable_module.export()
463+
exported_program = exportable_module.export(
464+
input_ids=prompt_token_ids,
465+
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
466+
)
464467
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
465468
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
466469
)

tests/models/gemma2/test_modeling_gemma2.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,10 @@ def test_export_static_cache(self):
365365
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
366366

367367
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
368-
exported_program = exportable_module.export()
368+
exported_program = exportable_module.export(
369+
input_ids=prompt_token_ids,
370+
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
371+
)
369372
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
370373
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
371374
)
@@ -389,7 +392,10 @@ def test_export_hybrid_cache(self):
389392
# Export + HybridCache
390393
model.eval()
391394
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
392-
exported_program = exportable_module.export()
395+
exported_program = exportable_module.export(
396+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
397+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
398+
)
393399

394400
# Test generation with the exported model
395401
prompt = "What is the capital of France?"

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,10 @@ def test_export_text_only_with_hybrid_cache(self):
822822
# Export + HybridCache
823823
model.eval()
824824
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
825-
exported_program = exportable_module.export()
825+
exported_program = exportable_module.export(
826+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
827+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
828+
)
826829
logging.info(f"\nExported program: {exported_program}")
827830

828831
# Test generation with the exported model

tests/models/llama/test_modeling_llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,10 @@ def test_export_static_cache(self):
353353
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
354354

355355
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
356-
exported_program = exportable_module.export()
356+
exported_program = exportable_module.export(
357+
input_ids=prompt_token_ids,
358+
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
359+
)
357360
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
358361
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
359362
)

tests/models/olmo/test_modeling_olmo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,10 @@ def test_export_static_cache(self):
384384
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
385385

386386
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
387-
exported_program = exportable_module.export()
387+
exported_program = exportable_module.export(
388+
input_ids=prompt_token_ids,
389+
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
390+
)
388391
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
389392
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
390393
)

tests/models/phi3/test_modeling_phi3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,10 @@ def test_export_static_cache(self):
417417
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
418418

419419
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
420-
exported_program = exportable_module.export()
420+
exported_program = exportable_module.export(
421+
input_ids=prompt_token_ids,
422+
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
423+
)
421424
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
422425
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
423426
)

tests/models/qwen2/test_modeling_qwen2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,11 @@ def test_export_static_cache(self):
303303
strict = version.parse(torch.__version__) != version.parse(
304304
"2.7.0"
305305
) # Due to https://github.com/pytorch/pytorch/issues/150994
306-
exported_program = exportable_module.export(strict=strict)
306+
exported_program = exportable_module.export(
307+
input_ids=prompt_token_ids,
308+
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
309+
strict=strict,
310+
)
307311
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
308312
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
309313
)

tests/models/qwen3/test_modeling_qwen3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,11 @@ def test_export_static_cache(self):
293293
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
294294

295295
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
296-
exported_program = exportable_module.export(strict=strict)
296+
exported_program = exportable_module.export(
297+
input_ids=prompt_token_ids,
298+
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
299+
strict=strict,
300+
)
297301
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
298302
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
299303
)

tests/test_executorch.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2025 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from transformers import AutoModelForCausalLM, set_seed
20+
from transformers.generation.configuration_utils import GenerationConfig
21+
from transformers.integrations.executorch import (
22+
TorchExportableModuleForDecoderOnlyLM,
23+
TorchExportableModuleWithHybridCache,
24+
TorchExportableModuleWithStaticCache,
25+
)
26+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
27+
from transformers.testing_utils import require_torch
28+
29+
30+
@require_torch
31+
class ExecutorchTest(unittest.TestCase):
32+
def setUp(self):
33+
if not is_torch_greater_or_equal_than_2_3:
34+
self.skipTest("torch >= 2.3 is required")
35+
36+
set_seed(0)
37+
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
38+
self.model.eval()
39+
40+
# Create generation config with static cache for the model
41+
self.model.generation_config = GenerationConfig(
42+
use_cache=True,
43+
cache_implementation="static",
44+
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
45+
)
46+
47+
self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
48+
self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size)
49+
self.cache_position = torch.arange(3, dtype=torch.long)
50+
51+
def test_static_cache_module_forward(self):
52+
"""Test TorchExportableModuleWithStaticCache forward with both input types"""
53+
generation_config = GenerationConfig(
54+
use_cache=True,
55+
cache_implementation="static",
56+
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
57+
)
58+
59+
# Set generation config on model
60+
self.model.generation_config = generation_config
61+
module = TorchExportableModuleWithStaticCache(self.model)
62+
63+
# Test with input_ids
64+
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
65+
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
66+
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
67+
68+
# Test with inputs_embeds
69+
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
70+
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
71+
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
72+
73+
def test_hybrid_cache_module_forward(self):
74+
"""Test TorchExportableModuleWithHybridCache forward with both input types"""
75+
config = self.model.config
76+
config.sliding_window = 16
77+
config.layer_types = ["full_attention"] * config.num_hidden_layers
78+
79+
generation_config = GenerationConfig(
80+
use_cache=True,
81+
cache_implementation="hybrid",
82+
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
83+
)
84+
85+
# Set generation config on model
86+
self.model.generation_config = generation_config
87+
module = TorchExportableModuleWithHybridCache(self.model)
88+
89+
# Test with input_ids
90+
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
91+
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
92+
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
93+
94+
# Test with inputs_embeds
95+
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
96+
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
97+
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
98+
99+
def test_decoder_only_lm_export_validation(self):
100+
"""Test TorchExportableModuleForDecoderOnlyLM export validation"""
101+
module = TorchExportableModuleForDecoderOnlyLM(self.model)
102+
103+
# Should fail with both input_ids and inputs_embeds
104+
with self.assertRaises(ValueError):
105+
module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds)
106+
107+
# Should fail with neither
108+
with self.assertRaises(ValueError):
109+
module.export()
110+
111+
def test_decoder_only_lm_export(self):
112+
"""Test TorchExportableModuleForDecoderOnlyLM export with both input types"""
113+
module = TorchExportableModuleForDecoderOnlyLM(self.model)
114+
115+
# Test export with input_ids
116+
exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position)
117+
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
118+
exported_output_ids = exported_program_ids.module()(
119+
input_ids=self.input_ids, cache_position=self.cache_position
120+
)
121+
torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4)
122+
123+
# Test export with inputs_embeds
124+
exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
125+
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
126+
exported_output_embeds = exported_program_embeds.module()(
127+
inputs_embeds=self.inputs_embeds, cache_position=self.cache_position
128+
)
129+
torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)