Skip to content

Commit c772664

Browse files
committed
fixes
1 parent 1d4a07d commit c772664

File tree

1 file changed

+31
-12
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+31
-12
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5085,40 +5085,59 @@ def aten_linear_backward(
50855085

50865086
@torch_op("aten::linspace", trace_only=True)
50875087
def aten_linspace(
5088-
start: TFloat,
5089-
end: TFloat,
5088+
start: float,
5089+
end: float,
50905090
steps: int,
5091-
dtype: int = FLOAT.dtype,
5091+
dtype: int = -1,
50925092
layout: str = "",
50935093
device: str = "",
50945094
pin_memory: bool = False,
50955095
) -> TensorType:
50965096
"""linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
5097+
50975098
if dtype == -1 or dtype is None:
50985099
dtype = FLOAT.dtype
5099-
5100+
51005101
if steps == 0:
51015102
return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype)
51025103
if steps == 1:
51035104
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)
5104-
5105-
compute_dtype = FLOAT.dtype
5106-
5105+
5106+
# Use double precision for computation to match PyTorch's internal precision
5107+
compute_dtype = DOUBLE.dtype
5108+
5109+
# For integer output dtypes, cast start/end to the target dtype first
5110+
# This matches PyTorch's behavior where fractional start/end values
5111+
# are truncated before computing the linspace
5112+
is_integer_dtype = dtype not in (FLOAT.dtype, DOUBLE.dtype, FLOAT16.dtype, COMPLEX64.dtype, COMPLEX128.dtype)
5113+
5114+
if is_integer_dtype:
5115+
# Cast to integer dtype first, then to compute dtype
5116+
# This ensures truncation happens before computation
5117+
start_int = op.Cast(start, to=dtype)
5118+
end_int = op.Cast(end, to=dtype)
5119+
start_f = op.Cast(start_int, to=compute_dtype)
5120+
end_f = op.Cast(end_int, to=compute_dtype)
5121+
else:
5122+
# For float dtypes, cast directly to compute dtype
5123+
start_f = op.Cast(start, to=compute_dtype)
5124+
end_f = op.Cast(end, to=compute_dtype)
5125+
51075126
rg = aten_arange_start(0, steps, dtype=compute_dtype)
5108-
start_f = op.Cast(start, to=compute_dtype)
5109-
end_f = op.Cast(end, to=compute_dtype)
51105127
steps_f = op.Cast(steps, to=compute_dtype)
51115128
one = op.Cast(1.0, to=compute_dtype)
51125129
two = op.Cast(2.0, to=compute_dtype)
51135130
steps_minus_1 = op.Sub(steps_f, one)
51145131
step = op.Div(op.Sub(end_f, start_f), steps_minus_1)
5115-
5132+
5133+
# Two-sided computation for numerical stability at endpoints
5134+
# Use forward computation for first half, backward for second half
51165135
lin_vals = op.Where(
51175136
rg < op.Div(steps_f, two),
51185137
op.Add(start_f, op.Mul(step, rg)),
5119-
op.Sub(end_f, op.Mul(step, op.Sub(op.Sub(steps_f, one), rg))),
5138+
op.Sub(end_f, op.Mul(step, op.Sub(steps_minus_1, rg))),
51205139
)
5121-
5140+
51225141
return op.Cast(lin_vals, to=dtype)
51235142

51245143

0 commit comments

Comments
 (0)