Skip to content

Commit 5a338ad

Browse files
Aravind-11titaiwangmsjustinchuby
authored
[torchlib] Fix linspace implementation for int64 (#2693)
## Description Fixes #854 - linspace now correctly handles int64 dtype ## Changes - Modified `aten_linspace` to compute in floating-point then cast to target dtype - This matches PyTorch's behavior and fixes integer division precision loss ## Testing Manually verified: `linspace(0, 10, 5, dtype=int64)` now produces correct output `[0, 2, 5, 7, 10]` ## Questions Where should I add automated test cases for this fix? Happy to add tests wherever you suggest! --------- Co-authored-by: Ti-Tai Wang <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 565b8e5 commit 5a338ad

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5494,10 +5494,10 @@ def aten_linear_backward(
54945494

54955495
@torch_op("aten::linspace", trace_only=True)
54965496
def aten_linspace(
5497-
start: TFloat,
5498-
end: TFloat,
5497+
start: TensorType,
5498+
end: TensorType,
54995499
steps: int,
5500-
dtype: int = FLOAT.dtype,
5500+
dtype: int = -1,
55015501
layout: str = "",
55025502
device: str = "",
55035503
pin_memory: bool = False,
@@ -5507,26 +5507,45 @@ def aten_linspace(
55075507
if dtype == -1 or dtype is None:
55085508
dtype = FLOAT.dtype
55095509

5510-
# Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896
55115510
if steps == 0:
55125511
return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype)
55135512
if steps == 1:
55145513
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)
55155514

5516-
rg = aten_arange_start(0, steps, dtype=dtype)
5517-
start = op.Cast(start, to=dtype)
5518-
end = op.Cast(end, to=dtype)
5519-
steps_float = op.Cast(steps, to=dtype)
5520-
one = op.Cast(1.0, to=dtype)
5521-
two = op.Cast(2.0, to=dtype)
5522-
steps_minus_1 = op.Cast(steps - 1, to=dtype)
5523-
step = op.Div(op.Sub(end, start), steps_minus_1)
5524-
return op.Where(
5525-
rg < op.Div(steps_float, two),
5526-
start + step * rg,
5527-
end - step * (steps_float - one - rg),
5515+
# For integer output dtypes, cast start/end to the target dtype first
5516+
# This matches PyTorch's behavior where fractional start/end values
5517+
# are truncated before computing the linspace
5518+
dtype = ir.DataType(dtype)
5519+
if dtype.is_integer():
5520+
# Use double precision for computation to match PyTorch's internal precision
5521+
compute_dtype = ir.DataType.DOUBLE
5522+
# Cast to integer dtype first (truncation), then to compute dtype
5523+
start_int = op.Cast(start, to=dtype) # Truncate to int32/int64
5524+
end_int = op.Cast(end, to=dtype)
5525+
start_f = op.Cast(start_int, to=compute_dtype) # Then to double
5526+
end_f = op.Cast(end_int, to=compute_dtype)
5527+
else:
5528+
compute_dtype = dtype
5529+
start_f = op.Cast(start, to=compute_dtype)
5530+
end_f = op.Cast(end, to=compute_dtype)
5531+
5532+
rg = aten_arange_start(0, steps, dtype=compute_dtype)
5533+
steps_f = op.Cast(steps, to=compute_dtype)
5534+
one = op.Constant(value=ir.tensor(1, dtype=compute_dtype))
5535+
two = op.Constant(value=ir.tensor(2, dtype=compute_dtype))
5536+
steps_minus_1 = op.Sub(steps_f, one)
5537+
step = op.Div(op.Sub(end_f, start_f), steps_minus_1)
5538+
5539+
# Two-sided computation for numerical stability at endpoints
5540+
# Use forward computation for first half, backward for second half
5541+
lin_vals = op.Where(
5542+
rg < op.Div(steps_f, two),
5543+
op.Add(start_f, op.Mul(step, rg)),
5544+
op.Sub(end_f, op.Mul(step, op.Sub(steps_minus_1, rg))),
55285545
)
55295546

5547+
return op.Cast(lin_vals, to=dtype)
5548+
55305549

55315550
@torch_op("aten::log", trace_only=True)
55325551
def aten_log(self: TFloat) -> TFloat:

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -779,14 +779,6 @@ def _where_input_wrangler(
779779
"linspace",
780780
core_ops.aten_linspace,
781781
tolerance={torch.float16: (2e-2, 2e-3)},
782-
)
783-
.xfail(
784-
dtypes=(torch.int64, torch.int32),
785-
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
786-
)
787-
.skip(
788-
matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32),
789-
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
790782
),
791783
TorchLibOpInfo("log", core_ops.aten_log),
792784
TorchLibOpInfo("le", core_ops.aten_le),

0 commit comments

Comments
 (0)