Skip to content

Commit 8436553

Browse files
committed
update docs
1 parent 8acaaed commit 8436553

File tree

1 file changed

+79
-4
lines changed

1 file changed

+79
-4
lines changed

README.md

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ are:
77
- `optimize_optimal(inputs, output, size_dict, **kwargs)`
88
- `optimize_greedy(inputs, output, size_dict, **kwargs)`
99

10-
The optimal algorithm is an optimized version of the `opt_einsum` 'dp'
10+
The optimal algorithm is an optimized version of the `opt_einsum` 'dp'
1111
path - itself an implementation of https://arxiv.org/abs/1304.6112.
1212

13+
There is also a variant of the greedy algorithm, which runs `ntrials` of greedy,
14+
randomized paths and computes and reports the flops cost (log10) simultaneously:
15+
16+
- `optimize_random_greedy_track_flops(inputs, output, size_dict, **kwargs)`
17+
1318

1419
## Installation
1520

@@ -20,7 +25,7 @@ path - itself an implementation of https://arxiv.org/abs/1304.6112.
2025
pip install cotengrust
2126
```
2227

23-
or if you want to develop locally (which requires [pyo3](https://github.com/PyO3/pyo3)
28+
or if you want to develop locally (which requires [pyo3](https://github.com/PyO3/pyo3)
2429
and [maturin](https://github.com/PyO3/maturin)):
2530

2631
```bash
@@ -34,8 +39,8 @@ maturin develop --release
3439
## Usage
3540

3641
If `cotengrust` is installed, then by default `cotengra` will use it for its
37-
greedy and optimal subroutines, notably subtree reconfiguration. You can also
38-
call the routines directly:
42+
greedy, random-greedy, and optimal subroutines, notably subtree
43+
reconfiguration. You can also call the routines directly:
3944

4045
```python
4146
import cotengra as ctg
@@ -225,6 +230,76 @@ def optimize_simplify(
225230
"""
226231
...
227232

233+
def optimize_random_greedy_track_flops(
234+
inputs,
235+
output,
236+
size_dict,
237+
ntrials=1,
238+
costmod=1.0,
239+
temperature=0.01,
240+
seed=None,
241+
simplify=True,
242+
use_ssa=False,
243+
):
244+
"""Perform a batch of random greedy optimizations, simulteneously tracking
245+
the best contraction path in terms of flops, so as to avoid constructing a
246+
separate contraction tree.
247+
248+
Parameters
249+
----------
250+
inputs : tuple[tuple[str]]
251+
The indices of each input tensor.
252+
output : tuple[str]
253+
The indices of the output tensor.
254+
size_dict : dict[str, int]
255+
A dictionary mapping indices to their dimension.
256+
ntrials : int, optional
257+
The number of random greedy trials to perform. The default is 1.
258+
costmod : float, optional
259+
When assessing local greedy scores how much to weight the size of the
260+
tensors removed compared to the size of the tensor added::
261+
262+
score = size_ab - costmod * (size_a + size_b)
263+
264+
This can be a useful hyper-parameter to tune.
265+
temperature : float, optional
266+
When asessing local greedy scores, how much to randomly perturb the
267+
score. This is implemented as::
268+
269+
score -> sign(score) * log(|score|) - temperature * gumbel()
270+
271+
which implements boltzmann sampling.
272+
seed : int, optional
273+
The seed for the random number generator.
274+
simplify : bool, optional
275+
Whether to perform simplifications before optimizing. These are:
276+
277+
- ignore any indices that appear in all terms
278+
- combine any repeated indices within a single term
279+
- reduce any non-output indices that only appear on a single term
280+
- combine any scalar terms
281+
- combine any tensors with matching indices (hadamard products)
282+
283+
Such simpifications may be required in the general case for the proper
284+
functioning of the core optimization, but may be skipped if the input
285+
indices are already in a simplified form.
286+
use_ssa : bool, optional
287+
Whether to return the contraction path in 'single static assignment'
288+
(SSA) format (i.e. as if each intermediate is appended to the list of
289+
inputs, without removals). This can be quicker and easier to work with
290+
than the 'linear recycled' format that `numpy` and `opt_einsum` use.
291+
292+
Returns
293+
-------
294+
path : list[list[int]]
295+
The best contraction path, given as a sequence of pairs of node
296+
indices.
297+
flops : float
298+
The flops (/ contraction cost / number of multiplications), of the best
299+
contraction path, given log10.
300+
"""
301+
...
302+
228303
def ssa_to_linear(ssa_path, n=None):
229304
"""Convert a SSA path to linear format."""
230305
...

0 commit comments

Comments
 (0)