Skip to content

Commit 7524286

Browse files
committed
add robust shortcuts for 1 and 2 inputs
1 parent fa84ab0 commit 7524286

File tree

2 files changed

+112
-32
lines changed

2 files changed

+112
-32
lines changed

src/lib.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,9 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
991991
if simplify {
992992
cp.simplify();
993993
}
994-
cp.optimize_greedy(costmod, temperature, max_neighbors, seed);
994+
if cp.nodes.len() > 2 {
995+
cp.optimize_greedy(costmod, temperature, max_neighbors, seed);
996+
}
995997
cp.optimize_remaining_by_size();
996998
cp.ssa_path
997999
}
@@ -1010,7 +1012,9 @@ fn run_optimal<Ix: IndexType, Node: NodeType>(
10101012
if simplify {
10111013
cp.simplify();
10121014
}
1013-
cp.optimize_optimal(minimize, cost_cap, search_outer);
1015+
if cp.nodes.len() > 2 {
1016+
cp.optimize_optimal(minimize, cost_cap, search_outer);
1017+
}
10141018
cp.optimize_remaining_by_size();
10151019
cp.ssa_path
10161020
}
@@ -1038,6 +1042,12 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
10381042
cp0.simplify();
10391043
}
10401044

1045+
if cp0.nodes.len() <= 2 {
1046+
cp0.optimize_remaining_by_size();
1047+
let flops = cp0.flops * f32::consts::LOG10_E;
1048+
return (cp0.ssa_path, flops);
1049+
}
1050+
10411051
let mut best_path: Option<SSAPath> = None;
10421052
let mut best_flops = f32::INFINITY;
10431053

@@ -1158,6 +1168,9 @@ fn optimize_simplify(
11581168
use_ssa: Option<bool>,
11591169
) -> SSAPath {
11601170
let n = inputs.len();
1171+
if n <= 1 {
1172+
return vec![(0..n as u32).collect()];
1173+
}
11611174
let num_indices = size_dict.len();
11621175
let max_nodes = 2 * n;
11631176

@@ -1245,8 +1258,11 @@ fn optimize_greedy(
12451258
simplify: Option<bool>,
12461259
use_ssa: Option<bool>,
12471260
) -> SSAPath {
1261+
let n = inputs.len();
1262+
if n <= 1 {
1263+
return vec![(0..n as u32).collect()];
1264+
}
12481265
py.detach(|| {
1249-
let n = inputs.len();
12501266
let num_indices = size_dict.len();
12511267
let max_nodes = 2 * n;
12521268
let simplify = simplify.unwrap_or(true);
@@ -1372,6 +1388,10 @@ fn optimize_random_greedy_track_flops(
13721388
simplify: Option<bool>,
13731389
use_ssa: Option<bool>,
13741390
) -> (SSAPath, Score) {
1391+
let n = inputs.len();
1392+
if n <= 1 {
1393+
return (vec![(0..n as u32).collect()], 0.0);
1394+
}
13751395
py.detach(|| {
13761396
let (costmod_min, costmod_max) = costmod.unwrap_or((0.1, 4.0));
13771397
let costmod_diff = (costmod_max - costmod_min).abs();
@@ -1389,7 +1409,6 @@ fn optimize_random_greedy_track_flops(
13891409
};
13901410
let seeds = (0..ntrials).map(|_| rng.random()).collect::<Vec<u64>>();
13911411

1392-
let n: usize = inputs.len();
13931412
let num_indices = size_dict.len();
13941413
let max_nodes = 2 * n;
13951414

@@ -1530,8 +1549,11 @@ fn optimize_optimal(
15301549
simplify: Option<bool>,
15311550
use_ssa: Option<bool>,
15321551
) -> SSAPath {
1552+
let n = inputs.len();
1553+
if n <= 1 {
1554+
return vec![(0..n as u32).collect()];
1555+
}
15331556
py.detach(|| {
1534-
let n = inputs.len();
15351557
let num_indices = size_dict.len();
15361558
let max_nodes = 2 * n;
15371559
let simplify = simplify.unwrap_or(true);

tests/test_cotengrust.py

Lines changed: 85 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,88 @@
1616

1717
@pytest.mark.parametrize("which", ["greedy", "optimal"])
1818
def test_basic_call(which):
19-
inputs = [('a', 'b'), ('b', 'c'), ('c', 'd'), ('d', 'a')]
20-
output = ('b', 'd')
21-
size_dict = {'a': 2, 'b': 3, 'c': 4, 'd': 5}
19+
inputs = [("a", "b"), ("b", "c"), ("c", "d"), ("d", "a")]
20+
output = ("b", "d")
21+
size_dict = {"a": 2, "b": 3, "c": 4, "d": 5}
2222
path = {
2323
"greedy": ctgr.optimize_greedy,
2424
"optimal": ctgr.optimize_optimal,
25-
}[
26-
which
27-
](inputs, output, size_dict)
25+
}[which](inputs, output, size_dict)
2826
assert all(len(con) <= 2 for con in path)
2927

3028

29+
@pytest.mark.parametrize(
30+
"which",
31+
["simplify", "greedy", "optimal", "random_greedy"],
32+
)
33+
def test_single_input(which):
34+
inputs = [("a", "b")]
35+
output = ("a", "b")
36+
size_dict = {"a": 2, "b": 3}
37+
if which == "random_greedy":
38+
path, flops = ctgr.optimize_random_greedy_track_flops(
39+
inputs, output, size_dict, ntrials=1
40+
)
41+
assert flops == 0.0
42+
else:
43+
path = {
44+
"simplify": ctgr.optimize_simplify,
45+
"greedy": ctgr.optimize_greedy,
46+
"optimal": ctgr.optimize_optimal,
47+
}[which](inputs, output, size_dict)
48+
assert path == [[0]]
49+
50+
51+
@pytest.mark.parametrize("which", ["greedy", "optimal", "random_greedy"])
52+
def test_two_inputs(which):
53+
inputs = [("a", "b"), ("b", "c")]
54+
output = ("a", "c")
55+
size_dict = {"a": 2, "b": 3, "c": 4}
56+
if which == "random_greedy":
57+
path, flops = ctgr.optimize_random_greedy_track_flops(
58+
inputs, output, size_dict, ntrials=1
59+
)
60+
else:
61+
path = {
62+
"greedy": ctgr.optimize_greedy,
63+
"optimal": ctgr.optimize_optimal,
64+
}[which](inputs, output, size_dict)
65+
assert path == [[0, 1]]
66+
67+
68+
@pytest.mark.parametrize(
69+
"which",
70+
["simplify", "greedy", "optimal", "random_greedy"],
71+
)
72+
def test_two_inputs_with_simplification(which):
73+
"""Two inputs where each term has indices needing simplification first.
74+
75+
For 'ab,cd->', both terms have non-output, single-term indices that
76+
should be reduced before the final contraction, producing a path like
77+
[(0,), (1,), (0, 1)] rather than just [(0, 1)].
78+
"""
79+
inputs = [("a", "b"), ("c", "d")]
80+
output = ()
81+
size_dict = {"a": 2, "b": 3, "c": 4, "d": 5}
82+
if which == "random_greedy":
83+
path, _ = ctgr.optimize_random_greedy_track_flops(
84+
inputs, output, size_dict, ntrials=1
85+
)
86+
else:
87+
path = {
88+
"simplify": ctgr.optimize_simplify,
89+
"greedy": ctgr.optimize_greedy,
90+
"optimal": ctgr.optimize_optimal,
91+
}[which](inputs, output, size_dict)
92+
# simplification should reduce each term independently first,
93+
# producing two single-term contractions before the final pair
94+
assert len(path) == 3
95+
singles = [con for con in path if len(con) == 1]
96+
pairs = [con for con in path if len(con) == 2]
97+
assert len(singles) == 2
98+
assert len(pairs) == 1
99+
100+
31101
def find_output_str(lhs):
32102
tmp_lhs = lhs.replace(",", "")
33103
return "".join(s for s in sorted(set(tmp_lhs)) if tmp_lhs.count(s) == 1)
@@ -157,9 +227,7 @@ def test_manual_cases(eq, which):
157227
path = {
158228
"greedy": ctgr.optimize_greedy,
159229
"optimal": ctgr.optimize_optimal,
160-
}[
161-
which
162-
](inputs, output, size_dict)
230+
}[which](inputs, output, size_dict)
163231
assert all(len(con) <= 2 for con in path)
164232
tree = ctg.ContractionTree.from_path(
165233
inputs, output, size_dict, path=path, check=True
@@ -184,9 +252,7 @@ def test_basic_rand(seed, which):
184252
path = {
185253
"greedy": ctgr.optimize_greedy,
186254
"optimal": ctgr.optimize_optimal,
187-
}[
188-
which
189-
](inputs, output, size_dict)
255+
}[which](inputs, output, size_dict)
190256
assert all(len(con) <= 2 for con in path)
191257
tree = ctg.ContractionTree.from_path(
192258
inputs, output, size_dict, path=path, check=True
@@ -196,22 +262,16 @@ def test_basic_rand(seed, which):
196262

197263
@requires_cotengra
198264
def test_optimal_lattice_eq():
199-
inputs, output, _, size_dict = ctg.utils.lattice_equation(
200-
[4, 5], d_max=2, seed=42
201-
)
265+
inputs, output, _, size_dict = ctg.utils.lattice_equation([4, 5], d_max=2, seed=42)
202266

203-
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='flops')
204-
tree = ctg.ContractionTree.from_path(
205-
inputs, output, size_dict, path=path
206-
)
267+
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize="flops")
268+
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
207269
assert tree.is_complete()
208270
assert tree.contraction_cost() == 964
209271

210-
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='size')
272+
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize="size")
211273
assert all(len(con) <= 2 for con in path)
212-
tree = ctg.ContractionTree.from_path(
213-
inputs, output, size_dict, path=path
214-
)
274+
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
215275
assert tree.contraction_width() == pytest.approx(5)
216276

217277

@@ -228,8 +288,6 @@ def test_optimize_random_greedy_log_flops():
228288
inputs, output, size_dict, ntrials=4, seed=42
229289
)
230290
assert cost1 == cost2
231-
tree = ctg.ContractionTree.from_path(
232-
inputs, output, size_dict, path=path
233-
)
291+
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
234292
assert tree.is_complete()
235-
assert tree.contraction_cost(log=10) == pytest.approx(cost1)
293+
assert tree.contraction_cost(log=10) == pytest.approx(cost1)

0 commit comments

Comments
 (0)