|  | 
|  | 1 | +# coding=utf-8 | 
|  | 2 | +# Copyright 2024 The HuggingFace Inc. team. | 
|  | 3 | +# | 
|  | 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 5 | +# you may not use this file except in compliance with the License. | 
|  | 6 | +# You may obtain a copy of the License at | 
|  | 7 | +# | 
|  | 8 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 9 | +# | 
|  | 10 | +""" | 
|  | 11 | +Utility that checks that modules like attention processors are listed in the documentation file. | 
|  | 12 | +
 | 
|  | 13 | +```bash | 
|  | 14 | +python utils/check_support_list.py | 
|  | 15 | +``` | 
|  | 16 | +
 | 
|  | 17 | +It has no auto-fix mode. | 
|  | 18 | +""" | 
|  | 19 | + | 
|  | 20 | +import os | 
|  | 21 | +import re | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +# All paths are set with the intent that you run this script from the root of the repo | 
|  | 25 | +REPO_PATH = "." | 
|  | 26 | + | 
|  | 27 | + | 
|  | 28 | +def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"): | 
|  | 29 | +    """ | 
|  | 30 | +    Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class. | 
|  | 31 | +    Returns a list of documented class names (just the class name portion). | 
|  | 32 | +    """ | 
|  | 33 | +    with open(os.path.join(REPO_PATH, doc_path), "r") as f: | 
|  | 34 | +        doctext = f.read() | 
|  | 35 | +    matches = re.findall(autodoc_regex, doctext) | 
|  | 36 | +    return [match.split(".")[-1] for match in matches] | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +def read_source_classes(src_path, class_regex, exclude_conditions=None): | 
|  | 40 | +    """ | 
|  | 41 | +    Reads class names from a source file using a regex that captures class definitions. | 
|  | 42 | +    Optionally exclude classes based on a list of conditions (functions that take class name and return bool). | 
|  | 43 | +    """ | 
|  | 44 | +    if exclude_conditions is None: | 
|  | 45 | +        exclude_conditions = [] | 
|  | 46 | +    with open(os.path.join(REPO_PATH, src_path), "r") as f: | 
|  | 47 | +        doctext = f.read() | 
|  | 48 | +    classes = re.findall(class_regex, doctext) | 
|  | 49 | +    # Filter out classes that meet any of the exclude conditions | 
|  | 50 | +    filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)] | 
|  | 51 | +    return filtered_classes | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None): | 
|  | 55 | +    """ | 
|  | 56 | +    Generic function to check if all classes defined in `src_path` are documented in `doc_path`. | 
|  | 57 | +    Returns a set of undocumented class names. | 
|  | 58 | +    """ | 
|  | 59 | +    documented = set(read_documented_classes(doc_path, doc_regex)) | 
|  | 60 | +    source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions)) | 
|  | 61 | + | 
|  | 62 | +    # Find which classes in source are not documented in a deterministic way. | 
|  | 63 | +    undocumented = sorted(source_classes - documented) | 
|  | 64 | +    return undocumented | 
|  | 65 | + | 
|  | 66 | + | 
|  | 67 | +if __name__ == "__main__": | 
|  | 68 | +    # Define the checks we need to perform | 
|  | 69 | +    checks = { | 
|  | 70 | +        "Attention Processors": { | 
|  | 71 | +            "doc_path": "docs/source/en/api/attnprocessor.md", | 
|  | 72 | +            "src_path": "src/diffusers/models/attention_processor.py", | 
|  | 73 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 74 | +            "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | 
|  | 75 | +            "exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"], | 
|  | 76 | +        }, | 
|  | 77 | +        "Image Processors": { | 
|  | 78 | +            "doc_path": "docs/source/en/api/image_processor.md", | 
|  | 79 | +            "src_path": "src/diffusers/image_processor.py", | 
|  | 80 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 81 | +            "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | 
|  | 82 | +        }, | 
|  | 83 | +        "Activations": { | 
|  | 84 | +            "doc_path": "docs/source/en/api/activations.md", | 
|  | 85 | +            "src_path": "src/diffusers/models/activations.py", | 
|  | 86 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 87 | +            "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | 
|  | 88 | +        }, | 
|  | 89 | +        "Normalizations": { | 
|  | 90 | +            "doc_path": "docs/source/en/api/normalization.md", | 
|  | 91 | +            "src_path": "src/diffusers/models/normalization.py", | 
|  | 92 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 93 | +            "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | 
|  | 94 | +            "exclude_conditions": [ | 
|  | 95 | +                # Exclude LayerNorm as it's an intentional exception | 
|  | 96 | +                lambda c: c == "LayerNorm" | 
|  | 97 | +            ], | 
|  | 98 | +        }, | 
|  | 99 | +        "LoRA Mixins": { | 
|  | 100 | +            "doc_path": "docs/source/en/api/loaders/lora.md", | 
|  | 101 | +            "src_path": "src/diffusers/loaders/lora_pipeline.py", | 
|  | 102 | +            "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | 
|  | 103 | +            "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | 
|  | 104 | +        }, | 
|  | 105 | +    } | 
|  | 106 | + | 
|  | 107 | +    missing_items = {} | 
|  | 108 | +    for category, params in checks.items(): | 
|  | 109 | +        undocumented = check_documentation( | 
|  | 110 | +            doc_path=params["doc_path"], | 
|  | 111 | +            src_path=params["src_path"], | 
|  | 112 | +            doc_regex=params["doc_regex"], | 
|  | 113 | +            src_regex=params["src_regex"], | 
|  | 114 | +            exclude_conditions=params.get("exclude_conditions"), | 
|  | 115 | +        ) | 
|  | 116 | +        if undocumented: | 
|  | 117 | +            missing_items[category] = undocumented | 
|  | 118 | + | 
|  | 119 | +    # If we have any missing items, raise a single combined error | 
|  | 120 | +    if missing_items: | 
|  | 121 | +        error_msg = ["Some classes are not documented properly:\n"] | 
|  | 122 | +        for category, classes in missing_items.items(): | 
|  | 123 | +            error_msg.append(f"- {category}: {', '.join(sorted(classes))}") | 
|  | 124 | +        raise ValueError("\n".join(error_msg)) | 
0 commit comments