Skip to content

Commit 7eb617a

Browse files
committed
fixes
1 parent 75136e6 commit 7eb617a

File tree

2 files changed

+159
-111
lines changed

2 files changed

+159
-111
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
import sys
3+
import unittest
4+
from unittest.mock import mock_open, patch
5+
6+
7+
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
8+
sys.path.append(os.path.join(git_repo_path, "utils"))
9+
10+
from check_support_list import check_documentation # noqa: E402
11+
12+
13+
class TestCheckSupportList(unittest.TestCase):
14+
def setUp(self):
15+
# Mock doc and source contents that we can reuse
16+
self.doc_content = """# Documentation
17+
## FooProcessor
18+
19+
[[autodoc]] module.FooProcessor
20+
21+
## BarProcessor
22+
23+
[[autodoc]] module.BarProcessor
24+
"""
25+
self.source_content = """
26+
class FooProcessor(nn.Module):
27+
pass
28+
29+
class BarProcessor(nn.Module):
30+
pass
31+
"""
32+
33+
def test_check_documentation_all_documented(self):
34+
# In this test, both FooProcessor and BarProcessor are documented
35+
with patch("builtins.open", mock_open(read_data=self.doc_content)) as doc_file:
36+
doc_file.side_effect = [
37+
mock_open(read_data=self.doc_content).return_value,
38+
mock_open(read_data=self.source_content).return_value,
39+
]
40+
41+
undocumented = check_documentation(
42+
doc_path="fake_doc.md",
43+
src_path="fake_source.py",
44+
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
45+
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
46+
)
47+
self.assertEqual(len(undocumented), 0, f"Expected no undocumented classes, got {undocumented}")
48+
49+
def test_check_documentation_missing_class(self):
50+
# In this test, only FooProcessor is documented, but BarProcessor is missing from the docs
51+
doc_content_missing = """# Documentation
52+
## FooProcessor
53+
54+
[[autodoc]] module.FooProcessor
55+
"""
56+
with patch("builtins.open", mock_open(read_data=doc_content_missing)) as doc_file:
57+
doc_file.side_effect = [
58+
mock_open(read_data=doc_content_missing).return_value,
59+
mock_open(read_data=self.source_content).return_value,
60+
]
61+
62+
undocumented = check_documentation(
63+
doc_path="fake_doc.md",
64+
src_path="fake_source.py",
65+
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
66+
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
67+
)
68+
self.assertIn("BarProcessor", undocumented, f"BarProcessor should be undocumented, got {undocumented}")

utils/check_support_list.py

Lines changed: 91 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
#
88
# http://www.apache.org/licenses/LICENSE-2.0
99
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
1510
"""
1611
Utility that checks that modules like attention processors are listed in the documentation file.
1712
@@ -21,124 +16,109 @@
2116
2217
It has no auto-fix mode.
2318
"""
19+
2420
import os
2521
import re
2622

2723

28-
# All paths are set with the intent you should run this script from the root of the repo with the command
29-
# python utils/check_doctest_list.py
24+
# All paths are set with the intent that you run this script from the root of the repo
3025
REPO_PATH = "."
3126

3227

33-
def check_attention_processors():
34-
with open(os.path.join(REPO_PATH, "docs/source/en/api/attnprocessor.md"), "r") as f:
35-
doctext = f.read()
36-
matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext)
37-
documented_attention_processors = [match.split(".")[-1] for match in matches]
38-
39-
with open(os.path.join(REPO_PATH, "src/diffusers/models/attention_processor.py"), "r") as f:
40-
doctext = f.read()
41-
processor_classes = re.findall(r"class\s+(\w+Processor(?:\d*_?\d*))[(:]", doctext)
42-
processor_classes = [proc for proc in processor_classes if "LoRA" not in proc and proc != "Attention"]
43-
44-
undocumented_attn_processors = set()
45-
for processor in processor_classes:
46-
if processor not in documented_attention_processors:
47-
undocumented_attn_processors.add(processor)
48-
49-
if undocumented_attn_processors:
50-
raise ValueError(
51-
f"The following attention processors should be in listed in the attention processor documentation but are not: {list(undocumented_attn_processors)}. Please update the documentation."
52-
)
53-
54-
55-
def check_image_processors():
56-
with open(os.path.join(REPO_PATH, "docs/source/en/api/image_processor.md"), "r") as f:
57-
doctext = f.read()
58-
matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext)
59-
documented_image_processors = [match.split(".")[-1] for match in matches]
60-
61-
with open(os.path.join(REPO_PATH, "src/diffusers/image_processor.py"), "r") as f:
28+
def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"):
29+
"""
30+
Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class.
31+
Returns a list of documented class names (just the class name portion).
32+
"""
33+
with open(os.path.join(REPO_PATH, doc_path), "r") as f:
6234
doctext = f.read()
63-
processor_classes = re.findall(r"class\s+(\w+Processor(?:\d*_?\d*))[(:]", doctext)
64-
65-
undocumented_img_processors = set()
66-
for processor in processor_classes:
67-
if processor not in documented_image_processors:
68-
undocumented_img_processors.add(processor)
69-
raise ValueError(
70-
f"The following image processors should be in listed in the image processor documentation but are not: {list(undocumented_img_processors)}. Please update the documentation."
71-
)
72-
73-
74-
def check_activations():
75-
with open(os.path.join(REPO_PATH, "docs/source/en/api/activations.md"), "r") as f:
35+
matches = re.findall(autodoc_regex, doctext)
36+
return [match.split(".")[-1] for match in matches]
37+
38+
39+
def read_source_classes(src_path, class_regex, exclude_conditions=None):
40+
"""
41+
Reads class names from a source file using a regex that captures class definitions.
42+
Optionally exclude classes based on a list of conditions (functions that take class name and return bool).
43+
"""
44+
if exclude_conditions is None:
45+
exclude_conditions = []
46+
with open(os.path.join(REPO_PATH, src_path), "r") as f:
7647
doctext = f.read()
77-
matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext)
78-
documented_activations = [match.split(".")[-1] for match in matches]
48+
classes = re.findall(class_regex, doctext)
49+
# Filter out classes that meet any of the exclude conditions
50+
filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)]
51+
return filtered_classes
7952

80-
with open(os.path.join(REPO_PATH, "src/diffusers/models/activations.py"), "r") as f:
81-
doctext = f.read()
82-
activation_classes = re.findall(r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", doctext)
8353

84-
undocumented_activations = set()
85-
for activation in activation_classes:
86-
if activation not in documented_activations:
87-
undocumented_activations.add(activation)
54+
def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None):
55+
"""
56+
Generic function to check if all classes defined in `src_path` are documented in `doc_path`.
57+
Returns a set of undocumented class names.
58+
"""
59+
documented = set(read_documented_classes(doc_path, doc_regex))
60+
source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions))
8861

89-
if undocumented_activations:
90-
raise ValueError(
91-
f"The following activations should be in listed in the activations documentation but are not: {list(undocumented_activations)}. Please update the documentation."
92-
)
93-
94-
95-
def check_normalizations():
96-
with open(os.path.join(REPO_PATH, "docs/source/en/api/normalization.md"), "r") as f:
97-
doctext = f.read()
98-
matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext)
99-
documented_normalizations = [match.split(".")[-1] for match in matches]
100-
101-
with open(os.path.join(REPO_PATH, "src/diffusers/models/normalization.py"), "r") as f:
102-
doctext = f.read()
103-
normalization_classes = re.findall(r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", doctext)
104-
# LayerNorm is an exception because adding doc for is confusing.
105-
normalization_classes = [norm for norm in normalization_classes if norm != "LayerNorm"]
106-
107-
undocumented_norms = set()
108-
for norm in normalization_classes:
109-
if norm not in documented_normalizations:
110-
undocumented_norms.add(norm)
111-
112-
if undocumented_norms:
113-
raise ValueError(
114-
f"The following norms should be in listed in the normalizations documentation but are not: {list(undocumented_norms)}. Please update the documentation."
115-
)
116-
117-
118-
def check_lora_mixins():
119-
with open(os.path.join(REPO_PATH, "docs/source/en/api/loaders/lora.md"), "r") as f:
120-
doctext = f.read()
121-
matches = re.findall(r"\[\[autodoc\]\]\s([^\n]+)", doctext)
122-
documented_loras = [match.split(".")[-1] for match in matches]
123-
124-
with open(os.path.join(REPO_PATH, "src/diffusers/loaders/lora_pipeline.py"), "r") as f:
125-
doctext = f.read()
126-
lora_classes = re.findall(r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", doctext)
127-
128-
undocumented_loras = set()
129-
for lora in lora_classes:
130-
if lora not in documented_loras:
131-
undocumented_loras.add(lora)
132-
133-
if undocumented_loras:
134-
raise ValueError(
135-
f"The following LoRA mixins should be in listed in the LoRA loader documentation but are not: {list(undocumented_loras)}. Please update the documentation."
136-
)
62+
# Find which classes in source are not documented in a deterministic way.
63+
undocumented = sorted(source_classes - documented)
64+
return undocumented
13765

13866

13967
if __name__ == "__main__":
140-
check_attention_processors()
141-
check_image_processors()
142-
check_activations()
143-
check_normalizations()
144-
check_lora_mixins()
68+
# Define the checks we need to perform
69+
checks = {
70+
"Attention Processors": {
71+
"doc_path": "docs/source/en/api/attnprocessor.md",
72+
"src_path": "src/diffusers/models/attention_processor.py",
73+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
74+
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
75+
"exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"],
76+
},
77+
"Image Processors": {
78+
"doc_path": "docs/source/en/api/image_processor.md",
79+
"src_path": "src/diffusers/image_processor.py",
80+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
81+
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
82+
},
83+
"Activations": {
84+
"doc_path": "docs/source/en/api/activations.md",
85+
"src_path": "src/diffusers/models/activations.py",
86+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
87+
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
88+
},
89+
"Normalizations": {
90+
"doc_path": "docs/source/en/api/normalization.md",
91+
"src_path": "src/diffusers/models/normalization.py",
92+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
93+
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
94+
"exclude_conditions": [
95+
# Exclude LayerNorm as it's an intentional exception
96+
lambda c: c == "LayerNorm"
97+
],
98+
},
99+
"LoRA Mixins": {
100+
"doc_path": "docs/source/en/api/loaders/lora.md",
101+
"src_path": "src/diffusers/loaders/lora_pipeline.py",
102+
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
103+
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
104+
},
105+
}
106+
107+
missing_items = {}
108+
for category, params in checks.items():
109+
undocumented = check_documentation(
110+
doc_path=params["doc_path"],
111+
src_path=params["src_path"],
112+
doc_regex=params["doc_regex"],
113+
src_regex=params["src_regex"],
114+
exclude_conditions=params.get("exclude_conditions"),
115+
)
116+
if undocumented:
117+
missing_items[category] = undocumented
118+
119+
# If we have any missing items, raise a single combined error
120+
if missing_items:
121+
error_msg = ["Some classes are not documented properly:\n"]
122+
for category, classes in missing_items.items():
123+
error_msg.append(f"- {category}: {', '.join(sorted(classes))}")
124+
raise ValueError("\n".join(error_msg))

0 commit comments

Comments
 (0)