@@ -5117,24 +5117,22 @@ def aten_linspace(
51175117 COMPLEX128 .dtype ,
51185118 )
51195119
5120- if is_integer_dtype :
5120+ if ir . DataType ( dtype ). is_integer () :
51215121 # Cast to integer dtype first, then to compute dtype
51225122 # This ensures truncation happens before computation
51235123 start_int = op .Cast (start , to = dtype )
51245124 end_int = op .Cast (end , to = dtype )
5125- start_f = op .Cast (start_int , to = compute_dtype )
5126- end_f = op .Cast (end_int , to = compute_dtype )
5125+ start = op .Cast (start_int , to = compute_dtype )
5126+ end = op .Cast (end_int , to = compute_dtype )
51275127 else :
5128- # For float dtypes, cast directly to compute dtype
5129- start_f = op .Cast (start , to = compute_dtype )
5130- end_f = op .Cast (end , to = compute_dtype )
5128+ compute_dtype = dtype
51315129
51325130 rg = aten_arange_start (0 , steps , dtype = compute_dtype )
5133- steps_f = op .Cast ( steps , to = compute_dtype )
5134- one = op .Cast ( 1.0 , to = compute_dtype )
5135- two = op .Cast ( 2.0 , to = compute_dtype )
5136- steps_minus_1 = op .Sub ( steps_f , one )
5137- step = op .Div ( op . Sub ( end_f , start_f ), steps_minus_1 )
5131+ steps_f = op .Constant ( value = ir . tensor ( steps , dtype = compute_dtype ) )
5132+ one = op .Constant ( value = ir . tensor ( 1 , dtype = compute_dtype ) )
5133+ two = op .Constant ( value = ir . tensor ( 2 , dtype = compute_dtype ) )
5134+ steps_minus_1 = op .Constant ( value = ir . tensor ( steps - 1 , dtype = compute_dtype ) )
5135+ step = op .Constant ( value = ir . tensor (( end - start ) / ( steps - 1 ), dtype = compute_dtype ) )
51385136
51395137 # Two-sided computation for numerical stability at endpoints
51405138 # Use forward computation for first half, backward for second half
0 commit comments