2727from transformers import CONFIG_MAPPING , Blip2Config , Blip2QFormerConfig , Blip2VisionConfig
2828from transformers .testing_utils import (
2929 require_torch ,
30+ require_torch_accelerator ,
3031 require_torch_fp16 ,
3132 require_torch_gpu ,
3233 require_torch_multi_accelerator ,
@@ -1565,7 +1566,7 @@ def test_forward_signature(self):
15651566 self .assertListEqual (arg_names [: len (expected_arg_names )], expected_arg_names )
15661567
15671568 @slow
1568- @require_torch_gpu
1569+ @require_torch_accelerator
15691570 def test_model_from_pretrained (self ):
15701571 model_name = "Salesforce/blip2-itm-vit-g"
15711572 model = Blip2TextModelWithProjection .from_pretrained (model_name )
@@ -2191,7 +2192,7 @@ def test_expansion_in_processing(self):
21912192
21922193 self .assertTrue (generated_text_expanded == generated_text )
21932194
2194- @require_torch_gpu
2195+ @require_torch_accelerator
21952196 def test_inference_itm (self ):
21962197 model_name = "Salesforce/blip2-itm-vit-g"
21972198 processor = Blip2Processor .from_pretrained (model_name )
@@ -2210,7 +2211,7 @@ def test_inference_itm(self):
22102211 self .assertTrue (torch .allclose (torch .nn .Softmax ()(out_itm [0 ].cpu ()), expected_scores , rtol = 1e-3 , atol = 1e-3 ))
22112212 self .assertTrue (torch .allclose (out [0 ].cpu (), torch .Tensor ([[0.4406 ]]), rtol = 1e-3 , atol = 1e-3 ))
22122213
2213- @require_torch_gpu
2214+ @require_torch_accelerator
22142215 @require_torch_fp16
22152216 def test_inference_itm_fp16 (self ):
22162217 model_name = "Salesforce/blip2-itm-vit-g"
@@ -2232,7 +2233,7 @@ def test_inference_itm_fp16(self):
22322233 )
22332234 self .assertTrue (torch .allclose (out [0 ].cpu ().float (), torch .Tensor ([[0.4406 ]]), rtol = 1e-3 , atol = 1e-3 ))
22342235
2235- @require_torch_gpu
2236+ @require_torch_accelerator
22362237 @require_torch_fp16
22372238 def test_inference_vision_with_projection_fp16 (self ):
22382239 model_name = "Salesforce/blip2-itm-vit-g"
@@ -2256,7 +2257,7 @@ def test_inference_vision_with_projection_fp16(self):
22562257 ]
22572258 self .assertTrue (np .allclose (out .image_embeds [0 ][0 ][:6 ].tolist (), expected_image_embeds , atol = 1e-3 ))
22582259
2259- @require_torch_gpu
2260+ @require_torch_accelerator
22602261 @require_torch_fp16
22612262 def test_inference_text_with_projection_fp16 (self ):
22622263 model_name = "Salesforce/blip2-itm-vit-g"
0 commit comments