Skip to content

Commit bffa3a9

Browse files
committed
update
1 parent 1c55871 commit bffa3a9

File tree

16 files changed

+2752
-121
lines changed

16 files changed

+2752
-121
lines changed

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@
3232

3333
def pytest_configure(config):
3434
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
35+
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
36+
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
37+
config.addinivalue_line("markers", "training: marks tests for training functionality")
38+
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
39+
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
40+
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
41+
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
42+
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
43+
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
44+
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
45+
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
46+
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
47+
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
48+
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
3549

3650

3751
def pytest_addoption(parser):
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,37 @@
1+
from .attention import AttentionTesterMixin
12
from .common import ModelTesterMixin
3+
from .compile import TorchCompileTesterMixin
4+
from .ip_adapter import IPAdapterTesterMixin
5+
from .lora import LoraTesterMixin
6+
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
7+
from .quantization import (
8+
BitsAndBytesTesterMixin,
9+
GGUFTesterMixin,
10+
ModelOptTesterMixin,
11+
QuantizationTesterMixin,
12+
QuantoTesterMixin,
13+
TorchAoTesterMixin,
14+
)
215
from .single_file import SingleFileTesterMixin
16+
from .training import TrainingTesterMixin
17+
18+
19+
__all__ = [
20+
"AttentionTesterMixin",
21+
"BitsAndBytesTesterMixin",
22+
"CPUOffloadTesterMixin",
23+
"GGUFTesterMixin",
24+
"GroupOffloadTesterMixin",
25+
"IPAdapterTesterMixin",
26+
"LayerwiseCastingTesterMixin",
27+
"LoraTesterMixin",
28+
"MemoryTesterMixin",
29+
"ModelOptTesterMixin",
30+
"ModelTesterMixin",
31+
"QuantizationTesterMixin",
32+
"QuantoTesterMixin",
33+
"SingleFileTesterMixin",
34+
"TorchAoTesterMixin",
35+
"TorchCompileTesterMixin",
36+
"TrainingTesterMixin",
37+
]
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
import torch
18+
19+
from diffusers.models.attention import AttentionModuleMixin
20+
from diffusers.models.attention_processor import (
21+
AttnProcessor,
22+
)
23+
24+
from ...testing_utils import is_attention, require_accelerator, torch_device
25+
26+
27+
@is_attention
28+
@require_accelerator
29+
class AttentionTesterMixin:
30+
"""
31+
Mixin class for testing attention processor and module functionality on models.
32+
33+
Tests functionality from AttentionModuleMixin including:
34+
- Attention processor management (set/get)
35+
- QKV projection fusion/unfusion
36+
- Attention backends (XFormers, NPU, etc.)
37+
38+
Expected class attributes to be set by subclasses:
39+
- model_class: The model class to test
40+
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
41+
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
42+
43+
Expected methods to be implemented by subclasses:
44+
- get_init_dict(): Returns dict of arguments to initialize the model
45+
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
46+
47+
Pytest mark: attention
48+
Use `pytest -m "not attention"` to skip these tests
49+
"""
50+
51+
base_precision = 1e-3
52+
53+
def test_fuse_unfuse_qkv_projections(self):
54+
init_dict = self.get_init_dict()
55+
inputs_dict = self.get_dummy_inputs()
56+
model = self.model_class(**init_dict)
57+
model.to(torch_device)
58+
model.eval()
59+
60+
if not hasattr(model, "fuse_qkv_projections"):
61+
pytest.skip("Model does not support QKV projection fusion.")
62+
63+
# Get output before fusion
64+
with torch.no_grad():
65+
output_before_fusion = model(**inputs_dict)
66+
if isinstance(output_before_fusion, dict):
67+
output_before_fusion = output_before_fusion.to_tuple()[0]
68+
69+
# Fuse projections
70+
model.fuse_qkv_projections()
71+
72+
# Verify fusion occurred by checking for fused attributes
73+
has_fused_projections = False
74+
for module in model.modules():
75+
if isinstance(module, AttentionModuleMixin):
76+
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
77+
has_fused_projections = True
78+
assert module.fused_projections, "fused_projections flag should be True"
79+
break
80+
81+
if has_fused_projections:
82+
# Get output after fusion
83+
with torch.no_grad():
84+
output_after_fusion = model(**inputs_dict)
85+
if isinstance(output_after_fusion, dict):
86+
output_after_fusion = output_after_fusion.to_tuple()[0]
87+
88+
# Verify outputs match
89+
assert torch.allclose(
90+
output_before_fusion, output_after_fusion, atol=self.base_precision
91+
), "Output should not change after fusing projections"
92+
93+
# Unfuse projections
94+
model.unfuse_qkv_projections()
95+
96+
# Verify unfusion occurred
97+
for module in model.modules():
98+
if isinstance(module, AttentionModuleMixin):
99+
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
100+
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
101+
assert not module.fused_projections, "fused_projections flag should be False"
102+
103+
# Get output after unfusion
104+
with torch.no_grad():
105+
output_after_unfusion = model(**inputs_dict)
106+
if isinstance(output_after_unfusion, dict):
107+
output_after_unfusion = output_after_unfusion.to_tuple()[0]
108+
109+
# Verify outputs still match
110+
assert torch.allclose(
111+
output_before_fusion, output_after_unfusion, atol=self.base_precision
112+
), "Output should match original after unfusing projections"
113+
114+
def test_get_set_processor(self):
115+
init_dict = self.get_init_dict()
116+
model = self.model_class(**init_dict)
117+
model.to(torch_device)
118+
119+
# Check if model has attention processors
120+
if not hasattr(model, "attn_processors"):
121+
pytest.skip("Model does not have attention processors.")
122+
123+
# Test getting processors
124+
processors = model.attn_processors
125+
assert isinstance(processors, dict), "attn_processors should return a dict"
126+
assert len(processors) > 0, "Model should have at least one attention processor"
127+
128+
# Test that all processors can be retrieved via get_processor
129+
for module in model.modules():
130+
if isinstance(module, AttentionModuleMixin):
131+
processor = module.get_processor()
132+
assert processor is not None, "get_processor should return a processor"
133+
134+
# Test setting a new processor
135+
new_processor = AttnProcessor()
136+
module.set_processor(new_processor)
137+
retrieved_processor = module.get_processor()
138+
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
139+
140+
def test_attention_processor_dict(self):
141+
init_dict = self.get_init_dict()
142+
model = self.model_class(**init_dict)
143+
model.to(torch_device)
144+
145+
if not hasattr(model, "set_attn_processor"):
146+
pytest.skip("Model does not support setting attention processors.")
147+
148+
# Get current processors
149+
current_processors = model.attn_processors
150+
151+
# Create a dict of new processors
152+
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
153+
154+
# Set processors using dict
155+
model.set_attn_processor(new_processors)
156+
157+
# Verify all processors were set
158+
updated_processors = model.attn_processors
159+
for key in current_processors.keys():
160+
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
161+
162+
def test_attention_processor_count_mismatch_raises_error(self):
163+
init_dict = self.get_init_dict()
164+
model = self.model_class(**init_dict)
165+
model.to(torch_device)
166+
167+
if not hasattr(model, "set_attn_processor"):
168+
pytest.skip("Model does not support setting attention processors.")
169+
170+
# Get current processors
171+
current_processors = model.attn_processors
172+
173+
# Create a dict with wrong number of processors
174+
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
175+
176+
# Verify error is raised
177+
with pytest.raises(ValueError) as exc_info:
178+
model.set_attn_processor(wrong_processors)
179+
180+
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"

0 commit comments

Comments
 (0)