Skip to content

Commit 9aadf4f

Browse files
committed
greedy: add max_neighbors to skip large batch indices
1 parent 51d6dda commit 9aadf4f

File tree

1 file changed

+67
-5
lines changed

1 file changed

+67
-5
lines changed

src/lib.rs

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,24 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
245245
js
246246
}
247247

248+
/// like neighbors but skip edges with too many neighbors, for greedy
249+
fn neighbors_limit(&self, i: Node, max_neighbors: usize) -> BTreeSet<Node> {
250+
let mut js = BTreeSet::default();
251+
for (ix, _) in self.nodes[&i].iter() {
252+
if max_neighbors != 0 && self.edges[&ix].len() > max_neighbors {
253+
// basically a batch index with too many combinations -> skip
254+
continue;
255+
}
256+
257+
self.edges[&ix].iter().for_each(|&j| {
258+
if j != i {
259+
js.insert(j);
260+
};
261+
});
262+
}
263+
js
264+
}
265+
248266
/// remove an index from the graph, updating all legs
249267
fn remove_ix(&mut self, ix: Ix) {
250268
for j in self.edges.remove(&ix).unwrap() {
@@ -439,10 +457,12 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
439457
&mut self,
440458
costmod: Option<f32>,
441459
temperature: Option<f32>,
460+
max_neighbors: Option<usize>,
442461
seed: Option<u64>,
443462
) -> bool {
444463
let coeff_t = temperature.unwrap_or(0.0);
445464
let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));
465+
let max_neighbors = max_neighbors.unwrap_or(16);
446466

447467
let mut rng = if coeff_t != 0.0 {
448468
Some(match seed {
@@ -479,6 +499,11 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
479499

480500
// get the initial candidate contractions
481501
for ix_nodes in self.edges.values() {
502+
if max_neighbors != 0 && ix_nodes.len() > max_neighbors {
503+
// basically a batch index with too many combinations -> skip
504+
continue;
505+
}
506+
482507
// convert to vector for combinational indexing
483508
let ix_nodes: Vec<Node> = ix_nodes.iter().cloned().collect();
484509
// for all combinations of nodes with a connected edge
@@ -516,7 +541,7 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
516541

517542
node_sizes.insert(k, ksize);
518543

519-
for l in self.neighbors(k) {
544+
for l in self.neighbors_limit(k, max_neighbors) {
520545
// assess all neighboring contractions of new node
521546
let llegs = &self.nodes[&l];
522547
let lsize = node_sizes[&l];
@@ -528,6 +553,22 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
528553
contractions.insert(c, (k, l, msize, mlegs));
529554
c -= 1;
530555
}
556+
557+
// // potential queue pruning?
558+
// if queue.len() > 100_000 {
559+
// let mut valid_contractions = Vec::new();
560+
// for (score, cid) in queue.drain() {
561+
// if let Some((i, j, _, _)) = contractions.get(&cid) {
562+
// if self.nodes.contains_key(&i) && self.nodes.contains_key(&j) {
563+
// valid_contractions.push((score, cid));
564+
// } else {
565+
// // Remove stale contraction from map
566+
// contractions.remove(&cid);
567+
// }
568+
// }
569+
// }
570+
// queue = BinaryHeap::from(valid_contractions);
571+
// }
531572
}
532573
// success
533574
return true;
@@ -941,6 +982,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
941982
size_dict: Dict<char, f32>,
942983
costmod: Option<f32>,
943984
temperature: Option<f32>,
985+
max_neighbors: Option<usize>,
944986
seed: Option<u64>,
945987
simplify: bool,
946988
) -> SSAPath {
@@ -949,7 +991,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
949991
if simplify {
950992
cp.simplify();
951993
}
952-
cp.optimize_greedy(costmod, temperature, seed);
994+
cp.optimize_greedy(costmod, temperature, max_neighbors, seed);
953995
cp.optimize_remaining_by_size();
954996
cp.ssa_path
955997
}
@@ -987,6 +1029,7 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
9871029
log_temp_min: f32,
9881030
log_temp_diff: f32,
9891031
is_const_temp: bool,
1032+
max_neighbors: Option<usize>,
9901033
rng: &mut rand::rngs::StdRng,
9911034
) -> (SSAPath, Score) {
9921035
let mut cp0: ContractionProcessor<Ix, Node> =
@@ -1013,7 +1056,8 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
10131056
f32::exp(log_temp_min + rng.random::<f32>() * log_temp_diff)
10141057
};
10151058

1016-
let success = cp.optimize_greedy(Some(costmod), Some(temperature), Some(*seed));
1059+
let success =
1060+
cp.optimize_greedy(Some(costmod), Some(temperature), max_neighbors, Some(*seed));
10171061

10181062
if !success {
10191063
continue;
@@ -1135,7 +1179,7 @@ fn optimize_simplify(
11351179
}
11361180

11371181
#[pyfunction]
1138-
#[pyo3(signature = (inputs, output, size_dict, costmod=None, temperature=None, seed=None, simplify=None, use_ssa=None))]
1182+
#[pyo3(signature = (inputs, output, size_dict, costmod=None, temperature=None, max_neighbors=None, seed=None, simplify=None, use_ssa=None))]
11391183
/// Find a contraction path using a (randomizable) greedy algorithm.
11401184
///
11411185
/// Parameters
@@ -1160,6 +1204,12 @@ fn optimize_simplify(
11601204
/// score -> sign(score) * log(|score|) - temperature * gumbel()
11611205
///
11621206
/// which implements boltzmann sampling.
1207+
/// max_neighbors : int, optional
1208+
/// If non-zero, skip any index that connects to more than this many
1209+
/// nodes. This is useful to avoid combinatorial explosions when
1210+
/// dealing with essentially batch indices. Default: 16.
1211+
/// seed : int, optional
1212+
/// The seed for the random number generator.
11631213
/// simplify : bool, optional
11641214
/// Whether to perform simplifications before optimizing. These are:
11651215
///
@@ -1190,6 +1240,7 @@ fn optimize_greedy(
11901240
size_dict: Dict<char, f32>,
11911241
costmod: Option<f32>,
11921242
temperature: Option<f32>,
1243+
max_neighbors: Option<usize>,
11931244
seed: Option<u64>,
11941245
simplify: Option<bool>,
11951246
use_ssa: Option<bool>,
@@ -1208,6 +1259,7 @@ fn optimize_greedy(
12081259
size_dict,
12091260
costmod,
12101261
temperature,
1262+
max_neighbors,
12111263
seed,
12121264
simplify,
12131265
)
@@ -1219,6 +1271,7 @@ fn optimize_greedy(
12191271
size_dict,
12201272
costmod,
12211273
temperature,
1274+
max_neighbors,
12221275
seed,
12231276
simplify,
12241277
)
@@ -1229,6 +1282,7 @@ fn optimize_greedy(
12291282
size_dict,
12301283
costmod,
12311284
temperature,
1285+
max_neighbors,
12321286
seed,
12331287
simplify,
12341288
),
@@ -1243,7 +1297,7 @@ fn optimize_greedy(
12431297
}
12441298

12451299
#[pyfunction]
1246-
#[pyo3(signature = (inputs, output, size_dict, ntrials, costmod=None, temperature=None, seed=None, simplify=None, use_ssa=None))]
1300+
#[pyo3(signature = (inputs, output, size_dict, ntrials, costmod=None, temperature=None, max_neighbors=None, seed=None, simplify=None, use_ssa=None))]
12471301
/// Perform a batch of random greedy optimizations, simulteneously tracking
12481302
/// the best contraction path in terms of flops, so as to avoid constructing a
12491303
/// separate contraction tree.
@@ -1273,6 +1327,10 @@ fn optimize_greedy(
12731327
///
12741328
/// which implements boltzmann sampling. It is sampled log-uniformly from
12751329
/// the given range.
1330+
/// max_neighbors : int, optional
1331+
/// If non-zero, skip any index that connects to more than this many
1332+
/// nodes. This is useful to avoid combinatorial explosions when
1333+
/// dealing with essentially batch indices. Default: 16.
12761334
/// seed : int, optional
12771335
/// The seed for the random number generator.
12781336
/// simplify : bool, optional
@@ -1309,6 +1367,7 @@ fn optimize_random_greedy_track_flops(
13091367
ntrials: usize,
13101368
costmod: Option<(f32, f32)>,
13111369
temperature: Option<(f32, f32)>,
1370+
max_neighbors: Option<usize>,
13121371
seed: Option<u64>,
13131372
simplify: Option<bool>,
13141373
use_ssa: Option<bool>,
@@ -1350,6 +1409,7 @@ fn optimize_random_greedy_track_flops(
13501409
log_temp_min,
13511410
log_temp_diff,
13521411
is_const_temp,
1412+
max_neighbors,
13531413
&mut rng,
13541414
)
13551415
}
@@ -1367,6 +1427,7 @@ fn optimize_random_greedy_track_flops(
13671427
log_temp_min,
13681428
log_temp_diff,
13691429
is_const_temp,
1430+
max_neighbors,
13701431
&mut rng,
13711432
)
13721433
}
@@ -1383,6 +1444,7 @@ fn optimize_random_greedy_track_flops(
13831444
log_temp_min,
13841445
log_temp_diff,
13851446
is_const_temp,
1447+
max_neighbors,
13861448
&mut rng,
13871449
),
13881450
};

0 commit comments

Comments
 (0)