Skip to content

Commit 2fa876d

Browse files
authored
[tests] make cuda-only tests device-agnostic (#35607)
* intial commit * remove unrelated files * further remove * Update test_trainer.py * fix style
1 parent e6f9b03 commit 2fa876d

File tree

18 files changed

+57
-47
lines changed

18 files changed

+57
-47
lines changed

tests/fsdp/test_fsdp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
require_accelerate,
3333
require_fsdp,
3434
require_torch_accelerator,
35-
require_torch_gpu,
3635
require_torch_multi_accelerator,
3736
slow,
3837
torch_device,
@@ -288,7 +287,7 @@ def test_training_and_can_resume_normally(self, state_dict_type):
288287

289288
@require_torch_multi_accelerator
290289
@slow
291-
@require_torch_gpu
290+
@require_torch_accelerator
292291
@require_fsdp
293292
def test_fsdp_cpu_offloading(self):
294293
try:

tests/generation/test_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
require_flash_attn,
3434
require_optimum_quanto,
3535
require_torch,
36+
require_torch_accelerator,
3637
require_torch_gpu,
3738
require_torch_multi_accelerator,
3839
require_torch_multi_gpu,
@@ -2043,7 +2044,7 @@ def test_generate_with_quant_cache(self):
20432044
model.generate(**generation_kwargs, **inputs_dict)
20442045

20452046
@pytest.mark.generate
2046-
@require_torch_gpu
2047+
@require_torch_accelerator
20472048
@slow
20482049
def test_generate_compile_model_forward(self):
20492050
"""
@@ -3791,10 +3792,12 @@ def test_assisted_decoding_in_different_gpu(self):
37913792
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
37923793

37933794
@slow
3794-
@require_torch_gpu
3795+
@require_torch_accelerator
37953796
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
37963797
# PT-only test: TF doesn't support assisted decoding yet.
3797-
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
3798+
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
3799+
torch_device
3800+
)
37983801
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
37993802
"cpu"
38003803
)

tests/models/blip_2/test_modeling_blip_2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
2828
from transformers.testing_utils import (
2929
require_torch,
30+
require_torch_accelerator,
3031
require_torch_fp16,
3132
require_torch_gpu,
3233
require_torch_multi_accelerator,
@@ -1565,7 +1566,7 @@ def test_forward_signature(self):
15651566
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
15661567

15671568
@slow
1568-
@require_torch_gpu
1569+
@require_torch_accelerator
15691570
def test_model_from_pretrained(self):
15701571
model_name = "Salesforce/blip2-itm-vit-g"
15711572
model = Blip2TextModelWithProjection.from_pretrained(model_name)
@@ -2191,7 +2192,7 @@ def test_expansion_in_processing(self):
21912192

21922193
self.assertTrue(generated_text_expanded == generated_text)
21932194

2194-
@require_torch_gpu
2195+
@require_torch_accelerator
21952196
def test_inference_itm(self):
21962197
model_name = "Salesforce/blip2-itm-vit-g"
21972198
processor = Blip2Processor.from_pretrained(model_name)
@@ -2210,7 +2211,7 @@ def test_inference_itm(self):
22102211
self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3))
22112212
self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
22122213

2213-
@require_torch_gpu
2214+
@require_torch_accelerator
22142215
@require_torch_fp16
22152216
def test_inference_itm_fp16(self):
22162217
model_name = "Salesforce/blip2-itm-vit-g"
@@ -2232,7 +2233,7 @@ def test_inference_itm_fp16(self):
22322233
)
22332234
self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
22342235

2235-
@require_torch_gpu
2236+
@require_torch_accelerator
22362237
@require_torch_fp16
22372238
def test_inference_vision_with_projection_fp16(self):
22382239
model_name = "Salesforce/blip2-itm-vit-g"
@@ -2256,7 +2257,7 @@ def test_inference_vision_with_projection_fp16(self):
22562257
]
22572258
self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3))
22582259

2259-
@require_torch_gpu
2260+
@require_torch_accelerator
22602261
@require_torch_fp16
22612262
def test_inference_text_with_projection_fp16(self):
22622263
model_name = "Salesforce/blip2-itm-vit-g"

tests/models/diffllama/test_modeling_diffllama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def test_eager_matches_sdpa_generate(self):
676676
)
677677

678678

679-
@require_torch_gpu
679+
@require_torch_accelerator
680680
class DiffLlamaIntegrationTest(unittest.TestCase):
681681
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
682682
# Depending on the hardware we get different logits / generations
@@ -689,7 +689,7 @@ def setUpClass(cls):
689689
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
690690

691691
@slow
692-
@require_torch_gpu
692+
@require_torch_accelerator
693693
@require_read_token
694694
def test_compile_static_cache(self):
695695
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2

tests/models/falcon_mamba/test_modeling_falcon_mamba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from transformers.testing_utils import (
2424
require_bitsandbytes,
2525
require_torch,
26-
require_torch_gpu,
26+
require_torch_accelerator,
2727
require_torch_multi_gpu,
2828
slow,
2929
torch_device,
@@ -426,7 +426,7 @@ def recursive_check(tuple_object, dict_object):
426426

427427

428428
@require_torch
429-
@require_torch_gpu
429+
@require_torch_accelerator
430430
@slow
431431
class FalconMambaIntegrationTests(unittest.TestCase):
432432
def setUp(self):

tests/models/fuyu/test_modeling_fuyu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from parameterized import parameterized
2323

2424
from transformers import FuyuConfig, is_torch_available, is_vision_available
25-
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
25+
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
2626
from transformers.utils import cached_property
2727

2828
from ...generation.test_utils import GenerationTesterMixin
@@ -327,7 +327,7 @@ def test_model_parallelism(self):
327327

328328

329329
@slow
330-
@require_torch_gpu
330+
@require_torch_accelerator
331331
class FuyuModelIntegrationTest(unittest.TestCase):
332332
@cached_property
333333
def default_processor(self):

tests/models/llama/test_modeling_llama.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
require_read_token,
2727
require_torch,
2828
require_torch_accelerator,
29-
require_torch_gpu,
3029
slow,
3130
torch_device,
3231
)
@@ -541,7 +540,7 @@ def _reinitialize_config(base_config, new_kwargs):
541540
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
542541

543542

544-
@require_torch_gpu
543+
@require_torch_accelerator
545544
class LlamaIntegrationTest(unittest.TestCase):
546545
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
547546
# Depending on the hardware we get different logits / generations
@@ -695,7 +694,7 @@ def test_model_7b_dola_generation(self):
695694
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
696695

697696
@slow
698-
@require_torch_gpu
697+
@require_torch_accelerator
699698
@require_read_token
700699
def test_compile_static_cache(self):
701700
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2

tests/models/mistral/test_modeling_mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
424424
self.skipTest(reason="Mistral flash attention does not support right padding")
425425

426426

427-
@require_torch_gpu
427+
@require_torch_accelerator
428428
class MistralIntegrationTest(unittest.TestCase):
429429
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
430430
# Depending on the hardware we get different logits / generations

tests/models/mixtral/test_modeling_mixtral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from transformers.testing_utils import (
2323
require_flash_attn,
2424
require_torch,
25+
require_torch_accelerator,
2526
require_torch_gpu,
2627
slow,
2728
torch_device,
@@ -471,7 +472,7 @@ def setUpClass(cls):
471472
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
472473

473474
@slow
474-
@require_torch_gpu
475+
@require_torch_accelerator
475476
def test_small_model_logits(self):
476477
model_id = "hf-internal-testing/Mixtral-tiny"
477478
dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device)
@@ -507,7 +508,7 @@ def test_small_model_logits(self):
507508
)
508509

509510
@slow
510-
@require_torch_gpu
511+
@require_torch_accelerator
511512
def test_small_model_logits_batched(self):
512513
model_id = "hf-internal-testing/Mixtral-tiny"
513514
dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device)

tests/models/nemotron/test_modeling_nemotron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
require_flash_attn,
2727
require_read_token,
2828
require_torch,
29+
require_torch_accelerator,
2930
require_torch_gpu,
3031
require_torch_sdpa,
3132
slow,
@@ -103,7 +104,7 @@ def test_model_outputs_equivalence(self, **kwargs):
103104
pass
104105

105106
@require_torch_sdpa
106-
@require_torch_gpu
107+
@require_torch_accelerator
107108
@slow
108109
def test_sdpa_equivalence(self):
109110
for model_class in self.all_model_classes:

0 commit comments

Comments
 (0)