|  | 
| 7 | 7 | # | 
| 8 | 8 | #     http://www.apache.org/licenses/LICENSE-2.0 | 
| 9 | 9 | # | 
| 10 |  | -# Unless required by applicable law or agreed to in writing, software | 
| 11 |  | -# distributed under the License is distributed on an "AS IS" BASIS, | 
| 12 |  | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
| 13 |  | -# See the License for the specific language governing permissions and | 
| 14 |  | -# limitations under the License. | 
| 15 | 10 | """ | 
| 16 | 11 | Utility that checks that modules like attention processors are listed in the documentation file. | 
| 17 | 12 | 
 | 
|  | 
| 21 | 16 | 
 | 
| 22 | 17 | It has no auto-fix mode. | 
| 23 | 18 | """ | 
|  | 19 | + | 
| 24 | 20 | import os | 
| 25 | 21 | import re | 
| 26 | 22 | 
 | 
| 27 | 23 | 
 | 
| 28 |  | -# All paths are set with the intent you should run this script from the root of the repo with the command | 
| 29 |  | -# python utils/check_doctest_list.py | 
|  | 24 | +# All paths are set with the intent that you run this script from the root of the repo | 
| 30 | 25 | REPO_PATH = "." | 
| 31 | 26 | 
 | 
| 32 | 27 | 
 | 
| 33 |  | -def check_attention_processors(): | 
| 34 |  | -    with open(os.path.join(REPO_PATH, "docs/source/en/api/attnprocessor.md"), "r") as f: | 
| 35 |  | -        doctext = f.read() | 
| 36 |  | -        matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext) | 
| 37 |  | -        documented_attention_processors = [match.split(".")[-1] for match in matches] | 
| 38 |  | - | 
| 39 |  | -    with open(os.path.join(REPO_PATH, "src/diffusers/models/attention_processor.py"), "r") as f: | 
| 40 |  | -        doctext = f.read() | 
| 41 |  | -        processor_classes = re.findall(r"class\s+(\w+Processor(?:\d*_?\d*))[(:]", doctext) | 
| 42 |  | -        processor_classes = [proc for proc in processor_classes if "LoRA" not in proc and proc != "Attention"] | 
| 43 |  | - | 
| 44 |  | -    undocumented_attn_processors = set() | 
| 45 |  | -    for processor in processor_classes: | 
| 46 |  | -        if processor not in documented_attention_processors: | 
| 47 |  | -            undocumented_attn_processors.add(processor) | 
| 48 |  | - | 
| 49 |  | -    if undocumented_attn_processors: | 
| 50 |  | -        raise ValueError( | 
| 51 |  | -            f"The following attention processors should be in listed in the attention processor documentation but are not: {list(undocumented_attn_processors)}. Please update the documentation." | 
| 52 |  | -        ) | 
| 53 |  | - | 
| 54 |  | - | 
| 55 |  | -def check_image_processors(): | 
| 56 |  | -    with open(os.path.join(REPO_PATH, "docs/source/en/api/image_processor.md"), "r") as f: | 
| 57 |  | -        doctext = f.read() | 
| 58 |  | -        matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext) | 
| 59 |  | -        documented_image_processors = [match.split(".")[-1] for match in matches] | 
| 60 |  | - | 
| 61 |  | -    with open(os.path.join(REPO_PATH, "src/diffusers/image_processor.py"), "r") as f: | 
|  | 28 | +def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"): | 
|  | 29 | +    """ | 
|  | 30 | +    Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class. | 
|  | 31 | +    Returns a list of documented class names (just the class name portion). | 
|  | 32 | +    """ | 
|  | 33 | +    with open(os.path.join(REPO_PATH, doc_path), "r") as f: | 
| 62 | 34 |         doctext = f.read() | 
| 63 |  | -        processor_classes = re.findall(r"class\s+(\w+Processor(?:\d*_?\d*))[(:]", doctext) | 
| 64 |  | - | 
| 65 |  | -    undocumented_img_processors = set() | 
| 66 |  | -    for processor in processor_classes: | 
| 67 |  | -        if processor not in documented_image_processors: | 
| 68 |  | -            undocumented_img_processors.add(processor) | 
| 69 |  | -            raise ValueError( | 
| 70 |  | -                f"The following image processors should be in listed in the image processor documentation but are not: {list(undocumented_img_processors)}. Please update the documentation." | 
| 71 |  | -            ) | 
| 72 |  | - | 
| 73 |  | - | 
| 74 |  | -def check_activations(): | 
| 75 |  | -    with open(os.path.join(REPO_PATH, "docs/source/en/api/activations.md"), "r") as f: | 
|  | 35 | +    matches = re.findall(autodoc_regex, doctext) | 
|  | 36 | +    return [match.split(".")[-1] for match in matches] | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +def read_source_classes(src_path, class_regex, exclude_conditions=None): | 
|  | 40 | +    """ | 
|  | 41 | +    Reads class names from a source file using a regex that captures class definitions. | 
|  | 42 | +    Optionally exclude classes based on a list of conditions (functions that take class name and return bool). | 
|  | 43 | +    """ | 
|  | 44 | +    if exclude_conditions is None: | 
|  | 45 | +        exclude_conditions = [] | 
|  | 46 | +    with open(os.path.join(REPO_PATH, src_path), "r") as f: | 
| 76 | 47 |         doctext = f.read() | 
| 77 |  | -        matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext) | 
| 78 |  | -        documented_activations = [match.split(".")[-1] for match in matches] | 
|  | 48 | +    classes = re.findall(class_regex, doctext) | 
|  | 49 | +    # Filter out classes that meet any of the exclude conditions | 
|  | 50 | +    filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)] | 
|  | 51 | +    return filtered_classes | 
| 79 | 52 | 
 | 
| 80 |  | -    with open(os.path.join(REPO_PATH, "src/diffusers/models/activations.py"), "r") as f: | 
| 81 |  | -        doctext = f.read() | 
| 82 |  | -        activation_classes = re.findall(r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", doctext) | 
| 83 | 53 | 
 | 
| 84 |  | -    undocumented_activations = set() | 
| 85 |  | -    for activation in activation_classes: | 
| 86 |  | -        if activation not in documented_activations: | 
| 87 |  | -            undocumented_activations.add(activation) | 
|  | 54 | +def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None): | 
|  | 55 | +    """ | 
|  | 56 | +    Generic function to check if all classes defined in `src_path` are documented in `doc_path`. | 
|  | 57 | +    Returns a set of undocumented class names. | 
|  | 58 | +    """ | 
|  | 59 | +    documented = set(read_documented_classes(doc_path, doc_regex)) | 
|  | 60 | +    source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions)) | 
| 88 | 61 | 
 | 
| 89 |  | -    if undocumented_activations: | 
| 90 |  | -        raise ValueError( | 
| 91 |  | -            f"The following activations should be in listed in the activations documentation but are not: {list(undocumented_activations)}. Please update the documentation." | 
| 92 |  | -        ) | 
| 93 |  | - | 
| 94 |  | - | 
| 95 |  | -def check_normalizations(): | 
| 96 |  | -    with open(os.path.join(REPO_PATH, "docs/source/en/api/normalization.md"), "r") as f: | 
| 97 |  | -        doctext = f.read() | 
| 98 |  | -        matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext) | 
| 99 |  | -        documented_normalizations = [match.split(".")[-1] for match in matches] | 
| 100 |  | - | 
| 101 |  | -    with open(os.path.join(REPO_PATH, "src/diffusers/models/normalization.py"), "r") as f: | 
| 102 |  | -        doctext = f.read() | 
| 103 |  | -        normalization_classes = re.findall(r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", doctext) | 
| 104 |  | -        # LayerNorm is an exception because adding doc for is confusing. | 
| 105 |  | -        normalization_classes = [norm for norm in normalization_classes if norm != "LayerNorm"] | 
| 106 |  | - | 
| 107 |  | -    undocumented_norms = set() | 
| 108 |  | -    for norm in normalization_classes: | 
| 109 |  | -        if norm not in documented_normalizations: | 
| 110 |  | -            undocumented_norms.add(norm) | 
| 111 |  | - | 
| 112 |  | -    if undocumented_norms: | 
| 113 |  | -        raise ValueError( | 
| 114 |  | -            f"The following norms should be in listed in the normalizations documentation but are not: {list(undocumented_norms)}. Please update the documentation." | 
| 115 |  | -        ) | 
| 116 |  | - | 
| 117 |  | - | 
| 118 |  | -def check_lora_mixins(): | 
| 119 |  | -    with open(os.path.join(REPO_PATH, "docs/source/en/api/loaders/lora.md"), "r") as f: | 
| 120 |  | -        doctext = f.read() | 
| 121 |  | -        matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext) | 
| 122 |  | -        documented_loras = [match.split(".")[-1] for match in matches] | 
| 123 |  | - | 
| 124 |  | -    with open(os.path.join(REPO_PATH, "src/diffusers/loaders/lora_pipeline.py"), "r") as f: | 
| 125 |  | -        doctext = f.read() | 
| 126 |  | -        lora_classes = re.findall(r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", doctext) | 
| 127 |  | - | 
| 128 |  | -    undocumented_loras = set() | 
| 129 |  | -    for lora in lora_classes: | 
| 130 |  | -        if lora not in documented_loras: | 
| 131 |  | -            undocumented_loras.add(lora) | 
| 132 |  | - | 
| 133 |  | -    if undocumented_loras: | 
| 134 |  | -        raise ValueError( | 
| 135 |  | -            f"The following LoRA mixins should be in listed in the LoRA loader documentation but are not: {list(undocumented_loras)}. Please update the documentation." | 
| 136 |  | -        ) | 
|  | 62 | +    # Find which classes in source are not documented in a deterministic way. | 
|  | 63 | +    undocumented = sorted(source_classes - documented) | 
|  | 64 | +    return undocumented | 
| 137 | 65 | 
 | 
| 138 | 66 | 
 | 
| 139 | 67 | if __name__ == "__main__": | 
| 140 |  | -    check_attention_processors() | 
| 141 |  | -    check_image_processors() | 
| 142 |  | -    check_activations() | 
| 143 |  | -    check_normalizations() | 
| 144 |  | -    check_lora_mixins() | 
|  | 68 | +    # Define the checks we need to perform | 
|  | 69 | +    checks = { | 
|  | 70 | +        "Attention Processors": { | 
|  | 71 | +            "doc_path": "docs/source/en/api/attnprocessor.md", | 
|  | 72 | +            "src_path": "src/diffusers/models/attention_processor.py", | 
|  | 73 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 74 | +            "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | 
|  | 75 | +            "exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"], | 
|  | 76 | +        }, | 
|  | 77 | +        "Image Processors": { | 
|  | 78 | +            "doc_path": "docs/source/en/api/image_processor.md", | 
|  | 79 | +            "src_path": "src/diffusers/image_processor.py", | 
|  | 80 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 81 | +            "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | 
|  | 82 | +        }, | 
|  | 83 | +        "Activations": { | 
|  | 84 | +            "doc_path": "docs/source/en/api/activations.md", | 
|  | 85 | +            "src_path": "src/diffusers/models/activations.py", | 
|  | 86 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 87 | +            "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | 
|  | 88 | +        }, | 
|  | 89 | +        "Normalizations": { | 
|  | 90 | +            "doc_path": "docs/source/en/api/normalization.md", | 
|  | 91 | +            "src_path": "src/diffusers/models/normalization.py", | 
|  | 92 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 93 | +            "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | 
|  | 94 | +            "exclude_conditions": [ | 
|  | 95 | +                # Exclude LayerNorm as it's an intentional exception | 
|  | 96 | +                lambda c: c == "LayerNorm" | 
|  | 97 | +            ], | 
|  | 98 | +        }, | 
|  | 99 | +        "LoRA Mixins": { | 
|  | 100 | +            "doc_path": "docs/source/en/api/loaders/lora.md", | 
|  | 101 | +            "src_path": "src/diffusers/loaders/lora_pipeline.py", | 
|  | 102 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 103 | +            "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | 
|  | 104 | +        }, | 
|  | 105 | +    } | 
|  | 106 | + | 
|  | 107 | +    missing_items = {} | 
|  | 108 | +    for category, params in checks.items(): | 
|  | 109 | +        undocumented = check_documentation( | 
|  | 110 | +            doc_path=params["doc_path"], | 
|  | 111 | +            src_path=params["src_path"], | 
|  | 112 | +            doc_regex=params["doc_regex"], | 
|  | 113 | +            src_regex=params["src_regex"], | 
|  | 114 | +            exclude_conditions=params.get("exclude_conditions"), | 
|  | 115 | +        ) | 
|  | 116 | +        if undocumented: | 
|  | 117 | +            missing_items[category] = undocumented | 
|  | 118 | + | 
|  | 119 | +    # If we have any missing items, raise a single combined error | 
|  | 120 | +    if missing_items: | 
|  | 121 | +        error_msg = ["Some classes are not documented properly:\n"] | 
|  | 122 | +        for category, classes in missing_items.items(): | 
|  | 123 | +            error_msg.append(f"- {category}: {', '.join(sorted(classes))}") | 
|  | 124 | +        raise ValueError("\n".join(error_msg)) | 
0 commit comments