Skip to content

Commit d314dfe

Browse files
committed
Modify arange to directly determine the shape from the arguments if possible
1 parent 2a7f3e1 commit d314dfe

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

pytensor/tensor/basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3224,13 +3224,29 @@ def __init__(self, dtype):
32243224
self.dtype = dtype
32253225

32263226
def make_node(self, start, stop, step):
3227+
types = TensorConstant
3228+
shape = (None,)
3229+
# if it is possible to directly determine the shape i.e static shape is present, we find it.
3230+
if (
3231+
isinstance(start, types)
3232+
and isinstance(stop, types)
3233+
and isinstance(step, types)
3234+
):
3235+
length = max(
3236+
np.max(np.ceil((stop.value - start.value) / step.value))
3237+
.astype(int)
3238+
.item(),
3239+
0,
3240+
)
3241+
shape = (length,)
3242+
32273243
start, stop, step = map(as_tensor_variable, (start, stop, step))
32283244
assert start.ndim == 0
32293245
assert stop.ndim == 0
32303246
assert step.ndim == 0
32313247

32323248
inputs = [start, stop, step]
3233-
outputs = [tensor(dtype=self.dtype, shape=(None,))]
3249+
outputs = [tensor(dtype=self.dtype, shape=shape)]
32343250

32353251
return Apply(self, inputs, outputs)
32363252

tests/tensor/test_basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,6 +2861,12 @@ def test_infer_shape(self, cast_policy):
28612861
assert np.all(f(2) == len(np.arange(0, 2)))
28622862
assert np.all(f(0) == len(np.arange(0, 0)))
28632863

2864+
def test_static_shape(self):
2865+
assert np.arange(1, 10).shape == arange(1, 10).type.shape
2866+
assert np.arange(10, 1, -1).shape == arange(10, 1, -1).type.shape
2867+
assert np.arange(1, -9, 2).shape == arange(1, -9, 2).type.shape
2868+
assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape
2869+
28642870

28652871
class TestNdGrid:
28662872
def setup_method(self):

0 commit comments

Comments
 (0)