Skip to content

Commit 153ca9a

Browse files
authored
Merge branch 'main' into inplace_sum_and_remove_padding_and_better_memory_count
2 parents 272537b + a1f9a71 commit 153ca9a

28 files changed

+299
-209
lines changed

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8282
return hidden_states
8383

8484

85+
class SanaModulatedNorm(nn.Module):
86+
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
87+
super().__init__()
88+
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
89+
90+
def forward(
91+
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
92+
) -> torch.Tensor:
93+
hidden_states = self.norm(hidden_states)
94+
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
95+
hidden_states = hidden_states * (1 + scale) + shift
96+
return hidden_states
97+
98+
8599
class SanaTransformerBlock(nn.Module):
86100
r"""
87101
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -221,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
221235
"""
222236

223237
_supports_gradient_checkpointing = True
224-
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
238+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
225239

226240
@register_to_config
227241
def __init__(
@@ -288,8 +302,7 @@ def __init__(
288302

289303
# 4. Output blocks
290304
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
291-
292-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
305+
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
293306
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
294307

295308
self.gradient_checkpointing = False
@@ -462,13 +475,8 @@ def custom_forward(*inputs):
462475
)
463476

464477
# 3. Normalization
465-
shift, scale = (
466-
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
467-
).chunk(2, dim=1)
468-
hidden_states = self.norm_out(hidden_states)
478+
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
469479

470-
# 4. Modulation
471-
hidden_states = hidden_states * (1 + scale) + shift
472480
hidden_states = self.proj_out(hidden_states)
473481

474482
# 5. Unpatchify

src/diffusers/utils/testing_utils.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@
8686
) from e
8787
logger.info(f"torch_device overrode to {torch_device}")
8888
else:
89-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
89+
if torch.cuda.is_available():
90+
torch_device = "cuda"
91+
elif torch.xpu.is_available():
92+
torch_device = "xpu"
93+
else:
94+
torch_device = "cpu"
9095
is_torch_higher_equal_than_1_12 = version.parse(
9196
version.parse(torch.__version__).base_version
9297
) >= version.parse("1.12")
@@ -1067,12 +1072,51 @@ def _is_torch_fp64_available(device):
10671072
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
10681073
if is_torch_available():
10691074
# Behaviour flags
1070-
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
1075+
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
10711076

10721077
# Function definitions
1073-
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
1074-
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
1075-
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
1078+
BACKEND_EMPTY_CACHE = {
1079+
"cuda": torch.cuda.empty_cache,
1080+
"xpu": torch.xpu.empty_cache,
1081+
"cpu": None,
1082+
"mps": torch.mps.empty_cache,
1083+
"default": None,
1084+
}
1085+
BACKEND_DEVICE_COUNT = {
1086+
"cuda": torch.cuda.device_count,
1087+
"xpu": torch.xpu.device_count,
1088+
"cpu": lambda: 0,
1089+
"mps": lambda: 0,
1090+
"default": 0,
1091+
}
1092+
BACKEND_MANUAL_SEED = {
1093+
"cuda": torch.cuda.manual_seed,
1094+
"xpu": torch.xpu.manual_seed,
1095+
"cpu": torch.manual_seed,
1096+
"mps": torch.mps.manual_seed,
1097+
"default": torch.manual_seed,
1098+
}
1099+
BACKEND_RESET_PEAK_MEMORY_STATS = {
1100+
"cuda": torch.cuda.reset_peak_memory_stats,
1101+
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
1102+
"cpu": None,
1103+
"mps": None,
1104+
"default": None,
1105+
}
1106+
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
1107+
"cuda": torch.cuda.reset_max_memory_allocated,
1108+
"xpu": None,
1109+
"cpu": None,
1110+
"mps": None,
1111+
"default": None,
1112+
}
1113+
BACKEND_MAX_MEMORY_ALLOCATED = {
1114+
"cuda": torch.cuda.max_memory_allocated,
1115+
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
1116+
"cpu": 0,
1117+
"mps": 0,
1118+
"default": 0,
1119+
}
10761120

10771121

10781122
# This dispatches a defined function according to the accelerator from the function definitions.
@@ -1103,6 +1147,18 @@ def backend_device_count(device: str):
11031147
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
11041148

11051149

1150+
def backend_reset_peak_memory_stats(device: str):
1151+
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
1152+
1153+
1154+
def backend_reset_max_memory_allocated(device: str):
1155+
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
1156+
1157+
1158+
def backend_max_memory_allocated(device: str):
1159+
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
1160+
1161+
11061162
# These are callables which return boolean behaviour flags and can be used to specify some
11071163
# device agnostic alternative where the feature is unsupported.
11081164
def backend_supports_training(device: str):
@@ -1159,3 +1215,6 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
11591215
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
11601216
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
11611217
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
1218+
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
1219+
update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
1220+
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")

tests/models/test_modeling_common.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import requests_mock
3030
import torch
3131
import torch.nn as nn
32-
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
32+
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
3333
from huggingface_hub import ModelCard, delete_repo, snapshot_download
3434
from huggingface_hub.utils import is_jinja_available
3535
from parameterized import parameterized
@@ -57,8 +57,8 @@
5757
get_python_version,
5858
is_torch_compile,
5959
require_torch_2,
60+
require_torch_accelerator,
6061
require_torch_accelerator_with_training,
61-
require_torch_gpu,
6262
require_torch_multi_gpu,
6363
run_test_in_subprocess,
6464
torch_all_close,
@@ -543,7 +543,7 @@ def test_set_xformers_attn_processor_for_determinism(self):
543543
assert torch.allclose(output, output_3, atol=self.base_precision)
544544
assert torch.allclose(output_2, output_3, atol=self.base_precision)
545545

546-
@require_torch_gpu
546+
@require_torch_accelerator
547547
def test_set_attn_processor_for_determinism(self):
548548
if self.uses_custom_attn_processor:
549549
return
@@ -1068,7 +1068,7 @@ def test_wrong_adapter_name_raises_error(self):
10681068

10691069
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
10701070

1071-
@require_torch_gpu
1071+
@require_torch_accelerator
10721072
def test_cpu_offload(self):
10731073
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10741074
model = self.model_class(**config).eval()
@@ -1080,7 +1080,7 @@ def test_cpu_offload(self):
10801080
torch.manual_seed(0)
10811081
base_output = model(**inputs_dict)
10821082

1083-
model_size = compute_module_persistent_sizes(model)[""]
1083+
model_size = compute_module_sizes(model)[""]
10841084
# We test several splits of sizes to make sure it works.
10851085
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
10861086
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1098,7 +1098,7 @@ def test_cpu_offload(self):
10981098

10991099
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
11001100

1101-
@require_torch_gpu
1101+
@require_torch_accelerator
11021102
def test_disk_offload_without_safetensors(self):
11031103
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
11041104
model = self.model_class(**config).eval()
@@ -1110,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
11101110
torch.manual_seed(0)
11111111
base_output = model(**inputs_dict)
11121112

1113-
model_size = compute_module_persistent_sizes(model)[""]
1113+
model_size = compute_module_sizes(model)[""]
11141114
with tempfile.TemporaryDirectory() as tmp_dir:
11151115
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
11161116

@@ -1132,7 +1132,7 @@ def test_disk_offload_without_safetensors(self):
11321132

11331133
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
11341134

1135-
@require_torch_gpu
1135+
@require_torch_accelerator
11361136
def test_disk_offload_with_safetensors(self):
11371137
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
11381138
model = self.model_class(**config).eval()
@@ -1144,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
11441144
torch.manual_seed(0)
11451145
base_output = model(**inputs_dict)
11461146

1147-
model_size = compute_module_persistent_sizes(model)[""]
1147+
model_size = compute_module_sizes(model)[""]
11481148
with tempfile.TemporaryDirectory() as tmp_dir:
11491149
model.cpu().save_pretrained(tmp_dir)
11501150

@@ -1172,7 +1172,7 @@ def test_model_parallelism(self):
11721172
torch.manual_seed(0)
11731173
base_output = model(**inputs_dict)
11741174

1175-
model_size = compute_module_persistent_sizes(model)[""]
1175+
model_size = compute_module_sizes(model)[""]
11761176
# We test several splits of sizes to make sure it works.
11771177
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
11781178
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1183,6 +1183,7 @@ def test_model_parallelism(self):
11831183
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
11841184
# Making sure part of the model will actually end up offloaded
11851185
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
1186+
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
11861187

11871188
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
11881189

@@ -1191,7 +1192,7 @@ def test_model_parallelism(self):
11911192

11921193
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
11931194

1194-
@require_torch_gpu
1195+
@require_torch_accelerator
11951196
def test_sharded_checkpoints(self):
11961197
torch.manual_seed(0)
11971198
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1223,7 +1224,7 @@ def test_sharded_checkpoints(self):
12231224

12241225
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
12251226

1226-
@require_torch_gpu
1227+
@require_torch_accelerator
12271228
def test_sharded_checkpoints_with_variant(self):
12281229
torch.manual_seed(0)
12291230
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1261,7 +1262,7 @@ def test_sharded_checkpoints_with_variant(self):
12611262

12621263
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
12631264

1264-
@require_torch_gpu
1265+
@require_torch_accelerator
12651266
def test_sharded_checkpoints_device_map(self):
12661267
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
12671268
model = self.model_class(**config).eval()

tests/models/transformers/test_models_transformer_sana.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import pytest
1817
import torch
1918

2019
from diffusers import SanaTransformer2DModel
@@ -33,6 +32,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
3332
model_class = SanaTransformer2DModel
3433
main_input_name = "hidden_states"
3534
uses_custom_attn_processor = True
35+
model_split_percents = [0.7, 0.7, 0.9]
3636

3737
@property
3838
def dummy_input(self):
@@ -81,27 +81,3 @@ def prepare_init_args_and_inputs_for_common(self):
8181
def test_gradient_checkpointing_is_applied(self):
8282
expected_set = {"SanaTransformer2DModel"}
8383
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
84-
85-
@pytest.mark.xfail(
86-
condition=torch.device(torch_device).type == "cuda",
87-
reason="Test currently fails.",
88-
strict=True,
89-
)
90-
def test_cpu_offload(self):
91-
return super().test_cpu_offload()
92-
93-
@pytest.mark.xfail(
94-
condition=torch.device(torch_device).type == "cuda",
95-
reason="Test currently fails.",
96-
strict=True,
97-
)
98-
def test_disk_offload_with_safetensors(self):
99-
return super().test_disk_offload_with_safetensors()
100-
101-
@pytest.mark.xfail(
102-
condition=torch.device(torch_device).type == "cuda",
103-
reason="Test currently fails.",
104-
strict=True,
105-
)
106-
def test_disk_offload_without_safetensors(self):
107-
return super().test_disk_offload_without_safetensors()

tests/pipelines/allegro/test_allegro.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
enable_full_determinism,
2828
numpy_cosine_similarity_distance,
2929
require_hf_hub_version_greater,
30-
require_torch_gpu,
30+
require_torch_accelerator,
3131
require_transformers_version_greater,
3232
slow,
3333
torch_device,
@@ -332,7 +332,7 @@ def test_save_load_dduf(self):
332332

333333

334334
@slow
335-
@require_torch_gpu
335+
@require_torch_accelerator
336336
class AllegroPipelineIntegrationTests(unittest.TestCase):
337337
prompt = "A painting of a squirrel eating a burger."
338338

@@ -350,7 +350,7 @@ def test_allegro(self):
350350
generator = torch.Generator("cpu").manual_seed(0)
351351

352352
pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16)
353-
pipe.enable_model_cpu_offload()
353+
pipe.enable_model_cpu_offload(device=torch_device)
354354
prompt = self.prompt
355355

356356
videos = pipe(

tests/pipelines/animatediff/test_animatediff.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from diffusers.models.attention import FreeNoiseTransformerBlock
2121
from diffusers.utils import is_xformers_available, logging
2222
from diffusers.utils.testing_utils import (
23+
backend_empty_cache,
2324
numpy_cosine_similarity_distance,
2425
require_accelerator,
25-
require_torch_gpu,
26+
require_torch_accelerator,
2627
slow,
2728
torch_device,
2829
)
@@ -547,19 +548,19 @@ def test_vae_slicing(self):
547548

548549

549550
@slow
550-
@require_torch_gpu
551+
@require_torch_accelerator
551552
class AnimateDiffPipelineSlowTests(unittest.TestCase):
552553
def setUp(self):
553554
# clean up the VRAM before each test
554555
super().setUp()
555556
gc.collect()
556-
torch.cuda.empty_cache()
557+
backend_empty_cache(torch_device)
557558

558559
def tearDown(self):
559560
# clean up the VRAM after each test
560561
super().tearDown()
561562
gc.collect()
562-
torch.cuda.empty_cache()
563+
backend_empty_cache(torch_device)
563564

564565
def test_animatediff(self):
565566
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
@@ -573,7 +574,7 @@ def test_animatediff(self):
573574
clip_sample=False,
574575
)
575576
pipe.enable_vae_slicing()
576-
pipe.enable_model_cpu_offload()
577+
pipe.enable_model_cpu_offload(device=torch_device)
577578
pipe.set_progress_bar_config(disable=None)
578579

579580
prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"

0 commit comments

Comments
 (0)