Skip to content

Commit 9d39ab2

Browse files
authored
Merge branch 'main' into tests-encode-prompt
2 parents a33ac2f + f550745 commit 9d39ab2

File tree

7 files changed

+306
-0
lines changed

7 files changed

+306
-0
lines changed

.github/workflows/pr_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ jobs:
6464
run: |
6565
python utils/check_copies.py
6666
python utils/check_dummies.py
67+
python utils/check_support_list.py
6768
make deps_table_check_updated
6869
- name: Check if failure
6970
if: ${{ failure() }}

docs/source/en/api/activations.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,16 @@ Customized activation functions for supporting various models in 🤗 Diffusers.
2525
## ApproximateGELU
2626

2727
[[autodoc]] models.activations.ApproximateGELU
28+
29+
30+
## SwiGLU
31+
32+
[[autodoc]] models.activations.SwiGLU
33+
34+
## FP32SiLU
35+
36+
[[autodoc]] models.activations.FP32SiLU
37+
38+
## LinearActivation
39+
40+
[[autodoc]] models.activations.LinearActivation

docs/source/en/api/attnprocessor.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,20 @@ An attention processor is a class for applying different types of attention mech
147147
## XLAFlashAttnProcessor2_0
148148

149149
[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0
150+
151+
## XFormersJointAttnProcessor
152+
153+
[[autodoc]] models.attention_processor.XFormersJointAttnProcessor
154+
155+
## IPAdapterXFormersAttnProcessor
156+
157+
[[autodoc]] models.attention_processor.IPAdapterXFormersAttnProcessor
158+
159+
## FluxIPAdapterJointAttnProcessor2_0
160+
161+
[[autodoc]] models.attention_processor.FluxIPAdapterJointAttnProcessor2_0
162+
163+
164+
## XLAFluxFlashAttnProcessor2_0
165+
166+
[[autodoc]] models.attention_processor.XLAFluxFlashAttnProcessor2_0

docs/source/en/api/normalization.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,43 @@ Customized normalization layers for supporting various models in 🤗 Diffusers.
2929
## AdaGroupNorm
3030

3131
[[autodoc]] models.normalization.AdaGroupNorm
32+
33+
## AdaLayerNormContinuous
34+
35+
[[autodoc]] models.normalization.AdaLayerNormContinuous
36+
37+
## RMSNorm
38+
39+
[[autodoc]] models.normalization.RMSNorm
40+
41+
## GlobalResponseNorm
42+
43+
[[autodoc]] models.normalization.GlobalResponseNorm
44+
45+
46+
## LuminaLayerNormContinuous
47+
[[autodoc]] models.normalization.LuminaLayerNormContinuous
48+
49+
## SD35AdaLayerNormZeroX
50+
[[autodoc]] models.normalization.SD35AdaLayerNormZeroX
51+
52+
## AdaLayerNormZeroSingle
53+
[[autodoc]] models.normalization.AdaLayerNormZeroSingle
54+
55+
## LuminaRMSNormZero
56+
[[autodoc]] models.normalization.LuminaRMSNormZero
57+
58+
## LpNorm
59+
[[autodoc]] models.normalization.LpNorm
60+
61+
## CogView3PlusAdaLayerNormZeroTextImage
62+
[[autodoc]] models.normalization.CogView3PlusAdaLayerNormZeroTextImage
63+
64+
## CogVideoXLayerNormZero
65+
[[autodoc]] models.normalization.CogVideoXLayerNormZero
66+
67+
## MochiRMSNormZero
68+
[[autodoc]] models.transformers.transformer_mochi.MochiRMSNormZero
69+
70+
## MochiRMSNorm
71+
[[autodoc]] models.normalization.MochiRMSNorm

src/diffusers/models/normalization.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,20 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
306306

307307

308308
class AdaLayerNormContinuous(nn.Module):
309+
r"""
310+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
311+
312+
Args:
313+
embedding_dim (`int`): Embedding dimension to use during projection.
314+
conditioning_embedding_dim (`int`): Dimension of the input condition.
315+
elementwise_affine (`bool`, defaults to `True`):
316+
Boolean flag to denote if affine transformation should be applied.
317+
eps (`float`, defaults to 1e-5): Epsilon factor.
318+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
319+
norm_type (`str`, defaults to `"layer_norm"`):
320+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
321+
"""
322+
309323
def __init__(
310324
self,
311325
embedding_dim: int,
@@ -462,6 +476,17 @@ def forward(
462476
# Has optional bias parameter compared to torch layer norm
463477
# TODO: replace with torch layernorm once min required torch version >= 2.1
464478
class LayerNorm(nn.Module):
479+
r"""
480+
LayerNorm with the bias parameter.
481+
482+
Args:
483+
dim (`int`): Dimensionality to use for the parameters.
484+
eps (`float`, defaults to 1e-5): Epsilon factor.
485+
elementwise_affine (`bool`, defaults to `True`):
486+
Boolean flag to denote if affine transformation should be applied.
487+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
488+
"""
489+
465490
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
466491
super().__init__()
467492

@@ -484,6 +509,17 @@ def forward(self, input):
484509

485510

486511
class RMSNorm(nn.Module):
512+
r"""
513+
RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
514+
515+
Args:
516+
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
517+
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
518+
elementwise_affine (`bool`, defaults to `True`):
519+
Boolean flag to denote if affine transformation should be applied.
520+
bias (`bool`, defaults to False): If also training the `bias` param.
521+
"""
522+
487523
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
488524
super().__init__()
489525

@@ -573,6 +609,13 @@ def forward(self, hidden_states):
573609

574610

575611
class GlobalResponseNorm(nn.Module):
612+
r"""
613+
Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
614+
615+
Args:
616+
dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
617+
"""
618+
576619
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
577620
def __init__(self, dim):
578621
super().__init__()
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
import sys
3+
import unittest
4+
from unittest.mock import mock_open, patch
5+
6+
7+
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
8+
sys.path.append(os.path.join(git_repo_path, "utils"))
9+
10+
from check_support_list import check_documentation # noqa: E402
11+
12+
13+
class TestCheckSupportList(unittest.TestCase):
14+
def setUp(self):
15+
# Mock doc and source contents that we can reuse
16+
self.doc_content = """# Documentation
17+
## FooProcessor
18+
19+
[[autodoc]] module.FooProcessor
20+
21+
## BarProcessor
22+
23+
[[autodoc]] module.BarProcessor
24+
"""
25+
self.source_content = """
26+
class FooProcessor(nn.Module):
27+
pass
28+
29+
class BarProcessor(nn.Module):
30+
pass
31+
"""
32+
33+
def test_check_documentation_all_documented(self):
34+
# In this test, both FooProcessor and BarProcessor are documented
35+
with patch("builtins.open", mock_open(read_data=self.doc_content)) as doc_file:
36+
doc_file.side_effect = [
37+
mock_open(read_data=self.doc_content).return_value,
38+
mock_open(read_data=self.source_content).return_value,
39+
]
40+
41+
undocumented = check_documentation(
42+
doc_path="fake_doc.md",
43+
src_path="fake_source.py",
44+
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
45+
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
46+
)
47+
self.assertEqual(len(undocumented), 0, f"Expected no undocumented classes, got {undocumented}")
48+
49+
def test_check_documentation_missing_class(self):
50+
# In this test, only FooProcessor is documented, but BarProcessor is missing from the docs
51+
doc_content_missing = """# Documentation
52+
## FooProcessor
53+
54+
[[autodoc]] module.FooProcessor
55+
"""
56+
with patch("builtins.open", mock_open(read_data=doc_content_missing)) as doc_file:
57+
doc_file.side_effect = [
58+
mock_open(read_data=doc_content_missing).return_value,
59+
mock_open(read_data=self.source_content).return_value,
60+
]
61+
62+
undocumented = check_documentation(
63+
doc_path="fake_doc.md",
64+
src_path="fake_source.py",
65+
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
66+
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
67+
)
68+
self.assertIn("BarProcessor", undocumented, f"BarProcessor should be undocumented, got {undocumented}")

utils/check_support_list.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
"""
11+
Utility that checks that modules like attention processors are listed in the documentation file.
12+
13+
```bash
14+
python utils/check_support_list.py
15+
```
16+
17+
It has no auto-fix mode.
18+
"""
19+
20+
import os
21+
import re
22+
23+
24+
# All paths are set with the intent that you run this script from the root of the repo
25+
REPO_PATH = "."
26+
27+
28+
def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"):
29+
"""
30+
Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class.
31+
Returns a list of documented class names (just the class name portion).
32+
"""
33+
with open(os.path.join(REPO_PATH, doc_path), "r") as f:
34+
doctext = f.read()
35+
matches = re.findall(autodoc_regex, doctext)
36+
return [match.split(".")[-1] for match in matches]
37+
38+
39+
def read_source_classes(src_path, class_regex, exclude_conditions=None):
40+
"""
41+
Reads class names from a source file using a regex that captures class definitions.
42+
Optionally exclude classes based on a list of conditions (functions that take class name and return bool).
43+
"""
44+
if exclude_conditions is None:
45+
exclude_conditions = []
46+
with open(os.path.join(REPO_PATH, src_path), "r") as f:
47+
doctext = f.read()
48+
classes = re.findall(class_regex, doctext)
49+
# Filter out classes that meet any of the exclude conditions
50+
filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)]
51+
return filtered_classes
52+
53+
54+
def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None):
55+
"""
56+
Generic function to check if all classes defined in `src_path` are documented in `doc_path`.
57+
Returns a set of undocumented class names.
58+
"""
59+
documented = set(read_documented_classes(doc_path, doc_regex))
60+
source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions))
61+
62+
# Find which classes in source are not documented in a deterministic way.
63+
undocumented = sorted(source_classes - documented)
64+
return undocumented
65+
66+
67+
if __name__ == "__main__":
68+
# Define the checks we need to perform
69+
checks = {
70+
"Attention Processors": {
71+
"doc_path": "docs/source/en/api/attnprocessor.md",
72+
"src_path": "src/diffusers/models/attention_processor.py",
73+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
74+
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
75+
"exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"],
76+
},
77+
"Image Processors": {
78+
"doc_path": "docs/source/en/api/image_processor.md",
79+
"src_path": "src/diffusers/image_processor.py",
80+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
81+
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
82+
},
83+
"Activations": {
84+
"doc_path": "docs/source/en/api/activations.md",
85+
"src_path": "src/diffusers/models/activations.py",
86+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
87+
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
88+
},
89+
"Normalizations": {
90+
"doc_path": "docs/source/en/api/normalization.md",
91+
"src_path": "src/diffusers/models/normalization.py",
92+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
93+
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
94+
"exclude_conditions": [
95+
# Exclude LayerNorm as it's an intentional exception
96+
lambda c: c == "LayerNorm"
97+
],
98+
},
99+
"LoRA Mixins": {
100+
"doc_path": "docs/source/en/api/loaders/lora.md",
101+
"src_path": "src/diffusers/loaders/lora_pipeline.py",
102+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
103+
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
104+
},
105+
}
106+
107+
missing_items = {}
108+
for category, params in checks.items():
109+
undocumented = check_documentation(
110+
doc_path=params["doc_path"],
111+
src_path=params["src_path"],
112+
doc_regex=params["doc_regex"],
113+
src_regex=params["src_regex"],
114+
exclude_conditions=params.get("exclude_conditions"),
115+
)
116+
if undocumented:
117+
missing_items[category] = undocumented
118+
119+
# If we have any missing items, raise a single combined error
120+
if missing_items:
121+
error_msg = ["Some classes are not documented properly:\n"]
122+
for category, classes in missing_items.items():
123+
error_msg.append(f"- {category}: {', '.join(sorted(classes))}")
124+
raise ValueError("\n".join(error_msg))

0 commit comments

Comments
 (0)