Skip to content

Commit 2ebe9ca

Browse files
author
sdp
committed
enable xpu
1 parent c934720 commit 2ebe9ca

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
require_peft_backend,
4747
require_torch_accelerator,
4848
require_torch_accelerator_with_fp16,
49-
require_torch_gpu,
5049
skip_mps,
5150
slow,
5251
torch_all_close,
@@ -980,7 +979,7 @@ def test_ip_adapter_plus(self):
980979
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
981980
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
982981

983-
@require_torch_gpu
982+
@require_torch_accelerator
984983
@parameterized.expand(
985984
[
986985
("hf-internal-testing/unet2d-sharded-dummy", None),
@@ -996,7 +995,7 @@ def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
996995
assert loaded_model
997996
assert new_output.sample.shape == (4, 4, 16, 16)
998997

999-
@require_torch_gpu
998+
@require_torch_accelerator
1000999
@parameterized.expand(
10011000
[
10021001
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),

0 commit comments

Comments
 (0)