Skip to content

Commit beacf26

Browse files
authored
Update core.py
1 parent 52a9172 commit beacf26

File tree

1 file changed

+6
-15
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+6
-15
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5103,29 +5103,20 @@ def aten_linspace(
51035103
if steps == 1:
51045104
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)
51055105

5106-
# Use double precision for computation to match PyTorch's internal precision
5107-
compute_dtype = DOUBLE.dtype
5108-
51095106
# For integer output dtypes, cast start/end to the target dtype first
51105107
# This matches PyTorch's behavior where fractional start/end values
51115108
# are truncated before computing the linspace
5112-
is_integer_dtype = dtype not in (
5113-
FLOAT.dtype,
5114-
DOUBLE.dtype,
5115-
FLOAT16.dtype,
5116-
COMPLEX64.dtype,
5117-
COMPLEX128.dtype,
5118-
)
5119-
51205109
if ir.DataType(dtype).is_integer():
5110+
# Use double precision for computation to match PyTorch's internal precision
5111+
compute_dtype = ir.DataType.DOUBLE
51215112
# Cast to integer dtype first, then to compute dtype
51225113
# This ensures truncation happens before computation
5123-
start_int = op.Cast(start, to=dtype)
5124-
end_int = op.Cast(end, to=dtype)
5125-
start = op.Cast(start_int, to=compute_dtype)
5126-
end = op.Cast(end_int, to=compute_dtype)
5114+
start_f = op.Constant(value=ir.tensor(int(start), dtype=compute_dtype))
5115+
end_f = op.Constant(value=ir.tensor(int(end), dtype=compute_dtype))
51275116
else:
51285117
compute_dtype = dtype
5118+
start_f = op.Constant(value=ir.tensor(start, dtype=compute_dtype))
5119+
end_f = op.Constant(value=ir.tensor(end, dtype=compute_dtype))
51295120

51305121
rg = aten_arange_start(0, steps, dtype=compute_dtype)
51315122
steps_f = op.Constant(value=ir.tensor(steps, dtype=compute_dtype))

0 commit comments

Comments
 (0)