Skip to content

Commit ce12925

Browse files
committed
update
1 parent 80b06b0 commit ce12925

File tree

4 files changed

+87
-34
lines changed

4 files changed

+87
-34
lines changed

examples/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,25 @@
2525
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
2626
sys.path.insert(1, git_repo_path)
2727

28+
# Add parent directory to path so we can import from tests
29+
repo_root = abspath(dirname(dirname(__file__)))
30+
if repo_root not in sys.path:
31+
sys.path.insert(0, repo_root)
32+
2833

2934
# silence FutureWarning warnings in tests since often we can't act on them until
3035
# they become normal warnings - i.e. the tests still need to test the current functionality
3136
warnings.simplefilter(action="ignore", category=FutureWarning)
3237

3338

3439
def pytest_addoption(parser):
35-
from diffusers.utils.testing_utils import pytest_addoption_shared
40+
from tests.testing_utils import pytest_addoption_shared
3641

3742
pytest_addoption_shared(parser)
3843

3944

4045
def pytest_terminal_summary(terminalreporter):
41-
from diffusers.utils.testing_utils import pytest_terminal_summary_main
46+
from tests.testing_utils import pytest_terminal_summary_main
4247

4348
make_reports = terminalreporter.config.getoption("--make-reports")
4449
if make_reports:

examples/controlnet/train_controlnet_sd3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import os
2525
import random
2626
import shutil
27+
28+
# Add repo root to path to import from tests
2729
from pathlib import Path
2830

2931
import accelerate
@@ -54,8 +56,7 @@
5456
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
5557
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
5658
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
57-
from diffusers.utils.testing_utils import backend_empty_cache
58-
from diffusers.utils.torch_utils import is_compiled_module
59+
from diffusers.utils.torch_utils import backend_empty_cache, is_compiled_module
5960

6061

6162
if is_wandb_available():

examples/vqgan/test_vqgan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@
2424
import torch
2525

2626
from diffusers import VQModel
27-
from diffusers.utils.testing_utils import require_timm
2827

2928

29+
# Add parent directories to path to import from tests
3030
sys.path.append("..")
31+
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
32+
if repo_root not in sys.path:
33+
sys.path.insert(0, repo_root)
34+
3135
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
3236

37+
from tests.testing_utils import require_timm # noqa
38+
3339

3440
logging.basicConfig(level=logging.DEBUG)
3541

src/diffusers/utils/testing_utils.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
logger = get_logger(__name__)
6969
logger.warning(
7070
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
71-
"Please use `diffusers.utils.torch_utils` instead. "
71+
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
7272
)
7373
_required_peft_version = is_peft_available() and version.parse(
7474
version.parse(importlib.metadata.version("peft")).base_version
@@ -804,10 +804,9 @@ def export_to_ply(mesh, output_ply_path: str = None):
804804
f.write(format.pack(*vertex))
805805

806806
if faces is not None:
807-
format = struct.Struct("<B3I")
808807
for tri in faces.tolist():
809808
f.write(format.pack(len(tri), *tri))
810-
809+
format = struct.Struct("<B3I")
811810
return output_ply_path
812811

813812

@@ -1147,23 +1146,23 @@ def enable_full_determinism():
11471146
Helper function for reproducible behavior during distributed training. See
11481147
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
11491148
"""
1150-
# Enable PyTorch deterministic mode. This potentially requires either the environment
1151-
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
1152-
# depending on the CUDA version, so we set them both here
1153-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
1154-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
1155-
torch.use_deterministic_algorithms(True)
1149+
from .torch_utils import enable_full_determinism as _enable_full_determinism
11561150

1157-
# Enable CUDNN deterministic mode
1158-
torch.backends.cudnn.deterministic = True
1159-
torch.backends.cudnn.benchmark = False
1160-
torch.backends.cuda.matmul.allow_tf32 = False
1151+
logger.warning(
1152+
"enable_full_determinism has been moved to diffusers.utils.torch_utils. "
1153+
"Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1154+
)
1155+
return _enable_full_determinism()
11611156

11621157

11631158
def disable_full_determinism():
1164-
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
1165-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
1166-
torch.use_deterministic_algorithms(False)
1159+
from .torch_utils import disable_full_determinism as _disable_full_determinism
1160+
1161+
logger.warning(
1162+
"disable_full_determinism has been moved to diffusers.utils.torch_utils. "
1163+
"Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1164+
)
1165+
return _disable_full_determinism()
11671166

11681167

11691168
# Utils for custom and alternative accelerator devices
@@ -1285,43 +1284,85 @@ def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable],
12851284

12861285
# These are callables which automatically dispatch the function specific to the accelerator
12871286
def backend_manual_seed(device: str, seed: int):
1288-
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
1287+
from .torch_utils import backend_manual_seed as _backend_manual_seed
1288+
1289+
logger.warning(
1290+
"backend_manual_seed has been moved to diffusers.utils.torch_utils. "
1291+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1292+
)
1293+
return _backend_manual_seed(device, seed)
12891294

12901295

12911296
def backend_synchronize(device: str):
1292-
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
1297+
from .torch_utils import backend_synchronize as _backend_synchronize
1298+
1299+
logger.warning(
1300+
"backend_synchronize has been moved to diffusers.utils.torch_utils. "
1301+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1302+
)
1303+
return _backend_synchronize(device)
12931304

12941305

12951306
def backend_empty_cache(device: str):
1296-
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
1307+
from .torch_utils import backend_empty_cache as _backend_empty_cache
1308+
1309+
logger.warning(
1310+
"backend_empty_cache has been moved to diffusers.utils.torch_utils. "
1311+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1312+
)
1313+
return _backend_empty_cache(device)
12971314

12981315

12991316
def backend_device_count(device: str):
1300-
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
1317+
from .torch_utils import backend_device_count as _backend_device_count
1318+
1319+
logger.warning(
1320+
"backend_device_count has been moved to diffusers.utils.torch_utils. "
1321+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1322+
)
1323+
return _backend_device_count(device)
13011324

13021325

13031326
def backend_reset_peak_memory_stats(device: str):
1304-
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
1327+
from .torch_utils import backend_reset_peak_memory_stats as _backend_reset_peak_memory_stats
1328+
1329+
logger.warning(
1330+
"backend_reset_peak_memory_stats has been moved to diffusers.utils.torch_utils. "
1331+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1332+
)
1333+
return _backend_reset_peak_memory_stats(device)
13051334

13061335

13071336
def backend_reset_max_memory_allocated(device: str):
1308-
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
1337+
from .torch_utils import backend_reset_max_memory_allocated as _backend_reset_max_memory_allocated
1338+
1339+
logger.warning(
1340+
"backend_reset_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
1341+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1342+
)
1343+
return _backend_reset_max_memory_allocated(device)
13091344

13101345

13111346
def backend_max_memory_allocated(device: str):
1312-
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
1347+
from .torch_utils import backend_max_memory_allocated as _backend_max_memory_allocated
1348+
1349+
logger.warning(
1350+
"backend_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
1351+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1352+
)
1353+
return _backend_max_memory_allocated(device)
13131354

13141355

13151356
# These are callables which return boolean behaviour flags and can be used to specify some
13161357
# device agnostic alternative where the feature is unsupported.
13171358
def backend_supports_training(device: str):
1318-
if not is_torch_available():
1319-
return False
1320-
1321-
if device not in BACKEND_SUPPORTS_TRAINING:
1322-
device = "default"
1359+
from .torch_utils import backend_supports_training as _backend_supports_training
13231360

1324-
return BACKEND_SUPPORTS_TRAINING[device]
1361+
logger.warning(
1362+
"backend_supports_training has been moved to diffusers.utils.torch_utils. "
1363+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1364+
)
1365+
return _backend_supports_training(device)
13251366

13261367

13271368
# Guard for when Torch is not available

0 commit comments

Comments
 (0)