Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 55 additions & 70 deletions tests/hooks/test_group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@

import contextlib
import gc
import unittest
import logging

import pytest
import torch
from parameterized import parameterized

from diffusers import AutoencoderKL
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
from diffusers.utils.import_utils import compare_versions

from ..testing_utils import (
Expand Down Expand Up @@ -219,20 +218,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
class TestGroupOffload:
in_features = 64
hidden_features = 256
out_features = 64
num_layers = 4

def setUp(self):
def setup_method(self):
with torch.no_grad():
self.model = self.get_model()
self.input = torch.randn((4, self.in_features)).to(torch_device)

def tearDown(self):
super().tearDown()

def teardown_method(self):
del self.model
del self.input
gc.collect()
Expand All @@ -248,18 +245,20 @@ def get_model(self):
num_layers=self.num_layers,
)

@pytest.mark.skipif(
torch.device(torch_device).type not in ["cuda", "xpu"],
reason="Test requires a CUDA or XPU device.",
)
def test_offloading_forward_pass(self):
@torch.no_grad()
def run_forward(model):
gc.collect()
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
model.eval()
output = model(self.input)[0].cpu()
Expand Down Expand Up @@ -291,73 +290,69 @@ def run_forward(model):
output_with_group_offloading5, mem5 = run_forward(model)

# Precision assertions - offloading should not impact the output
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)

# Memory assertions - offloading should reduce memory usage
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
assert mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline

def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self, caplog):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.models.modeling_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
with caplog.at_level(logging.WARNING, logger="diffusers.models.modeling_utils"):
self.model.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text

def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self, caplog):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
pipe = DummyPipeline(self.model)
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
with caplog.at_level(logging.WARNING, logger="diffusers.pipelines.pipeline_utils"):
pipe.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text

def test_error_raised_if_streams_used_and_no_accelerator_device(self):
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
original_is_available = torch_accelerator_module.is_available
torch_accelerator_module.is_available = lambda: False
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
self.model.enable_group_offload(
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
)
torch_accelerator_module.is_available = original_is_available

def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
with pytest.raises(ValueError, match="does not support group offloading"):
self.model.enable_group_offload(onload_device=torch.device(torch_device))

def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
pipe.enable_model_cpu_offload()

def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
pipe.enable_sequential_cpu_offload()

def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_model_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
with pytest.raises(ValueError, match="Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)

def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_sequential_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
with pytest.raises(ValueError, match="Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)

def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
Expand All @@ -376,12 +371,12 @@ def test_block_level_stream_with_invocation_order_different_from_initialization_
context = contextlib.nullcontext()
if compare_versions("diffusers", "<=", "0.33.0"):
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")
context = pytest.raises(RuntimeError, match="Expected all tensors to be on the same device")

with context:
model(self.input)

@parameterized.expand([("block_level",), ("leaf_level",)])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
Expand All @@ -407,14 +402,14 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):

out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match."

num_repeats = 2
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)

self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations."

for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
Expand All @@ -428,9 +423,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
assert cumulated_absmax < 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"

def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams."""
Expand All @@ -452,9 +445,7 @@ def test_vae_like_model_without_streams(self):
out_ref = model_ref(x).sample
out = model(x).sample

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
)
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."

def test_model_with_only_standalone_layers(self):
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
Expand All @@ -475,12 +466,11 @@ def test_model_with_only_standalone_layers(self):
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for model with standalone layers.",
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match at iteration {i} for model with standalone layers."
)

@parameterized.expand([("block_level",), ("leaf_level",)])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
Expand All @@ -501,9 +491,8 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str)
out_ref = model_ref(x).sample
out = model(x).sample

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match for standalone Conv layers with {offload_type}.",
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match for standalone Conv layers with {offload_type}."
)

def test_multiple_invocations_with_vae_like_model(self):
Expand All @@ -526,7 +515,7 @@ def test_multiple_invocations_with_vae_like_model(self):
for i in range(2):
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
assert torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}."

def test_nested_container_parameters_offloading(self):
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
Expand All @@ -547,9 +536,8 @@ def test_nested_container_parameters_offloading(self):
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for nested parameters.",
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match at iteration {i} for nested parameters."
)

def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
Expand Down Expand Up @@ -602,7 +590,7 @@ def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -
return x


class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
class TestConditionalModuleGroupOffload(TestGroupOffload):
"""Tests for conditionally-executed modules under group offloading with streams.

Regression tests for the case where a module is not executed during the first forward pass
Expand All @@ -620,10 +608,10 @@ def get_model(self):
num_layers=self.num_layers,
)

@parameterized.expand([("leaf_level",), ("block_level",)])
@unittest.skipIf(
@pytest.mark.parametrize("offload_type", ["leaf_level", "block_level"])
@pytest.mark.skipif(
torch.device(torch_device).type not in ["cuda", "xpu"],
"Test requires a CUDA or XPU device.",
reason="Test requires a CUDA or XPU device.",
)
def test_conditional_modules_with_stream(self, offload_type: str):
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
Expand Down Expand Up @@ -670,23 +658,20 @@ def test_conditional_modules_with_stream(self, offload_type: str):
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
out_ref_no_opt = model_ref(x, optional_input=None)
out_no_opt = model(x, optional_input=None)
self.assertTrue(
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
assert torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), (
f"[{offload_type}] Outputs do not match on first pass (no optional_input)."
)

# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
out_ref_with_opt = model_ref(x, optional_input=optional_input)
out_with_opt = model(x, optional_input=optional_input)
self.assertTrue(
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
assert torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), (
f"[{offload_type}] Outputs do not match on second pass (with optional_input)."
)

# Third pass again without optional_input — verify stable behavior.
out_ref_no_opt2 = model_ref(x, optional_input=None)
out_no_opt2 = model(x, optional_input=None)
self.assertTrue(
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
assert torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), (
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input)."
)