1616
1717@pytest .mark .parametrize ("which" , ["greedy" , "optimal" ])
1818def test_basic_call (which ):
19- inputs = [('a' , 'b' ), ('b' , 'c' ), ('c' , 'd' ), ('d' , 'a' )]
20- output = ('b' , 'd' )
21- size_dict = {'a' : 2 , 'b' : 3 , 'c' : 4 , 'd' : 5 }
19+ inputs = [("a" , "b" ), ("b" , "c" ), ("c" , "d" ), ("d" , "a" )]
20+ output = ("b" , "d" )
21+ size_dict = {"a" : 2 , "b" : 3 , "c" : 4 , "d" : 5 }
2222 path = {
2323 "greedy" : ctgr .optimize_greedy ,
2424 "optimal" : ctgr .optimize_optimal ,
25- }[
26- which
27- ](inputs , output , size_dict )
25+ }[which ](inputs , output , size_dict )
2826 assert all (len (con ) <= 2 for con in path )
2927
3028
29+ @pytest .mark .parametrize (
30+ "which" ,
31+ ["simplify" , "greedy" , "optimal" , "random_greedy" ],
32+ )
33+ def test_single_input (which ):
34+ inputs = [("a" , "b" )]
35+ output = ("a" , "b" )
36+ size_dict = {"a" : 2 , "b" : 3 }
37+ if which == "random_greedy" :
38+ path , flops = ctgr .optimize_random_greedy_track_flops (
39+ inputs , output , size_dict , ntrials = 1
40+ )
41+ assert flops == 0.0
42+ else :
43+ path = {
44+ "simplify" : ctgr .optimize_simplify ,
45+ "greedy" : ctgr .optimize_greedy ,
46+ "optimal" : ctgr .optimize_optimal ,
47+ }[which ](inputs , output , size_dict )
48+ assert path == [[0 ]]
49+
50+
51+ @pytest .mark .parametrize ("which" , ["greedy" , "optimal" , "random_greedy" ])
52+ def test_two_inputs (which ):
53+ inputs = [("a" , "b" ), ("b" , "c" )]
54+ output = ("a" , "c" )
55+ size_dict = {"a" : 2 , "b" : 3 , "c" : 4 }
56+ if which == "random_greedy" :
57+ path , flops = ctgr .optimize_random_greedy_track_flops (
58+ inputs , output , size_dict , ntrials = 1
59+ )
60+ else :
61+ path = {
62+ "greedy" : ctgr .optimize_greedy ,
63+ "optimal" : ctgr .optimize_optimal ,
64+ }[which ](inputs , output , size_dict )
65+ assert path == [[0 , 1 ]]
66+
67+
68+ @pytest .mark .parametrize (
69+ "which" ,
70+ ["simplify" , "greedy" , "optimal" , "random_greedy" ],
71+ )
72+ def test_two_inputs_with_simplification (which ):
73+ """Two inputs where each term has indices needing simplification first.
74+
75+ For 'ab,cd->', both terms have non-output, single-term indices that
76+ should be reduced before the final contraction, producing a path like
77+ [(0,), (1,), (0, 1)] rather than just [(0, 1)].
78+ """
79+ inputs = [("a" , "b" ), ("c" , "d" )]
80+ output = ()
81+ size_dict = {"a" : 2 , "b" : 3 , "c" : 4 , "d" : 5 }
82+ if which == "random_greedy" :
83+ path , _ = ctgr .optimize_random_greedy_track_flops (
84+ inputs , output , size_dict , ntrials = 1
85+ )
86+ else :
87+ path = {
88+ "simplify" : ctgr .optimize_simplify ,
89+ "greedy" : ctgr .optimize_greedy ,
90+ "optimal" : ctgr .optimize_optimal ,
91+ }[which ](inputs , output , size_dict )
92+ # simplification should reduce each term independently first,
93+ # producing two single-term contractions before the final pair
94+ assert len (path ) == 3
95+ singles = [con for con in path if len (con ) == 1 ]
96+ pairs = [con for con in path if len (con ) == 2 ]
97+ assert len (singles ) == 2
98+ assert len (pairs ) == 1
99+
100+
31101def find_output_str (lhs ):
32102 tmp_lhs = lhs .replace ("," , "" )
33103 return "" .join (s for s in sorted (set (tmp_lhs )) if tmp_lhs .count (s ) == 1 )
@@ -157,9 +227,7 @@ def test_manual_cases(eq, which):
157227 path = {
158228 "greedy" : ctgr .optimize_greedy ,
159229 "optimal" : ctgr .optimize_optimal ,
160- }[
161- which
162- ](inputs , output , size_dict )
230+ }[which ](inputs , output , size_dict )
163231 assert all (len (con ) <= 2 for con in path )
164232 tree = ctg .ContractionTree .from_path (
165233 inputs , output , size_dict , path = path , check = True
@@ -184,9 +252,7 @@ def test_basic_rand(seed, which):
184252 path = {
185253 "greedy" : ctgr .optimize_greedy ,
186254 "optimal" : ctgr .optimize_optimal ,
187- }[
188- which
189- ](inputs , output , size_dict )
255+ }[which ](inputs , output , size_dict )
190256 assert all (len (con ) <= 2 for con in path )
191257 tree = ctg .ContractionTree .from_path (
192258 inputs , output , size_dict , path = path , check = True
@@ -196,22 +262,16 @@ def test_basic_rand(seed, which):
196262
197263@requires_cotengra
198264def test_optimal_lattice_eq ():
199- inputs , output , _ , size_dict = ctg .utils .lattice_equation (
200- [4 , 5 ], d_max = 2 , seed = 42
201- )
265+ inputs , output , _ , size_dict = ctg .utils .lattice_equation ([4 , 5 ], d_max = 2 , seed = 42 )
202266
203- path = ctgr .optimize_optimal (inputs , output , size_dict , minimize = 'flops' )
204- tree = ctg .ContractionTree .from_path (
205- inputs , output , size_dict , path = path
206- )
267+ path = ctgr .optimize_optimal (inputs , output , size_dict , minimize = "flops" )
268+ tree = ctg .ContractionTree .from_path (inputs , output , size_dict , path = path )
207269 assert tree .is_complete ()
208270 assert tree .contraction_cost () == 964
209271
210- path = ctgr .optimize_optimal (inputs , output , size_dict , minimize = ' size' )
272+ path = ctgr .optimize_optimal (inputs , output , size_dict , minimize = " size" )
211273 assert all (len (con ) <= 2 for con in path )
212- tree = ctg .ContractionTree .from_path (
213- inputs , output , size_dict , path = path
214- )
274+ tree = ctg .ContractionTree .from_path (inputs , output , size_dict , path = path )
215275 assert tree .contraction_width () == pytest .approx (5 )
216276
217277
@@ -228,8 +288,6 @@ def test_optimize_random_greedy_log_flops():
228288 inputs , output , size_dict , ntrials = 4 , seed = 42
229289 )
230290 assert cost1 == cost2
231- tree = ctg .ContractionTree .from_path (
232- inputs , output , size_dict , path = path
233- )
291+ tree = ctg .ContractionTree .from_path (inputs , output , size_dict , path = path )
234292 assert tree .is_complete ()
235- assert tree .contraction_cost (log = 10 ) == pytest .approx (cost1 )
293+ assert tree .contraction_cost (log = 10 ) == pytest .approx (cost1 )
0 commit comments