|
21 | 21 | Join, |
22 | 22 | MakeVector, |
23 | 23 | ScalarFromTensor, |
| 24 | + Split, |
24 | 25 | TensorFromScalar, |
25 | 26 | alloc, |
26 | 27 | as_tensor, |
@@ -616,6 +617,39 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): |
616 | 617 | return [node.inputs[0].dimshuffle(tuple(remain_dim))] |
617 | 618 |
|
618 | 619 |
|
| 620 | +@register_specialize("shape_unsafe") |
| 621 | +@node_rewriter(tracks=[Split]) |
| 622 | +def split_to_subtensor(fgraph, node): |
| 623 | + """Rewrite split(2)(x, 0) -> (x[:split_sizes[0]], x[split_sizes[0]:). |
| 624 | +
|
| 625 | + This allows lifting the underlying split close to the inputs, and increases fusion opportunities. |
| 626 | + It automatically handles unused split outputs as well |
| 627 | +
|
| 628 | + It only works for constant axis |
| 629 | + """ |
| 630 | + x, axis, split_sizes = node.inputs |
| 631 | + |
| 632 | + n_splits = node.op.len_splits |
| 633 | + if n_splits <= 1: |
| 634 | + return [x] |
| 635 | + |
| 636 | + if not isinstance(axis, Constant): |
| 637 | + return None |
| 638 | + |
| 639 | + empty_slices = (slice(None),) * int(axis.data) |
| 640 | + ys = [] |
| 641 | + |
| 642 | + end = split_sizes[0] |
| 643 | + ys.append(x[(*empty_slices, slice(None, end))]) |
| 644 | + prev_start = end |
| 645 | + for i in range(1, n_splits - 1): |
| 646 | + end = prev_start + split_sizes[i] |
| 647 | + ys.append(x[(*empty_slices, slice(prev_start, end))]) |
| 648 | + prev_start = end |
| 649 | + ys.append(x[(*empty_slices, slice(prev_start, None))]) |
| 650 | + return ys |
| 651 | + |
| 652 | + |
619 | 653 | @register_infer_shape |
620 | 654 | @register_useless |
621 | 655 | @register_canonicalize |
|
0 commit comments