77- ` optimize_optimal(inputs, output, size_dict, **kwargs) `
88- ` optimize_greedy(inputs, output, size_dict, **kwargs) `
99
10- The optimal algorithm is an optimized version of the ` opt_einsum ` 'dp'
10+ The optimal algorithm is an optimized version of the ` opt_einsum ` 'dp'
1111path - itself an implementation of https://arxiv.org/abs/1304.6112 .
1212
13+ There is also a variant of the greedy algorithm, which runs ` ntrials ` of greedy,
14+ randomized paths and computes and reports the flops cost (log10) simultaneously:
15+
16+ - ` optimize_random_greedy_track_flops(inputs, output, size_dict, **kwargs) `
17+
1318
1419## Installation
1520
@@ -20,7 +25,7 @@ path - itself an implementation of https://arxiv.org/abs/1304.6112.
2025pip install cotengrust
2126```
2227
23- or if you want to develop locally (which requires [ pyo3] ( https://github.com/PyO3/pyo3 )
28+ or if you want to develop locally (which requires [ pyo3] ( https://github.com/PyO3/pyo3 )
2429and [ maturin] ( https://github.com/PyO3/maturin ) ):
2530
2631``` bash
@@ -34,8 +39,8 @@ maturin develop --release
3439## Usage
3540
3641If ` cotengrust ` is installed, then by default ` cotengra ` will use it for its
37- greedy and optimal subroutines, notably subtree reconfiguration. You can also
38- call the routines directly:
42+ greedy, random-greedy, and optimal subroutines, notably subtree
43+ reconfiguration. You can also call the routines directly:
3944
4045``` python
4146import cotengra as ctg
@@ -225,6 +230,76 @@ def optimize_simplify(
225230 """
226231 ...
227232
233+ def optimize_random_greedy_track_flops (
234+ inputs ,
235+ output ,
236+ size_dict ,
237+ ntrials = 1 ,
238+ costmod = 1.0 ,
239+ temperature = 0.01 ,
240+ seed = None ,
241+ simplify = True ,
242+ use_ssa = False ,
243+ ):
244+ """ Perform a batch of random greedy optimizations, simulteneously tracking
245+ the best contraction path in terms of flops, so as to avoid constructing a
246+ separate contraction tree.
247+
248+ Parameters
249+ ----------
250+ inputs : tuple[tuple[str]]
251+ The indices of each input tensor.
252+ output : tuple[str]
253+ The indices of the output tensor.
254+ size_dict : dict[str, int]
255+ A dictionary mapping indices to their dimension.
256+ ntrials : int, optional
257+ The number of random greedy trials to perform. The default is 1.
258+ costmod : float, optional
259+ When assessing local greedy scores how much to weight the size of the
260+ tensors removed compared to the size of the tensor added::
261+
262+ score = size_ab - costmod * (size_a + size_b)
263+
264+ This can be a useful hyper-parameter to tune.
265+ temperature : float, optional
266+ When asessing local greedy scores, how much to randomly perturb the
267+ score. This is implemented as::
268+
269+ score -> sign(score) * log(|score|) - temperature * gumbel()
270+
271+ which implements boltzmann sampling.
272+ seed : int, optional
273+ The seed for the random number generator.
274+ simplify : bool, optional
275+ Whether to perform simplifications before optimizing. These are:
276+
277+ - ignore any indices that appear in all terms
278+ - combine any repeated indices within a single term
279+ - reduce any non-output indices that only appear on a single term
280+ - combine any scalar terms
281+ - combine any tensors with matching indices (hadamard products)
282+
283+ Such simpifications may be required in the general case for the proper
284+ functioning of the core optimization, but may be skipped if the input
285+ indices are already in a simplified form.
286+ use_ssa : bool, optional
287+ Whether to return the contraction path in 'single static assignment'
288+ (SSA) format (i.e. as if each intermediate is appended to the list of
289+ inputs, without removals). This can be quicker and easier to work with
290+ than the 'linear recycled' format that `numpy` and `opt_einsum` use.
291+
292+ Returns
293+ -------
294+ path : list[list[int]]
295+ The best contraction path, given as a sequence of pairs of node
296+ indices.
297+ flops : float
298+ The flops (/ contraction cost / number of multiplications), of the best
299+ contraction path, given log10.
300+ """
301+ ...
302+
228303def ssa_to_linear (ssa_path , n = None ):
229304 """ Convert a SSA path to linear format."""
230305 ...
0 commit comments