Skip to content

Commit 67d21c4

Browse files
Updates arange to avoid immediate evaluation of constants
1 parent 7b40747 commit 67d21c4

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

tripy/nvtripy/frontend/ops/arange.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,8 @@ def arange(
8888
size = op_utils.tensor_from_shape_like([size])
8989

9090
linspace_dtype = Linspace.get_closest_dtype(dtype)
91-
start = Tensor(start, dtype=linspace_dtype) if not isinstance(start, DimensionSize) else cast(start, linspace_dtype)
92-
step = (
93-
Tensor([step], dtype=linspace_dtype)
94-
if not isinstance(step, DimensionSize)
95-
else cast(reshape(step, (1,)), linspace_dtype)
96-
)
91+
start = cast(Tensor(start) if not isinstance(start, DimensionSize) else start, linspace_dtype)
92+
step = cast(Tensor([step]) if not isinstance(step, DimensionSize) else reshape(step, (1,)), linspace_dtype)
9793
out = op_utils.create_op(Linspace, [size, start, step], dtype=linspace_dtype)
9894
return cast(out, dtype)
9995

0 commit comments

Comments
 (0)