diff --git a/examples/dreambooth/accelerate_config.yaml b/examples/dreambooth/accelerate_config.yaml new file mode 100644 index 000000000000..30b1c1cb9ff0 --- /dev/null +++ b/examples/dreambooth/accelerate_config.yaml @@ -0,0 +1,14 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_config_file: ds_config_zero2.json + zero3_init_flag: false +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/examples/dreambooth/ds_config_zero2.json b/examples/dreambooth/ds_config_zero2.json new file mode 100644 index 000000000000..41227d7011ea --- /dev/null +++ b/examples/dreambooth/ds_config_zero2.json @@ -0,0 +1,38 @@ +{ + "train_batch_size": 2, + "gradient_accumulation_steps": 1, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "offload_param": { + "device": "cpu" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 50000000, + "allgather_bucket_size": 50000000 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-5, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 1e-2 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-5, + "warmup_num_steps": 100 + } + }, + "steps_per_print": 10, + "wall_clock_breakdown": false, + "communication_data_type": "fp16" +} \ No newline at end of file diff --git a/examples/dreambooth/setup_training_env.sh b/examples/dreambooth/setup_training_env.sh new file mode 100755 index 000000000000..a4d2752ab02d --- /dev/null +++ b/examples/dreambooth/setup_training_env.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Extend NCCL timeouts +export NCCL_SOCKET_TIMEOUT=7200000 +export DEEPSPEED_TIMEOUT=7200000 + +# Set CPU threading optimizations +export OMP_NUM_THREADS=1 +export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb=512 + +# Increase system shared memory limits +sudo sysctl -w kernel.shmmax=85899345920 +sudo sysctl -w kernel.shmall=2097152 + +# Enable NCCL debugging for diagnostics +export NCCL_DEBUG=INFO + +# Optional: Set NCCL topology optimization +# Uncomment if needed after checking nvidia-smi topo -m +# export NCCL_P2P_LEVEL=PHB + +# Persist changes to sysctl +echo "kernel.shmmax=85899345920" | sudo tee -a /etc/sysctl.conf +echo "kernel.shmall=2097152" | sudo tee -a /etc/sysctl.conf +sudo sysctl -p \ No newline at end of file diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a1ab8cda431f..b9edd203d087 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -18,6 +18,7 @@ import functools import os from typing import Callable, Dict, List, Optional, Tuple, Union +import warnings from . import logging from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version @@ -222,8 +223,23 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T x = x.to(dtype=torch.float32) # FFT - x_freq = fftn(x, dim=(-2, -1)) - x_freq = fftshift(x_freq, dim=(-2, -1)) + # When running with torch.float16, PyTorch may emit a UserWarning about + # ComplexHalf (experimental) support when performing FFTs. This warning is + # noisy for users of the FreeU feature and doesn't change the behaviour of + # the algorithm here. We therefore locally suppress that specific warning + # around the FFT calls when the input dtype is float16. + if x.dtype == torch.float16: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="ComplexHalf support is experimental and many operators don't support it yet.*", + category=UserWarning, + ) + x_freq = fftn(x, dim=(-2, -1)) + x_freq = fftshift(x_freq, dim=(-2, -1)) + else: + x_freq = fftn(x, dim=(-2, -1)) + x_freq = fftshift(x_freq, dim=(-2, -1)) B, C, H, W = x_freq.shape mask = torch.ones((B, C, H, W), device=x.device) @@ -234,7 +250,16 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T # IFFT x_freq = ifftshift(x_freq, dim=(-2, -1)) - x_filtered = ifftn(x_freq, dim=(-2, -1)).real + if x.dtype == torch.float16: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="ComplexHalf support is experimental and many operators don't support it yet.*", + category=UserWarning, + ) + x_filtered = ifftn(x_freq, dim=(-2, -1)).real + else: + x_filtered = ifftn(x_freq, dim=(-2, -1)).real return x_filtered.to(dtype=x_in.dtype)