@@ -172,7 +172,7 @@ def get_interpolated_poses(pose_a: NDArray, pose_b: NDArray, steps: int = 10) ->
172172 quat_b = quaternion_from_matrix (pose_b [:3 , :3 ])
173173
174174 ts = np .linspace (0 , 1 , steps )
175- quats = [quaternion_slerp (quat_a , quat_b , t ) for t in ts ]
175+ quats = [quaternion_slerp (quat_a , quat_b , float ( t ) ) for t in ts ]
176176 trans = [(1 - t ) * pose_a [:3 , 3 ] + t * pose_b [:3 , 3 ] for t in ts ]
177177
178178 poses_ab = []
@@ -199,7 +199,7 @@ def get_interpolated_k(
199199 List of interpolated camera poses
200200 """
201201 Ks : List [Float [Tensor , "3 3" ]] = []
202- ts = np .linspace (0 , 1 , steps )
202+ ts = torch .linspace (0 , 1 , steps , dtype = k_a . dtype , device = k_a . device )
203203 for t in ts :
204204 new_k = k_a * (1.0 - t ) + k_b * t
205205 Ks .append (new_k )
@@ -218,7 +218,7 @@ def get_interpolated_time(
218218 steps: number of steps the interpolated pose path should contain
219219 """
220220 times : List [Float [Tensor , "1" ]] = []
221- ts = np .linspace (0 , 1 , steps )
221+ ts = torch .linspace (0 , 1 , steps , dtype = time_a . dtype , device = time_a . device )
222222 for t in ts :
223223 new_t = time_a * (1.0 - t ) + time_b * t
224224 times .append (new_t )
0 commit comments