Skip to content

Commit a393860

Browse files
committed
update device functions
1 parent cb7d9d5 commit a393860

File tree

2 files changed

+54
-16
lines changed

2 files changed

+54
-16
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import gc
23
import importlib
34
import importlib.metadata
45
import inspect
@@ -86,7 +87,12 @@
8687
) from e
8788
logger.info(f"torch_device overrode to {torch_device}")
8889
else:
89-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
90+
if torch.cuda.is_available():
91+
torch_device = "cuda"
92+
elif torch.xpu.is_available():
93+
torch_device = "xpu"
94+
else:
95+
torch_device = "cpu"
9096
is_torch_higher_equal_than_1_12 = version.parse(
9197
version.parse(torch.__version__).base_version
9298
) >= version.parse("1.12")
@@ -1055,12 +1061,34 @@ def _is_torch_fp64_available(device):
10551061
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
10561062
if is_torch_available():
10571063
# Behaviour flags
1058-
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
1064+
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
10591065

10601066
# Function definitions
1061-
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
1062-
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
1063-
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
1067+
BACKEND_EMPTY_CACHE = {
1068+
"cuda": torch.cuda.empty_cache,
1069+
"xpu": torch.xpu.empty_cache,
1070+
"cpu": None,
1071+
"mps": None,
1072+
"default": None,
1073+
}
1074+
BACKEND_DEVICE_COUNT = {
1075+
"cuda": torch.cuda.device_count,
1076+
"xpu": torch.xpu.device_count,
1077+
"cpu": lambda: 0,
1078+
"mps": lambda: 0,
1079+
"default": 0,
1080+
}
1081+
BACKEND_MANUAL_SEED = {
1082+
"cuda": torch.cuda.manual_seed,
1083+
"xpu": torch.xpu.manual_seed,
1084+
"cpu": torch.manual_seed,
1085+
"default": torch.manual_seed,
1086+
}
1087+
BACKEND_RESET_PEAK_MEMORY_STATS = {
1088+
"cuda": torch.cuda.reset_peak_memory_stats(),
1089+
"xpu": torch.xpu.reset_peak_memory_stats(),
1090+
"default": None,
1091+
}
10641092

10651093

10661094
# This dispatches a defined function according to the accelerator from the function definitions.
@@ -1091,6 +1119,10 @@ def backend_device_count(device: str):
10911119
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
10921120

10931121

1122+
def backend_reset_peak_memory(device: str):
1123+
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
1124+
1125+
10941126
# These are callables which return boolean behaviour flags and can be used to specify some
10951127
# device agnostic alternative where the feature is unsupported.
10961128
def backend_supports_training(device: str):
@@ -1147,3 +1179,13 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
11471179
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
11481180
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
11491181
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
1182+
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEM_STATS")
1183+
1184+
1185+
@require_torch
1186+
def flush_memory(device: str, gc_collect=False, reset_mem_stats=False):
1187+
if gc_collect:
1188+
gc.collect()
1189+
if reset_mem_stats:
1190+
backend_reset_peak_memory(device)
1191+
backend_empty_cache(device)

tests/pipelines/deepfloyd_if/test_if.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
1716
import unittest
1817

1918
import torch
@@ -24,9 +23,10 @@
2423
from diffusers.models.attention_processor import AttnAddedKVProcessor
2524
from diffusers.utils.import_utils import is_xformers_available
2625
from diffusers.utils.testing_utils import (
26+
flush_memory,
2727
load_numpy,
2828
require_accelerator,
29-
require_torch_gpu,
29+
require_torch_accelerator,
3030
skip_mps,
3131
slow,
3232
torch_device,
@@ -91,28 +91,24 @@ def test_xformers_attention_forwardGenerator_pass(self):
9191

9292

9393
@slow
94-
@require_torch_gpu
94+
@require_torch_accelerator
9595
class IFPipelineSlowTests(unittest.TestCase):
9696
def setUp(self):
9797
# clean up the VRAM before each test
9898
super().setUp()
99-
gc.collect()
100-
torch.cuda.empty_cache()
99+
flush_memory(torch_device, gc_collect=True)
101100

102101
def tearDown(self):
103102
# clean up the VRAM after each test
104103
super().tearDown()
105-
gc.collect()
106-
torch.cuda.empty_cache()
104+
flush_memory(torch_device, gc_collect=True)
107105

108106
def test_if_text_to_image(self):
109107
pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
110108
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
111-
pipe.enable_model_cpu_offload()
109+
pipe.enable_model_cpu_offload(device=torch_device)
112110

113-
torch.cuda.reset_max_memory_allocated()
114-
torch.cuda.empty_cache()
115-
torch.cuda.reset_peak_memory_stats()
111+
flush_memory(torch_device, reset_mem_stats=True)
116112

117113
generator = torch.Generator(device="cpu").manual_seed(0)
118114
output = pipe(

0 commit comments

Comments
 (0)