@@ -5103,29 +5103,20 @@ def aten_linspace(
51035103 if steps == 1 :
51045104 return aten_full (op .Constant (value_ints = [steps ]), start , dtype = dtype )
51055105
5106- # Use double precision for computation to match PyTorch's internal precision
5107- compute_dtype = DOUBLE .dtype
5108-
51095106 # For integer output dtypes, cast start/end to the target dtype first
51105107 # This matches PyTorch's behavior where fractional start/end values
51115108 # are truncated before computing the linspace
5112- is_integer_dtype = dtype not in (
5113- FLOAT .dtype ,
5114- DOUBLE .dtype ,
5115- FLOAT16 .dtype ,
5116- COMPLEX64 .dtype ,
5117- COMPLEX128 .dtype ,
5118- )
5119-
51205109 if ir .DataType (dtype ).is_integer ():
5110+ # Use double precision for computation to match PyTorch's internal precision
5111+ compute_dtype = ir .DataType .DOUBLE
51215112 # Cast to integer dtype first, then to compute dtype
51225113 # This ensures truncation happens before computation
5123- start_int = op .Cast (start , to = dtype )
5124- end_int = op .Cast (end , to = dtype )
5125- start = op .Cast (start_int , to = compute_dtype )
5126- end = op .Cast (end_int , to = compute_dtype )
5114+ start_f = op .Constant (value = ir .tensor (int (start ), dtype = compute_dtype ))
5115+ end_f = op .Constant (value = ir .tensor (int (end ), dtype = compute_dtype ))
51275116 else :
51285117 compute_dtype = dtype
5118+ start_f = op .Constant (value = ir .tensor (start , dtype = compute_dtype ))
5119+ end_f = op .Constant (value = ir .tensor (end , dtype = compute_dtype ))
51295120
51305121 rg = aten_arange_start (0 , steps , dtype = compute_dtype )
51315122 steps_f = op .Constant (value = ir .tensor (steps , dtype = compute_dtype ))
0 commit comments