Skip to content

Commit e5ea287

Browse files
committed
remove k diffusion tests
1 parent 4c20624 commit e5ea287

File tree

4 files changed

+87
-325
lines changed

4 files changed

+87
-325
lines changed

remove_stuff.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# list_deprecated_test_classes.py
2+
import ast
3+
import importlib
4+
import inspect
5+
from pathlib import Path
6+
7+
# --- Configuration ---
8+
PROJECT_ROOT = Path(__file__).parent.resolve()
9+
TESTS_DIR = PROJECT_ROOT / "tests" / "pipelines"
10+
PIPELINE_PACKAGE = "diffusers"
11+
12+
def find_deprecated_test_classes():
13+
"""
14+
Finds test files and the specific test class names that test deprecated pipelines.
15+
"""
16+
print(f"🔍 Project root set to: {PROJECT_ROOT}")
17+
print(f"🔍 Recursively searching for deprecated test classes in: {TESTS_DIR}\n")
18+
deprecated_tests = {}
19+
20+
# 1. Get the DeprecatedPipelineMixin class to check against.
21+
try:
22+
mixin_module = importlib.import_module(f"{PIPELINE_PACKAGE}.pipelines.pipeline_utils")
23+
DeprecatedPipelineMixin = getattr(mixin_module, 'DeprecatedPipelineMixin')
24+
except (ImportError, AttributeError):
25+
print("❌ Error: Could not import DeprecatedPipelineMixin.")
26+
return {}
27+
28+
if not TESTS_DIR.is_dir():
29+
print(f"❌ Error: The directory '{TESTS_DIR}' does not exist.")
30+
return {}
31+
32+
# 2. Recursively find all test files.
33+
for filepath in sorted(TESTS_DIR.rglob("test_*.py")):
34+
try:
35+
with open(filepath, 'r', encoding='utf-8') as f:
36+
source_code = f.read()
37+
tree = ast.parse(source_code)
38+
39+
relative_path = str(filepath.relative_to(PROJECT_ROOT))
40+
deprecated_pipelines_in_file = set()
41+
42+
# 3. First pass: find all imported pipelines in the file that are deprecated.
43+
for node in ast.walk(tree):
44+
if isinstance(node, ast.ImportFrom) and node.module and node.module.startswith(PIPELINE_PACKAGE):
45+
for alias in node.names:
46+
try:
47+
pipeline_module = importlib.import_module(node.module)
48+
pipeline_class = getattr(pipeline_module, alias.name)
49+
if inspect.isclass(pipeline_class) and issubclass(pipeline_class, DeprecatedPipelineMixin):
50+
deprecated_pipelines_in_file.add(pipeline_class.__name__)
51+
except (AttributeError, ImportError, TypeError):
52+
continue
53+
54+
if not deprecated_pipelines_in_file:
55+
continue
56+
57+
# 4. Second pass: find test classes that use these deprecated pipelines.
58+
for node in ast.walk(tree):
59+
if isinstance(node, ast.ClassDef):
60+
# Heuristic: Check if the class name contains the name of a deprecated pipeline.
61+
# This is a robust way to link a test class to the pipeline it tests.
62+
for pipeline_name in deprecated_pipelines_in_file:
63+
# e.g., Pipeline: BlipDiffusionPipeline, Test Class: BlipDiffusionPipelineTests
64+
if pipeline_name in node.name:
65+
if relative_path not in deprecated_tests:
66+
deprecated_tests[relative_path] = []
67+
deprecated_tests[relative_path].append(node.name)
68+
69+
except Exception as e:
70+
print(f"⚠️ Could not process {filepath.name}: {e}")
71+
continue
72+
73+
return deprecated_tests
74+
75+
if __name__ == "__main__":
76+
found_tests = find_deprecated_test_classes()
77+
78+
if found_tests:
79+
print("\n" + "="*70)
80+
print("✅ Found Test Files and Classes for Deprecated Pipelines:")
81+
print("="*70)
82+
for file, classes in found_tests.items():
83+
print(f"\n📄 File: {file}")
84+
for cls in sorted(list(set(classes))):
85+
print(f" - Test Class: {cls}")
86+
else:
87+
print("\nNo deprecated pipeline test classes were found.")

tests/pipelines/stable_diffusion_k_diffusion/__init__.py

Whitespace-only changes.

tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py

Lines changed: 0 additions & 147 deletions
This file was deleted.

0 commit comments

Comments
 (0)