@@ -5085,40 +5085,59 @@ def aten_linear_backward(
50855085
50865086@torch_op ("aten::linspace" , trace_only = True )
50875087def aten_linspace (
5088- start : TFloat ,
5089- end : TFloat ,
5088+ start : float ,
5089+ end : float ,
50905090 steps : int ,
5091- dtype : int = FLOAT . dtype ,
5091+ dtype : int = - 1 ,
50925092 layout : str = "" ,
50935093 device : str = "" ,
50945094 pin_memory : bool = False ,
50955095) -> TensorType :
50965096 """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
5097+
50975098 if dtype == - 1 or dtype is None :
50985099 dtype = FLOAT .dtype
5099-
5100+
51005101 if steps == 0 :
51015102 return aten_full (op .Constant (value_ints = [0 ]), 0.0 , dtype = dtype )
51025103 if steps == 1 :
51035104 return aten_full (op .Constant (value_ints = [steps ]), start , dtype = dtype )
5104-
5105- compute_dtype = FLOAT .dtype
5106-
5105+
5106+ # Use double precision for computation to match PyTorch's internal precision
5107+ compute_dtype = DOUBLE .dtype
5108+
5109+ # For integer output dtypes, cast start/end to the target dtype first
5110+ # This matches PyTorch's behavior where fractional start/end values
5111+ # are truncated before computing the linspace
5112+ is_integer_dtype = dtype not in (FLOAT .dtype , DOUBLE .dtype , FLOAT16 .dtype , COMPLEX64 .dtype , COMPLEX128 .dtype )
5113+
5114+ if is_integer_dtype :
5115+ # Cast to integer dtype first, then to compute dtype
5116+ # This ensures truncation happens before computation
5117+ start_int = op .Cast (start , to = dtype )
5118+ end_int = op .Cast (end , to = dtype )
5119+ start_f = op .Cast (start_int , to = compute_dtype )
5120+ end_f = op .Cast (end_int , to = compute_dtype )
5121+ else :
5122+ # For float dtypes, cast directly to compute dtype
5123+ start_f = op .Cast (start , to = compute_dtype )
5124+ end_f = op .Cast (end , to = compute_dtype )
5125+
51075126 rg = aten_arange_start (0 , steps , dtype = compute_dtype )
5108- start_f = op .Cast (start , to = compute_dtype )
5109- end_f = op .Cast (end , to = compute_dtype )
51105127 steps_f = op .Cast (steps , to = compute_dtype )
51115128 one = op .Cast (1.0 , to = compute_dtype )
51125129 two = op .Cast (2.0 , to = compute_dtype )
51135130 steps_minus_1 = op .Sub (steps_f , one )
51145131 step = op .Div (op .Sub (end_f , start_f ), steps_minus_1 )
5115-
5132+
5133+ # Two-sided computation for numerical stability at endpoints
5134+ # Use forward computation for first half, backward for second half
51165135 lin_vals = op .Where (
51175136 rg < op .Div (steps_f , two ),
51185137 op .Add (start_f , op .Mul (step , rg )),
5119- op .Sub (end_f , op .Mul (step , op .Sub (op . Sub ( steps_f , one ) , rg ))),
5138+ op .Sub (end_f , op .Mul (step , op .Sub (steps_minus_1 , rg ))),
51205139 )
5121-
5140+
51225141 return op .Cast (lin_vals , to = dtype )
51235142
51245143
0 commit comments