Skip to content

Commit 2b59207

Browse files
authored
Fix slow static cache export tests (#40261)
1 parent 56c4421 commit 2b59207

File tree

8 files changed

+16
-16
lines changed

8 files changed

+16
-16
lines changed

src/transformers/integrations/executorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,14 @@ def export(
325325
"input_ids": input_ids,
326326
"cache_position": cache_position
327327
if cache_position is not None
328-
else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device),
328+
else torch.arange(input_ids.shape[-1], dtype=torch.long, device=model_device),
329329
}
330330
else: # inputs_embeds
331331
input_kwargs = {
332332
"inputs_embeds": inputs_embeds,
333333
"cache_position": cache_position
334334
if cache_position is not None
335-
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device),
335+
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model_device),
336336
}
337337

338338
exported_program = torch.export.export(

tests/models/gemma/test_modeling_gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,8 @@ def test_export_static_cache(self):
463463

464464
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
465465
exported_program = exportable_module.export(
466-
input_ids=prompt_token_ids,
467-
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
466+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
467+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
468468
)
469469
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
470470
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

tests/models/gemma2/test_modeling_gemma2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ def test_export_static_cache(self):
368368

369369
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
370370
exported_program = exportable_module.export(
371-
input_ids=prompt_token_ids,
372-
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
371+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
372+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
373373
)
374374
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
375375
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

tests/models/llama/test_modeling_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,8 @@ def test_export_static_cache(self):
354354

355355
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
356356
exported_program = exportable_module.export(
357-
input_ids=prompt_token_ids,
358-
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
357+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
358+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
359359
)
360360
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
361361
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

tests/models/olmo/test_modeling_olmo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ def test_export_static_cache(self):
387387

388388
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
389389
exported_program = exportable_module.export(
390-
input_ids=prompt_token_ids,
391-
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
390+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
391+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
392392
)
393393
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
394394
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

tests/models/phi3/test_modeling_phi3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ def test_export_static_cache(self):
415415

416416
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
417417
exported_program = exportable_module.export(
418-
input_ids=prompt_token_ids,
419-
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
418+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
419+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
420420
)
421421
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
422422
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

tests/models/qwen2/test_modeling_qwen2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ def test_export_static_cache(self):
305305
"2.7.0"
306306
) # Due to https://github.com/pytorch/pytorch/issues/150994
307307
exported_program = exportable_module.export(
308-
input_ids=prompt_token_ids,
309-
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
308+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
309+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
310310
strict=strict,
311311
)
312312
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(

tests/models/qwen3/test_modeling_qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ def test_export_static_cache(self):
295295

296296
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
297297
exported_program = exportable_module.export(
298-
input_ids=prompt_token_ids,
299-
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
298+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
299+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
300300
strict=strict,
301301
)
302302
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(

0 commit comments

Comments
 (0)