|
59 | 59 | zscalar,
|
60 | 60 | )
|
61 | 61 | from pytensor.tensor.type_other import (
|
| 62 | + MakeSlice, |
62 | 63 | NoneConst,
|
63 | 64 | NoneTypeT,
|
64 | 65 | SliceConstant,
|
@@ -527,11 +528,20 @@ def basic_shape(shape, indices):
|
527 | 528 | if isinstance(idx, slice):
|
528 | 529 | res_shape += (slice_len(idx, n),)
|
529 | 530 | elif isinstance(getattr(idx, "type", None), SliceType):
|
530 |
| - if idx.owner: |
531 |
| - idx_inputs = idx.owner.inputs |
| 531 | + if idx.owner is None: |
| 532 | + if not isinstance(idx, Constant): |
| 533 | + # This is an input slice, we can't reason symbolically on it. |
| 534 | + # We don't even know if we will get None entries or integers |
| 535 | + res_shape += (None,) |
| 536 | + continue |
| 537 | + else: |
| 538 | + sl: slice = idx.data |
| 539 | + slice_inputs = (sl.start, sl.stop, sl.step) |
| 540 | + elif isinstance(idx.owner.op, MakeSlice): |
| 541 | + slice_inputs = idx.owner.inputs |
532 | 542 | else:
|
533 |
| - idx_inputs = (None,) |
534 |
| - res_shape += (slice_len(slice(*idx_inputs), n),) |
| 543 | + raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}") |
| 544 | + res_shape += (slice_len(slice(*slice_inputs), n),) |
535 | 545 | elif idx is None:
|
536 | 546 | res_shape += (ps.ScalarConstant(ps.int64, 1),)
|
537 | 547 | elif isinstance(getattr(idx, "type", None), NoneTypeT):
|
@@ -2728,6 +2738,11 @@ def is_bool_index(idx):
|
2728 | 2738 | res_shape = list(
|
2729 | 2739 | indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
|
2730 | 2740 | )
|
| 2741 | + for i, res_dim_length in enumerate(res_shape): |
| 2742 | + if res_dim_length is None: |
| 2743 | + # This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice) |
| 2744 | + # We must compute the Op to find its shape |
| 2745 | + res_shape[i] = Shape_i(i)(node.out) |
2731 | 2746 |
|
2732 | 2747 | adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
|
2733 | 2748 | bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
|
|
0 commit comments