diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 629f80637503..7ca04314ec3d 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -64,6 +64,7 @@ jobs: run: | python utils/check_copies.py python utils/check_dummies.py + python utils/check_support_list.py make deps_table_check_updated - name: Check if failure if: ${{ failure() }} diff --git a/docs/source/en/api/activations.md b/docs/source/en/api/activations.md index 3bef28a5ab0d..140a2ae1a1b2 100644 --- a/docs/source/en/api/activations.md +++ b/docs/source/en/api/activations.md @@ -25,3 +25,16 @@ Customized activation functions for supporting various models in 🤗 Diffusers. ## ApproximateGELU [[autodoc]] models.activations.ApproximateGELU + + +## SwiGLU + +[[autodoc]] models.activations.SwiGLU + +## FP32SiLU + +[[autodoc]] models.activations.FP32SiLU + +## LinearActivation + +[[autodoc]] models.activations.LinearActivation diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index 8bdffc330567..638ecb973e5d 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -147,3 +147,20 @@ An attention processor is a class for applying different types of attention mech ## XLAFlashAttnProcessor2_0 [[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0 + +## XFormersJointAttnProcessor + +[[autodoc]] models.attention_processor.XFormersJointAttnProcessor + +## IPAdapterXFormersAttnProcessor + +[[autodoc]] models.attention_processor.IPAdapterXFormersAttnProcessor + +## FluxIPAdapterJointAttnProcessor2_0 + +[[autodoc]] models.attention_processor.FluxIPAdapterJointAttnProcessor2_0 + + +## XLAFluxFlashAttnProcessor2_0 + +[[autodoc]] models.attention_processor.XLAFluxFlashAttnProcessor2_0 \ No newline at end of file diff --git a/docs/source/en/api/normalization.md b/docs/source/en/api/normalization.md index ef4b694a4d85..05ae92a28dc8 100644 --- a/docs/source/en/api/normalization.md +++ b/docs/source/en/api/normalization.md @@ -29,3 +29,43 @@ Customized normalization layers for supporting various models in 🤗 Diffusers. ## AdaGroupNorm [[autodoc]] models.normalization.AdaGroupNorm + +## AdaLayerNormContinuous + +[[autodoc]] models.normalization.AdaLayerNormContinuous + +## RMSNorm + +[[autodoc]] models.normalization.RMSNorm + +## GlobalResponseNorm + +[[autodoc]] models.normalization.GlobalResponseNorm + + +## LuminaLayerNormContinuous +[[autodoc]] models.normalization.LuminaLayerNormContinuous + +## SD35AdaLayerNormZeroX +[[autodoc]] models.normalization.SD35AdaLayerNormZeroX + +## AdaLayerNormZeroSingle +[[autodoc]] models.normalization.AdaLayerNormZeroSingle + +## LuminaRMSNormZero +[[autodoc]] models.normalization.LuminaRMSNormZero + +## LpNorm +[[autodoc]] models.normalization.LpNorm + +## CogView3PlusAdaLayerNormZeroTextImage +[[autodoc]] models.normalization.CogView3PlusAdaLayerNormZeroTextImage + +## CogVideoXLayerNormZero +[[autodoc]] models.normalization.CogVideoXLayerNormZero + +## MochiRMSNormZero +[[autodoc]] models.transformers.transformer_mochi.MochiRMSNormZero + +## MochiRMSNorm +[[autodoc]] models.normalization.MochiRMSNorm \ No newline at end of file diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index c31fd91ab433..383388ca543f 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -306,6 +306,20 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: class AdaLayerNormContinuous(nn.Module): + r""" + Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + def __init__( self, embedding_dim: int, @@ -462,6 +476,17 @@ def forward( # Has optional bias parameter compared to torch layer norm # TODO: replace with torch layernorm once min required torch version >= 2.1 class LayerNorm(nn.Module): + r""" + LayerNorm with the bias parameter. + + Args: + dim (`int`): Dimensionality to use for the parameters. + eps (`float`, defaults to 1e-5): Epsilon factor. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + """ + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): super().__init__() @@ -484,6 +509,17 @@ def forward(self, input): class RMSNorm(nn.Module): + r""" + RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al. + + Args: + dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True. + eps (`float`): Small value to use when calculating the reciprocal of the square-root. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + bias (`bool`, defaults to False): If also training the `bias` param. + """ + def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False): super().__init__() @@ -573,6 +609,13 @@ def forward(self, hidden_states): class GlobalResponseNorm(nn.Module): + r""" + Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808). + + Args: + dim (`int`): Number of dimensions to use for the `gamma` and `beta`. + """ + # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 def __init__(self, dim): super().__init__() diff --git a/tests/others/test_check_support_list.py b/tests/others/test_check_support_list.py new file mode 100644 index 000000000000..0f6b134aad49 --- /dev/null +++ b/tests/others/test_check_support_list.py @@ -0,0 +1,68 @@ +import os +import sys +import unittest +from unittest.mock import mock_open, patch + + +git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(os.path.join(git_repo_path, "utils")) + +from check_support_list import check_documentation # noqa: E402 + + +class TestCheckSupportList(unittest.TestCase): + def setUp(self): + # Mock doc and source contents that we can reuse + self.doc_content = """# Documentation +## FooProcessor + +[[autodoc]] module.FooProcessor + +## BarProcessor + +[[autodoc]] module.BarProcessor +""" + self.source_content = """ +class FooProcessor(nn.Module): + pass + +class BarProcessor(nn.Module): + pass +""" + + def test_check_documentation_all_documented(self): + # In this test, both FooProcessor and BarProcessor are documented + with patch("builtins.open", mock_open(read_data=self.doc_content)) as doc_file: + doc_file.side_effect = [ + mock_open(read_data=self.doc_content).return_value, + mock_open(read_data=self.source_content).return_value, + ] + + undocumented = check_documentation( + doc_path="fake_doc.md", + src_path="fake_source.py", + doc_regex=r"\[\[autodoc\]\]\s([^\n]+)", + src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):", + ) + self.assertEqual(len(undocumented), 0, f"Expected no undocumented classes, got {undocumented}") + + def test_check_documentation_missing_class(self): + # In this test, only FooProcessor is documented, but BarProcessor is missing from the docs + doc_content_missing = """# Documentation +## FooProcessor + +[[autodoc]] module.FooProcessor +""" + with patch("builtins.open", mock_open(read_data=doc_content_missing)) as doc_file: + doc_file.side_effect = [ + mock_open(read_data=doc_content_missing).return_value, + mock_open(read_data=self.source_content).return_value, + ] + + undocumented = check_documentation( + doc_path="fake_doc.md", + src_path="fake_source.py", + doc_regex=r"\[\[autodoc\]\]\s([^\n]+)", + src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):", + ) + self.assertIn("BarProcessor", undocumented, f"BarProcessor should be undocumented, got {undocumented}") diff --git a/utils/check_support_list.py b/utils/check_support_list.py new file mode 100644 index 000000000000..89cfce62de0b --- /dev/null +++ b/utils/check_support_list.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +""" +Utility that checks that modules like attention processors are listed in the documentation file. + +```bash +python utils/check_support_list.py +``` + +It has no auto-fix mode. +""" + +import os +import re + + +# All paths are set with the intent that you run this script from the root of the repo +REPO_PATH = "." + + +def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"): + """ + Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class. + Returns a list of documented class names (just the class name portion). + """ + with open(os.path.join(REPO_PATH, doc_path), "r") as f: + doctext = f.read() + matches = re.findall(autodoc_regex, doctext) + return [match.split(".")[-1] for match in matches] + + +def read_source_classes(src_path, class_regex, exclude_conditions=None): + """ + Reads class names from a source file using a regex that captures class definitions. + Optionally exclude classes based on a list of conditions (functions that take class name and return bool). + """ + if exclude_conditions is None: + exclude_conditions = [] + with open(os.path.join(REPO_PATH, src_path), "r") as f: + doctext = f.read() + classes = re.findall(class_regex, doctext) + # Filter out classes that meet any of the exclude conditions + filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)] + return filtered_classes + + +def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None): + """ + Generic function to check if all classes defined in `src_path` are documented in `doc_path`. + Returns a set of undocumented class names. + """ + documented = set(read_documented_classes(doc_path, doc_regex)) + source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions)) + + # Find which classes in source are not documented in a deterministic way. + undocumented = sorted(source_classes - documented) + return undocumented + + +if __name__ == "__main__": + # Define the checks we need to perform + checks = { + "Attention Processors": { + "doc_path": "docs/source/en/api/attnprocessor.md", + "src_path": "src/diffusers/models/attention_processor.py", + "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", + "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", + "exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"], + }, + "Image Processors": { + "doc_path": "docs/source/en/api/image_processor.md", + "src_path": "src/diffusers/image_processor.py", + "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", + "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", + }, + "Activations": { + "doc_path": "docs/source/en/api/activations.md", + "src_path": "src/diffusers/models/activations.py", + "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", + "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", + }, + "Normalizations": { + "doc_path": "docs/source/en/api/normalization.md", + "src_path": "src/diffusers/models/normalization.py", + "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", + "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", + "exclude_conditions": [ + # Exclude LayerNorm as it's an intentional exception + lambda c: c == "LayerNorm" + ], + }, + "LoRA Mixins": { + "doc_path": "docs/source/en/api/loaders/lora.md", + "src_path": "src/diffusers/loaders/lora_pipeline.py", + "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", + "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", + }, + } + + missing_items = {} + for category, params in checks.items(): + undocumented = check_documentation( + doc_path=params["doc_path"], + src_path=params["src_path"], + doc_regex=params["doc_regex"], + src_regex=params["src_regex"], + exclude_conditions=params.get("exclude_conditions"), + ) + if undocumented: + missing_items[category] = undocumented + + # If we have any missing items, raise a single combined error + if missing_items: + error_msg = ["Some classes are not documented properly:\n"] + for category, classes in missing_items.items(): + error_msg.append(f"- {category}: {', '.join(sorted(classes))}") + raise ValueError("\n".join(error_msg))