Skip to content

Commit 6c6cade

Browse files
committed
migrate lora pipeline tests to pytest
1 parent 7242b5f commit 6c6cade

19 files changed

+4104
-1359
lines changed

fix_asserts_lora.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Fix F631-style asserts of the form:
4+
assert (<expr>, "message")
5+
…into:
6+
assert <expr>, "message"
7+
8+
Scans recursively under tests/lora/.
9+
10+
Usage:
11+
python fix_assert_tuple.py [--root tests/lora] [--dry-run]
12+
"""
13+
14+
import argparse
15+
import ast
16+
from pathlib import Path
17+
from typing import Tuple, List, Optional
18+
19+
20+
class AssertTupleFixer(ast.NodeTransformer):
21+
"""
22+
Transform `assert (<expr>, <msg>)` into `assert <expr>, <msg>`.
23+
We only rewrite when the assert test is a Tuple with exactly 2 elements.
24+
"""
25+
def __init__(self):
26+
super().__init__()
27+
self.fixed_locs: List[Tuple[int, int]] = []
28+
29+
def visit_Assert(self, node: ast.Assert) -> ast.AST:
30+
self.generic_visit(node)
31+
if isinstance(node.test, ast.Tuple) and len(node.test.elts) == 2:
32+
cond, msg = node.test.elts
33+
# Convert only if this *looks* like a real assert-with-message tuple,
34+
# i.e. keep anything as msg (string, f-string, name, call, etc.)
35+
new_node = ast.Assert(test=cond, msg=msg)
36+
ast.copy_location(new_node, node)
37+
ast.fix_missing_locations(new_node)
38+
self.fixed_locs.append((node.lineno, node.col_offset))
39+
return new_node
40+
return node
41+
42+
43+
def fix_file(path: Path, dry_run: bool = False) -> int:
44+
"""
45+
Returns number of fixes applied.
46+
"""
47+
try:
48+
src = path.read_text(encoding="utf-8")
49+
except Exception as e:
50+
print(f"Could not read {path}: {e}")
51+
return 0
52+
53+
try:
54+
tree = ast.parse(src, filename=str(path))
55+
except SyntaxError:
56+
# Skip files that don’t parse (partial edits, etc.)
57+
return 0
58+
59+
fixer = AssertTupleFixer()
60+
new_tree = fixer.visit(tree)
61+
fixes = len(fixer.fixed_locs)
62+
if fixes == 0:
63+
return 0
64+
65+
try:
66+
new_src = ast.unparse(new_tree) # Python 3.9+
67+
except Exception as e:
68+
print(f"Failed to unparse {path}: {e}")
69+
return 0
70+
71+
if dry_run:
72+
for (lineno, col) in fixer.fixed_locs:
73+
print(f"[DRY-RUN] {path}:{lineno}:{col} -> fixed assert tuple")
74+
return fixes
75+
76+
# Backup and write
77+
backup = path.with_suffix(path.suffix + ".bak")
78+
try:
79+
if not backup.exists():
80+
backup.write_text(src, encoding="utf-8")
81+
path.write_text(new_src, encoding="utf-8")
82+
for (lineno, col) in fixer.fixed_locs:
83+
print(f"Fixed {path}:{lineno}:{col}")
84+
except Exception as e:
85+
print(f"Failed to write {path}: {e}")
86+
return 0
87+
88+
return fixes
89+
90+
91+
def main():
92+
ap = argparse.ArgumentParser(description="Fix F631-style tuple asserts.")
93+
ap.add_argument("--root", default="tests/lora", help="Root directory to scan")
94+
ap.add_argument("--dry-run", action="store_true", help="Report changes but don't write")
95+
args = ap.parse_args()
96+
97+
root = Path(args.root)
98+
if not root.exists():
99+
print(f"{root} does not exist.")
100+
return
101+
102+
total_files = 0
103+
total_fixes = 0
104+
for pyfile in root.rglob("*.py"):
105+
total_files += 1
106+
total_fixes += fix_file(pyfile, dry_run=args.dry_run)
107+
108+
print(f"\nScanned {total_files} file(s). Applied {total_fixes} fix(es).")
109+
if args.dry_run:
110+
print("Run again without --dry-run to apply changes.")
111+
112+
113+
if __name__ == "__main__":
114+
main()

tests/lora/test_lora_layers_auraflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
@require_peft_backend
43-
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
43+
class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
4444
pipeline_class = AuraFlowPipeline
4545
scheduler_cls = FlowMatchEulerDiscreteScheduler
4646
scheduler_kwargs = {}

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
@require_peft_backend
42-
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
42+
class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
4343
pipeline_class = CogVideoXPipeline
4444
scheduler_cls = CogVideoXDPMScheduler
4545
scheduler_kwargs = {"timestep_spacing": "trailing"}

tests/lora/test_lora_layers_cogview4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def from_pretrained(*args, **kwargs):
4747

4848
@require_peft_backend
4949
@skip_mps
50-
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
50+
class TestCogView4LoRA(PeftLoraLoaderMixinTests):
5151
pipeline_class = CogView4Pipeline
5252
scheduler_cls = FlowMatchEulerDiscreteScheduler
5353
scheduler_kwargs = {}

0 commit comments

Comments
 (0)