Skip to content

Commit 52a9172

Browse files
authored
Refactor casting logic for arange function
1 parent 1e37116 commit 52a9172

File tree

1 file changed

+9
-11
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+9
-11
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5117,24 +5117,22 @@ def aten_linspace(
51175117
COMPLEX128.dtype,
51185118
)
51195119

5120-
if is_integer_dtype:
5120+
if ir.DataType(dtype).is_integer():
51215121
# Cast to integer dtype first, then to compute dtype
51225122
# This ensures truncation happens before computation
51235123
start_int = op.Cast(start, to=dtype)
51245124
end_int = op.Cast(end, to=dtype)
5125-
start_f = op.Cast(start_int, to=compute_dtype)
5126-
end_f = op.Cast(end_int, to=compute_dtype)
5125+
start = op.Cast(start_int, to=compute_dtype)
5126+
end = op.Cast(end_int, to=compute_dtype)
51275127
else:
5128-
# For float dtypes, cast directly to compute dtype
5129-
start_f = op.Cast(start, to=compute_dtype)
5130-
end_f = op.Cast(end, to=compute_dtype)
5128+
compute_dtype = dtype
51315129

51325130
rg = aten_arange_start(0, steps, dtype=compute_dtype)
5133-
steps_f = op.Cast(steps, to=compute_dtype)
5134-
one = op.Cast(1.0, to=compute_dtype)
5135-
two = op.Cast(2.0, to=compute_dtype)
5136-
steps_minus_1 = op.Sub(steps_f, one)
5137-
step = op.Div(op.Sub(end_f, start_f), steps_minus_1)
5131+
steps_f = op.Constant(value=ir.tensor(steps, dtype=compute_dtype))
5132+
one = op.Constant(value=ir.tensor(1, dtype=compute_dtype))
5133+
two = op.Constant(value=ir.tensor(2, dtype=compute_dtype))
5134+
steps_minus_1 = op.Constant(value=ir.tensor(steps - 1, dtype=compute_dtype))
5135+
step = op.Constant(value=ir.tensor((end - start) / (steps - 1), dtype=compute_dtype))
51385136

51395137
# Two-sided computation for numerical stability at endpoints
51405138
# Use forward computation for first half, backward for second half

0 commit comments

Comments
 (0)