|
1 | 1 | from collections import defaultdict |
2 | | -from itertools import chain |
3 | | -from typing import DefaultDict, Dict |
| 2 | +from typing import DefaultDict, Dict, Optional, Tuple |
4 | 3 |
|
5 | | -from bioimageio.spec.model.v0_5 import ModelDescr |
| 4 | +from bioimageio.spec.model.v0_5 import ( |
| 5 | + InputAxis, |
| 6 | + ModelDescr, |
| 7 | + ParameterizedSize, |
| 8 | + SizeReference, |
| 9 | +) |
| 10 | +from torch.export import Dim |
| 11 | +from typing_extensions import assert_never |
6 | 12 |
|
7 | 13 |
|
8 | | -def get_dynamic_axes(model_descr: ModelDescr): |
9 | | - dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict) |
10 | | - for d in chain(model_descr.inputs, model_descr.outputs): |
| 14 | +def get_dynamic_shapes(model_descr: ModelDescr): |
| 15 | + dynamic_shapes: DefaultDict[str, Dict[int, Optional[Dim]]] = defaultdict(dict) |
| 16 | + potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {} |
| 17 | + # add dynamic dims from parameterized input sizes (and fixed sizes as None) |
| 18 | + for d in model_descr.inputs: |
11 | 19 | for i, ax in enumerate(d.axes): |
12 | | - if not isinstance(ax.size, int): |
13 | | - dynamic_axes[str(d.id)][i] = str(ax.id) |
| 20 | + dim_name = f"{d.id}_{ax.id}" |
| 21 | + if isinstance(ax.size, int): |
| 22 | + dim = None # fixed size (could also be left out) |
| 23 | + elif ax.size is None: |
| 24 | + dim = Dim(dim_name, min=1) |
| 25 | + elif isinstance(ax.size, ParameterizedSize): |
| 26 | + dim = Dim(dim_name, min=ax.size.min) |
| 27 | + elif isinstance(ax.size, SizeReference): |
| 28 | + continue # handled below |
| 29 | + else: |
| 30 | + assert_never(ax.size) |
14 | 31 |
|
15 | | - return dynamic_axes |
| 32 | + dynamic_shapes[str(d.id)][i] = dim |
| 33 | + potential_ref_axes[dim_name] = (ax, i) |
| 34 | + |
| 35 | + # add dynamic dims from size references |
| 36 | + for d in model_descr.inputs: |
| 37 | + for i, ax in enumerate(d.axes): |
| 38 | + if not isinstance(ax.size, SizeReference): |
| 39 | + continue # handled above |
| 40 | + |
| 41 | + dim_name_ref = f"{ax.size.tensor_id}_{ax.size.axis_id}" |
| 42 | + ax_ref, i_ref = potential_ref_axes[dim_name_ref] |
| 43 | + a = ax_ref.scale / ax.scale |
| 44 | + b = ax.size.offset |
| 45 | + dim_ref = dynamic_shapes[str(ax.size.tensor_id)][i_ref] |
| 46 | + dim = a * dim_ref + b if dim_ref is not None else None |
| 47 | + dynamic_shapes[str(d.id)][i] = dim |
| 48 | + |
| 49 | + return dict(dynamic_shapes) |
0 commit comments