Skip to content

Commit b19f582

Browse files
committed
tests
1 parent acc2fbf commit b19f582

File tree

5 files changed

+603
-29
lines changed

5 files changed

+603
-29
lines changed

tests/reproduce_conv_issues.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import os
2+
import sys
3+
4+
_here = os.path.abspath(os.path.dirname(__file__))
5+
_project_root = os.path.abspath(os.path.join(_here, os.pardir))
6+
if _project_root not in sys.path:
7+
sys.path.insert(0, _project_root)
8+
9+
_comfy_root = os.path.abspath(os.path.join(_here, "../../.."))
10+
if _comfy_root not in sys.path:
11+
sys.path.insert(0, _comfy_root)
12+
13+
import torch
14+
import pytest
15+
from more_math.Parser.UnifiedMathVisitor import UnifiedMathVisitor
16+
17+
from more_math.LatentMathNode import LatentMathNode
18+
19+
def test_conv_1d():
20+
"""
21+
Test 1D convolution.
22+
Input: [Batch, Length, Channels] = [1, 10, 4]
23+
Kernel: 1D size 3
24+
"""
25+
print("\n--- Testing 1D Conv ---")
26+
node = LatentMathNode()
27+
shape = (1, 10, 4)
28+
a_val = torch.randn(*shape)
29+
30+
# conv(a, 3, 1.0) -> implies kernel of ones, size 3
31+
# Result should correspond to 1D conv
32+
try:
33+
# LatentMathNode expects latent dicts usually
34+
input_dict = {"samples": a_val}
35+
res = node.execute("conv(a, 3, 1.0)", a=input_dict)
36+
37+
# LatentMathNode returns list of dicts
38+
res_tensor = res[0]["samples"]
39+
40+
print(f"1D Conv Result Shape: {res_tensor.shape}")
41+
# Expect (1, 10, 4)
42+
assert res_tensor.shape == shape
43+
except Exception as e:
44+
print(f"1D Conv Failed: {e}")
45+
raise
46+
47+
def test_conv_3d():
48+
"""
49+
Test 3D convolution.
50+
Input: [Batch, Depth, Height, Width, Channels] = [1, 5, 32, 32, 4]
51+
Kernel: 3D size 3x3x3
52+
"""
53+
print("\n--- Testing 3D Conv ---")
54+
node = LatentMathNode()
55+
shape = (1, 5, 32, 32, 4)
56+
a_val = torch.randn(*shape)
57+
58+
try:
59+
input_dict = {"samples": a_val}
60+
res = node.execute("conv(a, 3, 3, 3, 1.0)", a=input_dict)
61+
res_tensor = res[0]["samples"]
62+
print(f"3D Conv Result Shape: {res_tensor.shape}")
63+
assert res_tensor.shape == shape
64+
except Exception as e:
65+
print(f"3D Conv Failed: {e}")
66+
raise
67+
68+
def test_conv_arbitrary_batch():
69+
"""
70+
Test generic tensor with extra batch dims.
71+
Input: [B1, B2, H, W, C] = [2, 2, 16, 16, 4] -> Should be treated as Batch=4
72+
"""
73+
print("\n--- Testing Arbitrary Batch ---")
74+
node = LatentMathNode()
75+
shape = (2, 2, 16, 16, 4)
76+
a_val = torch.randn(*shape)
77+
78+
try:
79+
# conv(a, 3, 3, 1.0) -> 2D conv on (16,16)
80+
input_dict = {"samples": a_val}
81+
res = node.execute("conv(a, 3, 3, 1.0)", a=input_dict)
82+
res_tensor = res[0]["samples"]
83+
print(f"Arbitrary Batch Result Shape: {res_tensor.shape}")
84+
assert res_tensor.shape == shape
85+
except Exception as e:
86+
print(f"Arbitrary Batch Failed: {e}")
87+
raise
88+
89+
def test_conv_list_kernel():
90+
"""
91+
Test conv with list kernel (Regression test for float64 mismatch).
92+
Kernel: 3x3x3 list of floats.
93+
"""
94+
print("\n--- Testing List Kernel Conv ---")
95+
node = LatentMathNode()
96+
shape = (1, 5, 10, 10, 4) # [B, D, H, W, C]
97+
a_val = torch.randn(*shape).float()
98+
99+
# 3x3x3 kernel = 27 elements
100+
# Using the user's example kernel
101+
kernel_list = [1,1,1,1,0,1,1,1,1, 0,0,0,0,1,0,0,0,0, 1,1,1,1,0,1,1,1,1]
102+
kernel_str = str(kernel_list)
103+
expr = f"conv(a, 3, 3, 3, {kernel_str})/8"
104+
105+
try:
106+
input_dict = {"samples": a_val}
107+
res = node.execute(expr, a=input_dict)
108+
res_tensor = res[0]["samples"]
109+
print(f"List Kernel Result Shape: {res_tensor.shape}")
110+
assert res_tensor.shape == shape
111+
assert res_tensor.dtype == torch.float32
112+
except Exception as e:
113+
print(f"List Kernel Failed: {e}")
114+
raise
115+
116+
def test_conv_audio():
117+
"""
118+
Test 1D conv on Audio [B, C, L].
119+
Input: [1, 2, 100]. Kernel: 3.
120+
Should be treated as Channels First -> [B, L, C].
121+
Output should preserve Channels First [B, 2, 100].
122+
"""
123+
print("\n--- Testing Audio Conv [B, C, L] ---")
124+
node = LatentMathNode()
125+
shape = (1, 2, 100) # [B, C, L] (L >> C)
126+
a_val = torch.randn(*shape).float()
127+
128+
# conv(a, 3, 1.0) on last dim (L)
129+
# Expected: result shape same as input
130+
try:
131+
input_dict = {"samples": a_val}
132+
res = node.execute("conv(a, 3, 1.0)", a=input_dict)
133+
res_tensor = res[0]["samples"]
134+
print(f"Audio Result Shape: {res_tensor.shape}")
135+
136+
if res_tensor.shape != shape:
137+
print(f"Likely interpreted as Channels Last [B, L, C] where C is small? No.")
138+
# If interpreted as Channels last [..., C].
139+
# [1, 2, 100]. Spatial=[2]. Channel=100.
140+
# Output [1, 2, 100] (but confusing channels).
141+
pass
142+
143+
assert res_tensor.shape == shape
144+
except Exception as e:
145+
print(f"Audio Conv Failed: {e}")
146+
raise
147+
148+
def test_conv_deep_latent():
149+
"""
150+
Test 3D conv on Deep Latent [B, 32, H, W] (User request).
151+
Input: [1, 32, 16, 16]. Kernel: 3x3x3.
152+
Should be treated as Channels First -> [B, 32, 16, 16, 1].
153+
Depth=32. H=16. W=16.
154+
"""
155+
print("\n--- Testing Deep Latent Conv [B, 32, H, W] ---")
156+
node = LatentMathNode()
157+
shape = (1, 32, 16, 16)
158+
a_val = torch.randn(*shape).float()
159+
160+
# conv(a, 3, 3, 3, 1.0)
161+
# 3D kernels need D,H,W.
162+
# D=32 (Channel). H=16. W=16.
163+
try:
164+
input_dict = {"samples": a_val}
165+
res = node.execute("conv(a, 3, 3, 3, 1.0)", a=input_dict)
166+
res_tensor = res[0]["samples"]
167+
print(f"Deep Latent Result Shape: {res_tensor.shape}")
168+
assert res_tensor.shape == shape
169+
170+
# Identity check (ensure D neighbors engaged)
171+
# Using simple kernel, center only vs ones.
172+
# But this test just checks shape and execution path.
173+
except Exception as e:
174+
print(f"Deep Latent Failed: {e}")
175+
raise
176+
177+
def test_conv_padding():
178+
"""
179+
Test padding consistency, especially for even kernels.
180+
Input: [1, 10, 10, 1]. Kernel: 4x4.
181+
Should produce [1, 10, 10, 1] output (Same padding).
182+
"""
183+
print("\n--- Testing Padding (Even Kernel Size 4) ---")
184+
node = LatentMathNode()
185+
shape = (1, 10, 10, 1)
186+
a_val = torch.randn(*shape).float()
187+
188+
# conv(a, 4, 4, 1.0)
189+
# If padding is symmetric 2, result is 11x11.
190+
# If padding is symmetric 1, result is 9x9.
191+
# We need asymmetric pad (1, 2) to get 10x10.
192+
try:
193+
input_dict = {"samples": a_val}
194+
res = node.execute("conv(a, 4, 4, 1.0)", a=input_dict)
195+
res_tensor = res[0]["samples"]
196+
print(f"Padding Test Result Shape: {res_tensor.shape}")
197+
assert res_tensor.shape == shape
198+
except Exception as e:
199+
print(f"Padding Test Failed: {e}")
200+
raise
201+
202+
def test_conv_complex_padding():
203+
"""
204+
Test asymmetric padding with mixed odd/even kernel sizes.
205+
Kernel: (3, 4). Input: (1, 10, 10, 1).
206+
Should produce (1, 10, 10, 1).
207+
"""
208+
print("\n--- Testing Complex Padding (3, 4) ---")
209+
node = LatentMathNode()
210+
shape = (1, 10, 10, 1)
211+
a_val = torch.randn(*shape).float()
212+
213+
try:
214+
input_dict = {"samples": a_val}
215+
res = node.execute("conv(a, 3, 4, 1.0)", a=input_dict)
216+
res_tensor = res[0]["samples"]
217+
print(f"Complex Padding Result Shape: {res_tensor.shape}")
218+
assert res_tensor.shape == shape
219+
except Exception as e:
220+
print(f"Complex Padding Failed: {e}")
221+
raise
222+
223+
def test_conv_3d_asymmetric():
224+
"""
225+
Test 3D conv with asymmetric spatial dims.
226+
Input: [1, 5, 10, 20, 1]. Kernel: 3x3x3.
227+
"""
228+
print("\n--- Testing 3D Asymmetric Input ---")
229+
node = LatentMathNode()
230+
shape = (1, 5, 10, 20, 1)
231+
a_val = torch.randn(*shape).float()
232+
233+
try:
234+
input_dict = {"samples": a_val}
235+
res = node.execute("conv(a, 3, 3, 3, 1.0)", a=input_dict)
236+
res_tensor = res[0]["samples"]
237+
print(f"3D Asymmetric Result Shape: {res_tensor.shape}")
238+
assert res_tensor.shape == shape
239+
except Exception as e:
240+
print(f"3D Asymmetric Failed: {e}")
241+
raise
242+
243+
if __name__ == "__main__":
244+
try:
245+
test_conv_1d()
246+
test_conv_3d()
247+
test_conv_arbitrary_batch()
248+
test_conv_list_kernel()
249+
test_conv_audio()
250+
test_conv_deep_latent()
251+
test_conv_padding()
252+
test_conv_complex_padding()
253+
test_conv_3d_asymmetric()
254+
print("All Conv tests passed!")
255+
except Exception as e:
256+
import traceback
257+
traceback.print_exc()
258+
sys.exit(1)

0 commit comments

Comments
 (0)