Skip to content

Commit 80b06b0

Browse files
committed
update
1 parent 42c19fd commit 80b06b0

File tree

2 files changed

+143
-2
lines changed

2 files changed

+143
-2
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@
6666
global_rng = random.Random()
6767

6868
logger = get_logger(__name__)
69-
69+
logger.warning(
70+
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
71+
"Please use `diffusers.utils.torch_utils` instead. "
72+
)
7073
_required_peft_version = is_peft_available() and version.parse(
7174
version.parse(importlib.metadata.version("peft")).base_version
7275
) > version.parse("0.5")

src/diffusers/utils/torch_utils.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"""
1717

1818
import functools
19-
from typing import List, Optional, Tuple, Union
19+
import os
20+
from typing import Callable, Dict, List, Optional, Tuple, Union
2021

2122
from . import logging
2223
from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version
@@ -36,6 +37,116 @@ def maybe_allow_in_graph(cls):
3637
return cls
3738

3839

40+
# Behaviour flags
41+
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
42+
# Function definitions
43+
BACKEND_EMPTY_CACHE = {
44+
"cuda": torch.cuda.empty_cache,
45+
"xpu": torch.xpu.empty_cache,
46+
"cpu": None,
47+
"mps": torch.mps.empty_cache,
48+
"default": None,
49+
}
50+
BACKEND_DEVICE_COUNT = {
51+
"cuda": torch.cuda.device_count,
52+
"xpu": torch.xpu.device_count,
53+
"cpu": lambda: 0,
54+
"mps": lambda: 0,
55+
"default": 0,
56+
}
57+
BACKEND_MANUAL_SEED = {
58+
"cuda": torch.cuda.manual_seed,
59+
"xpu": torch.xpu.manual_seed,
60+
"cpu": torch.manual_seed,
61+
"mps": torch.mps.manual_seed,
62+
"default": torch.manual_seed,
63+
}
64+
BACKEND_RESET_PEAK_MEMORY_STATS = {
65+
"cuda": torch.cuda.reset_peak_memory_stats,
66+
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
67+
"cpu": None,
68+
"mps": None,
69+
"default": None,
70+
}
71+
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
72+
"cuda": torch.cuda.reset_max_memory_allocated,
73+
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
74+
"cpu": None,
75+
"mps": None,
76+
"default": None,
77+
}
78+
BACKEND_MAX_MEMORY_ALLOCATED = {
79+
"cuda": torch.cuda.max_memory_allocated,
80+
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
81+
"cpu": 0,
82+
"mps": 0,
83+
"default": 0,
84+
}
85+
BACKEND_SYNCHRONIZE = {
86+
"cuda": torch.cuda.synchronize,
87+
"xpu": getattr(torch.xpu, "synchronize", None),
88+
"cpu": None,
89+
"mps": None,
90+
"default": None,
91+
}
92+
93+
94+
# This dispatches a defined function according to the accelerator from the function definitions.
95+
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
96+
if device not in dispatch_table:
97+
return dispatch_table["default"](*args, **kwargs)
98+
99+
fn = dispatch_table[device]
100+
101+
# Some device agnostic functions return values. Need to guard against 'None' instead at
102+
# user level
103+
if not callable(fn):
104+
return fn
105+
106+
return fn(*args, **kwargs)
107+
108+
109+
# These are callables which automatically dispatch the function specific to the accelerator
110+
def backend_manual_seed(device: str, seed: int):
111+
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
112+
113+
114+
def backend_synchronize(device: str):
115+
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
116+
117+
118+
def backend_empty_cache(device: str):
119+
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
120+
121+
122+
def backend_device_count(device: str):
123+
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
124+
125+
126+
def backend_reset_peak_memory_stats(device: str):
127+
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
128+
129+
130+
def backend_reset_max_memory_allocated(device: str):
131+
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
132+
133+
134+
def backend_max_memory_allocated(device: str):
135+
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
136+
137+
138+
# These are callables which return boolean behaviour flags and can be used to specify some
139+
# device agnostic alternative where the feature is unsupported.
140+
def backend_supports_training(device: str):
141+
if not is_torch_available():
142+
return False
143+
144+
if device not in BACKEND_SUPPORTS_TRAINING:
145+
device = "default"
146+
147+
return BACKEND_SUPPORTS_TRAINING[device]
148+
149+
39150
def randn_tensor(
40151
shape: Union[Tuple, List],
41152
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
@@ -197,3 +308,30 @@ def device_synchronize(device_type: Optional[str] = None):
197308
device_type = get_device()
198309
device_mod = getattr(torch, device_type, torch.cuda)
199310
device_mod.synchronize()
311+
312+
313+
def enable_full_determinism():
314+
"""
315+
Helper function for reproducible behavior during distributed training. See
316+
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
317+
"""
318+
# Enable PyTorch deterministic mode. This potentially requires either the environment
319+
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
320+
# depending on the CUDA version, so we set them both here
321+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
322+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
323+
torch.use_deterministic_algorithms(True)
324+
325+
# Enable CUDNN deterministic mode
326+
torch.backends.cudnn.deterministic = True
327+
torch.backends.cudnn.benchmark = False
328+
torch.backends.cuda.matmul.allow_tf32 = False
329+
330+
331+
def disable_full_determinism():
332+
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
333+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
334+
torch.use_deterministic_algorithms(False)
335+
336+
337+
torch_device = get_device()

0 commit comments

Comments
 (0)