@@ -245,6 +245,24 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
245245 js
246246 }
247247
248+ /// like neighbors but skip edges with too many neighbors, for greedy
249+ fn neighbors_limit ( & self , i : Node , max_neighbors : usize ) -> BTreeSet < Node > {
250+ let mut js = BTreeSet :: default ( ) ;
251+ for ( ix, _) in self . nodes [ & i] . iter ( ) {
252+ if max_neighbors != 0 && self . edges [ & ix] . len ( ) > max_neighbors {
253+ // basically a batch index with too many combinations -> skip
254+ continue ;
255+ }
256+
257+ self . edges [ & ix] . iter ( ) . for_each ( |& j| {
258+ if j != i {
259+ js. insert ( j) ;
260+ } ;
261+ } ) ;
262+ }
263+ js
264+ }
265+
248266 /// remove an index from the graph, updating all legs
249267 fn remove_ix ( & mut self , ix : Ix ) {
250268 for j in self . edges . remove ( & ix) . unwrap ( ) {
@@ -439,10 +457,12 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
439457 & mut self ,
440458 costmod : Option < f32 > ,
441459 temperature : Option < f32 > ,
460+ max_neighbors : Option < usize > ,
442461 seed : Option < u64 > ,
443462 ) -> bool {
444463 let coeff_t = temperature. unwrap_or ( 0.0 ) ;
445464 let log_coeff_a = f32:: ln ( costmod. unwrap_or ( 1.0 ) ) ;
465+ let max_neighbors = max_neighbors. unwrap_or ( 16 ) ;
446466
447467 let mut rng = if coeff_t != 0.0 {
448468 Some ( match seed {
@@ -479,6 +499,11 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
479499
480500 // get the initial candidate contractions
481501 for ix_nodes in self . edges . values ( ) {
502+ if max_neighbors != 0 && ix_nodes. len ( ) > max_neighbors {
503+ // basically a batch index with too many combinations -> skip
504+ continue ;
505+ }
506+
482507 // convert to vector for combinational indexing
483508 let ix_nodes: Vec < Node > = ix_nodes. iter ( ) . cloned ( ) . collect ( ) ;
484509 // for all combinations of nodes with a connected edge
@@ -516,7 +541,7 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
516541
517542 node_sizes. insert ( k, ksize) ;
518543
519- for l in self . neighbors ( k ) {
544+ for l in self . neighbors_limit ( k , max_neighbors ) {
520545 // assess all neighboring contractions of new node
521546 let llegs = & self . nodes [ & l] ;
522547 let lsize = node_sizes[ & l] ;
@@ -528,6 +553,22 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
528553 contractions. insert ( c, ( k, l, msize, mlegs) ) ;
529554 c -= 1 ;
530555 }
556+
557+ // // potential queue pruning?
558+ // if queue.len() > 100_000 {
559+ // let mut valid_contractions = Vec::new();
560+ // for (score, cid) in queue.drain() {
561+ // if let Some((i, j, _, _)) = contractions.get(&cid) {
562+ // if self.nodes.contains_key(&i) && self.nodes.contains_key(&j) {
563+ // valid_contractions.push((score, cid));
564+ // } else {
565+ // // Remove stale contraction from map
566+ // contractions.remove(&cid);
567+ // }
568+ // }
569+ // }
570+ // queue = BinaryHeap::from(valid_contractions);
571+ // }
531572 }
532573 // success
533574 return true ;
@@ -941,6 +982,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
941982 size_dict : Dict < char , f32 > ,
942983 costmod : Option < f32 > ,
943984 temperature : Option < f32 > ,
985+ max_neighbors : Option < usize > ,
944986 seed : Option < u64 > ,
945987 simplify : bool ,
946988) -> SSAPath {
@@ -949,7 +991,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
949991 if simplify {
950992 cp. simplify ( ) ;
951993 }
952- cp. optimize_greedy ( costmod, temperature, seed) ;
994+ cp. optimize_greedy ( costmod, temperature, max_neighbors , seed) ;
953995 cp. optimize_remaining_by_size ( ) ;
954996 cp. ssa_path
955997}
@@ -987,6 +1029,7 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
9871029 log_temp_min : f32 ,
9881030 log_temp_diff : f32 ,
9891031 is_const_temp : bool ,
1032+ max_neighbors : Option < usize > ,
9901033 rng : & mut rand:: rngs:: StdRng ,
9911034) -> ( SSAPath , Score ) {
9921035 let mut cp0: ContractionProcessor < Ix , Node > =
@@ -1013,7 +1056,8 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
10131056 f32:: exp ( log_temp_min + rng. random :: < f32 > ( ) * log_temp_diff)
10141057 } ;
10151058
1016- let success = cp. optimize_greedy ( Some ( costmod) , Some ( temperature) , Some ( * seed) ) ;
1059+ let success =
1060+ cp. optimize_greedy ( Some ( costmod) , Some ( temperature) , max_neighbors, Some ( * seed) ) ;
10171061
10181062 if !success {
10191063 continue ;
@@ -1135,7 +1179,7 @@ fn optimize_simplify(
11351179}
11361180
11371181#[ pyfunction]
1138- #[ pyo3( signature = ( inputs, output, size_dict, costmod=None , temperature=None , seed=None , simplify=None , use_ssa=None ) ) ]
1182+ #[ pyo3( signature = ( inputs, output, size_dict, costmod=None , temperature=None , max_neighbors= None , seed=None , simplify=None , use_ssa=None ) ) ]
11391183/// Find a contraction path using a (randomizable) greedy algorithm.
11401184///
11411185/// Parameters
@@ -1160,6 +1204,12 @@ fn optimize_simplify(
11601204/// score -> sign(score) * log(|score|) - temperature * gumbel()
11611205///
11621206/// which implements boltzmann sampling.
1207+ /// max_neighbors : int, optional
1208+ /// If non-zero, skip any index that connects to more than this many
1209+ /// nodes. This is useful to avoid combinatorial explosions when
1210+ /// dealing with essentially batch indices. Default: 16.
1211+ /// seed : int, optional
1212+ /// The seed for the random number generator.
11631213/// simplify : bool, optional
11641214/// Whether to perform simplifications before optimizing. These are:
11651215///
@@ -1190,6 +1240,7 @@ fn optimize_greedy(
11901240 size_dict : Dict < char , f32 > ,
11911241 costmod : Option < f32 > ,
11921242 temperature : Option < f32 > ,
1243+ max_neighbors : Option < usize > ,
11931244 seed : Option < u64 > ,
11941245 simplify : Option < bool > ,
11951246 use_ssa : Option < bool > ,
@@ -1208,6 +1259,7 @@ fn optimize_greedy(
12081259 size_dict,
12091260 costmod,
12101261 temperature,
1262+ max_neighbors,
12111263 seed,
12121264 simplify,
12131265 )
@@ -1219,6 +1271,7 @@ fn optimize_greedy(
12191271 size_dict,
12201272 costmod,
12211273 temperature,
1274+ max_neighbors,
12221275 seed,
12231276 simplify,
12241277 )
@@ -1229,6 +1282,7 @@ fn optimize_greedy(
12291282 size_dict,
12301283 costmod,
12311284 temperature,
1285+ max_neighbors,
12321286 seed,
12331287 simplify,
12341288 ) ,
@@ -1243,7 +1297,7 @@ fn optimize_greedy(
12431297}
12441298
12451299#[ pyfunction]
1246- #[ pyo3( signature = ( inputs, output, size_dict, ntrials, costmod=None , temperature=None , seed=None , simplify=None , use_ssa=None ) ) ]
1300+ #[ pyo3( signature = ( inputs, output, size_dict, ntrials, costmod=None , temperature=None , max_neighbors= None , seed=None , simplify=None , use_ssa=None ) ) ]
12471301/// Perform a batch of random greedy optimizations, simulteneously tracking
12481302/// the best contraction path in terms of flops, so as to avoid constructing a
12491303/// separate contraction tree.
@@ -1273,6 +1327,10 @@ fn optimize_greedy(
12731327///
12741328/// which implements boltzmann sampling. It is sampled log-uniformly from
12751329/// the given range.
1330+ /// max_neighbors : int, optional
1331+ /// If non-zero, skip any index that connects to more than this many
1332+ /// nodes. This is useful to avoid combinatorial explosions when
1333+ /// dealing with essentially batch indices. Default: 16.
12761334/// seed : int, optional
12771335/// The seed for the random number generator.
12781336/// simplify : bool, optional
@@ -1309,6 +1367,7 @@ fn optimize_random_greedy_track_flops(
13091367 ntrials : usize ,
13101368 costmod : Option < ( f32 , f32 ) > ,
13111369 temperature : Option < ( f32 , f32 ) > ,
1370+ max_neighbors : Option < usize > ,
13121371 seed : Option < u64 > ,
13131372 simplify : Option < bool > ,
13141373 use_ssa : Option < bool > ,
@@ -1350,6 +1409,7 @@ fn optimize_random_greedy_track_flops(
13501409 log_temp_min,
13511410 log_temp_diff,
13521411 is_const_temp,
1412+ max_neighbors,
13531413 & mut rng,
13541414 )
13551415 }
@@ -1367,6 +1427,7 @@ fn optimize_random_greedy_track_flops(
13671427 log_temp_min,
13681428 log_temp_diff,
13691429 is_const_temp,
1430+ max_neighbors,
13701431 & mut rng,
13711432 )
13721433 }
@@ -1383,6 +1444,7 @@ fn optimize_random_greedy_track_flops(
13831444 log_temp_min,
13841445 log_temp_diff,
13851446 is_const_temp,
1447+ max_neighbors,
13861448 & mut rng,
13871449 ) ,
13881450 } ;
0 commit comments