Skip to content
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
25033da
add; utility to check if attn_procs,norms,acts are properly documented.
sayakpaul Apr 24, 2024
9398e0f
add support listing to the workflows.
sayakpaul Apr 24, 2024
132c68b
Merge branch 'main' into feat/check-doc-listing
sayakpaul Apr 24, 2024
57ca5be
change to 2024.
sayakpaul Apr 24, 2024
8532285
Merge branch 'main' into feat/check-doc-listing
sayakpaul Apr 24, 2024
b5c9aeb
small fixes.
sayakpaul Apr 24, 2024
40128ac
Merge branch 'main' into feat/check-doc-listing
sayakpaul Apr 25, 2024
c625166
does adding detailed docstrings help?
sayakpaul Apr 25, 2024
8b58696
Merge branch 'main' into feat/check-doc-listing
sayakpaul Apr 25, 2024
80d0a7f
Merge branch 'main' into feat/check-doc-listing
sayakpaul Apr 29, 2024
d064b11
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 1, 2024
45daa98
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 2, 2024
5663ba5
fix
sayakpaul May 10, 2024
0653e2d
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 10, 2024
dac63dd
uncomment image processor check
sayakpaul May 10, 2024
900cd1c
quality
sayakpaul May 10, 2024
6dc3d19
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 14, 2024
8449186
fix, thanks to @mishig.
sayakpaul May 14, 2024
af2370b
Apply suggestions from code review
sayakpaul May 15, 2024
c4c9fc4
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 15, 2024
15b2f57
style
sayakpaul May 15, 2024
12c9ac4
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 21, 2024
f3443d0
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 22, 2024
b8b0fd1
Merge branch 'main' into feat/check-doc-listing
sayakpaul May 28, 2024
63989af
resolve conflicts.
sayakpaul Dec 8, 2024
4227392
JointAttnProcessor2_0
sayakpaul Dec 8, 2024
0034db2
fixes
sayakpaul Dec 8, 2024
eb5a8b2
resolve conflicts.
sayakpaul Dec 16, 2024
a2aa752
fixes
sayakpaul Dec 17, 2024
005a2e9
fixes
sayakpaul Dec 17, 2024
b653eaa
fixes
sayakpaul Dec 17, 2024
75136e6
fixes
sayakpaul Dec 17, 2024
7eb617a
fixes
sayakpaul Dec 17, 2024
80be186
Merge branch 'main' into feat/check-doc-listing
sayakpaul Feb 19, 2025
ef03777
Merge branch 'main' into feat/check-doc-listing
sayakpaul Feb 20, 2025
53a3361
fixes
sayakpaul Feb 20, 2025
be989a6
Update docs/source/en/api/normalization.md
sayakpaul Feb 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
Expand Down
13 changes: 13 additions & 0 deletions docs/source/en/api/activations.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,16 @@ Customized activation functions for supporting various models in 🤗 Diffusers.
## ApproximateGELU

[[autodoc]] models.activations.ApproximateGELU


## SwiGLU

[[autodoc]] models.activations.SwiGLU

## FP32SiLU

[[autodoc]] models.activations.FP32SiLU

## LinearActivation

[[autodoc]] models.activations.LinearActivation
8 changes: 8 additions & 0 deletions docs/source/en/api/attnprocessor.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,11 @@ An attention processor is a class for applying different types of attention mech
## XLAFlashAttnProcessor2_0

[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0

## XFormersJointAttnProcessor

[[autodoc]] models.attention_processor.XFormersJointAttnProcessor

## IPAdapterXFormersAttnProcessor

[[autodoc]] models.attention_processor.IPAdapterXFormersAttnProcessor
37 changes: 37 additions & 0 deletions docs/source/en/api/normalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,40 @@ Customized normalization layers for supporting various models in 🤗 Diffusers.
## AdaGroupNorm

[[autodoc]] models.normalization.AdaGroupNorm

## AdaLayerNormContinuous

[[autodoc]] models.normalization.AdaLayerNormContinuous

## RMSNorm

[[autodoc]] models.normalization.RMSNorm

## GlobalResponseNorm

[[autodoc]] models.normalization.GlobalResponseNorm


## LuminaLayerNormContinuous
[[autodoc]] models.normalization.LuminaLayerNormContinuous

## SD35AdaLayerNormZeroX
[[autodoc]] models.normalization.SD35AdaLayerNormZeroX

## AdaLayerNormZeroSingle
[[autodoc]] models.normalization.AdaLayerNormZeroSingle

## LuminaRMSNormZero
[[autodoc]] models.normalization.LuminaRMSNormZero

## LpNorm
[[autodoc]] models.normalization.LpNorm

## CogView3PlusAdaLayerNormZeroTextImage
[[autodoc]] models.normalization.CogView3PlusAdaLayerNormZeroTextImage

## CogVideoXLayerNormZero
[[autodoc]] models.normalization.CogVideoXLayerNormZero

## MochiRMSNormZero
[[autodoc]] models.normalization.MochiRMSNormZero
43 changes: 43 additions & 0 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,20 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:


class AdaLayerNormContinuous(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).

Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""

def __init__(
self,
embedding_dim: int,
Expand Down Expand Up @@ -462,6 +476,17 @@ def forward(
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
r"""
LayerNorm with the bias parameter.

Args:
dim (`int`): Dimensionality to use for the parameters.
eps (`float`, defaults to 1e-5): Epsilon factor.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
"""

def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()

Expand All @@ -484,6 +509,17 @@ def forward(self, input):


class RMSNorm(nn.Module):
r"""
RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.

Args:
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
bias (`bool`, defaults to False): If also training the `bias` param.
"""

def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
super().__init__()

Expand Down Expand Up @@ -573,6 +609,13 @@ def forward(self, hidden_states):


class GlobalResponseNorm(nn.Module):
r"""
Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).

Args:
dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
"""

# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
super().__init__()
Expand Down
68 changes: 68 additions & 0 deletions tests/others/test_check_support_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import sys
import unittest
from unittest.mock import mock_open, patch


git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))

from check_support_list import check_documentation # noqa: E402


class TestCheckSupportList(unittest.TestCase):
def setUp(self):
# Mock doc and source contents that we can reuse
self.doc_content = """# Documentation
## FooProcessor

[[autodoc]] module.FooProcessor

## BarProcessor

[[autodoc]] module.BarProcessor
"""
self.source_content = """
class FooProcessor(nn.Module):
pass

class BarProcessor(nn.Module):
pass
"""

def test_check_documentation_all_documented(self):
# In this test, both FooProcessor and BarProcessor are documented
with patch("builtins.open", mock_open(read_data=self.doc_content)) as doc_file:
doc_file.side_effect = [
mock_open(read_data=self.doc_content).return_value,
mock_open(read_data=self.source_content).return_value,
]

undocumented = check_documentation(
doc_path="fake_doc.md",
src_path="fake_source.py",
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
)
self.assertEqual(len(undocumented), 0, f"Expected no undocumented classes, got {undocumented}")

def test_check_documentation_missing_class(self):
# In this test, only FooProcessor is documented, but BarProcessor is missing from the docs
doc_content_missing = """# Documentation
## FooProcessor

[[autodoc]] module.FooProcessor
"""
with patch("builtins.open", mock_open(read_data=doc_content_missing)) as doc_file:
doc_file.side_effect = [
mock_open(read_data=doc_content_missing).return_value,
mock_open(read_data=self.source_content).return_value,
]

undocumented = check_documentation(
doc_path="fake_doc.md",
src_path="fake_source.py",
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
)
self.assertIn("BarProcessor", undocumented, f"BarProcessor should be undocumented, got {undocumented}")
124 changes: 124 additions & 0 deletions utils/check_support_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""
Utility that checks that modules like attention processors are listed in the documentation file.

```bash
python utils/check_support_list.py
```

It has no auto-fix mode.
"""

import os
import re


# All paths are set with the intent that you run this script from the root of the repo
REPO_PATH = "."


def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"):
"""
Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class.
Returns a list of documented class names (just the class name portion).
"""
with open(os.path.join(REPO_PATH, doc_path), "r") as f:
doctext = f.read()
matches = re.findall(autodoc_regex, doctext)
return [match.split(".")[-1] for match in matches]


def read_source_classes(src_path, class_regex, exclude_conditions=None):
"""
Reads class names from a source file using a regex that captures class definitions.
Optionally exclude classes based on a list of conditions (functions that take class name and return bool).
"""
if exclude_conditions is None:
exclude_conditions = []
with open(os.path.join(REPO_PATH, src_path), "r") as f:
doctext = f.read()
classes = re.findall(class_regex, doctext)
# Filter out classes that meet any of the exclude conditions
filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)]
return filtered_classes


def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None):
"""
Generic function to check if all classes defined in `src_path` are documented in `doc_path`.
Returns a set of undocumented class names.
"""
documented = set(read_documented_classes(doc_path, doc_regex))
source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions))

# Find which classes in source are not documented in a deterministic way.
undocumented = sorted(source_classes - documented)
return undocumented


if __name__ == "__main__":
# Define the checks we need to perform
checks = {
"Attention Processors": {
"doc_path": "docs/source/en/api/attnprocessor.md",
"src_path": "src/diffusers/models/attention_processor.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
"exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"],
},
"Image Processors": {
"doc_path": "docs/source/en/api/image_processor.md",
"src_path": "src/diffusers/image_processor.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
},
"Activations": {
"doc_path": "docs/source/en/api/activations.md",
"src_path": "src/diffusers/models/activations.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
},
"Normalizations": {
"doc_path": "docs/source/en/api/normalization.md",
"src_path": "src/diffusers/models/normalization.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
"exclude_conditions": [
# Exclude LayerNorm as it's an intentional exception
lambda c: c == "LayerNorm"
],
},
"LoRA Mixins": {
"doc_path": "docs/source/en/api/loaders/lora.md",
"src_path": "src/diffusers/loaders/lora_pipeline.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
},
}

missing_items = {}
for category, params in checks.items():
undocumented = check_documentation(
doc_path=params["doc_path"],
src_path=params["src_path"],
doc_regex=params["doc_regex"],
src_regex=params["src_regex"],
exclude_conditions=params.get("exclude_conditions"),
)
if undocumented:
missing_items[category] = undocumented

# If we have any missing items, raise a single combined error
if missing_items:
error_msg = ["Some classes are not documented properly:\n"]
for category, classes in missing_items.items():
error_msg.append(f"- {category}: {', '.join(sorted(classes))}")
raise ValueError("\n".join(error_msg))
Loading