@@ -206,48 +206,74 @@ def get_interpolated_k(
206206 return Ks
207207
208208
209- def get_ordered_poses_and_k (
209+ def get_interpolated_time (
210+ time_a : Float [Tensor , "1" ], time_b : Float [Tensor , "1" ], steps : int = 10
211+ ) -> List [Float [Tensor , "1" ]]:
212+ """
213+ Returns interpolated time between two camera poses with specified number of steps.
214+
215+ Args:
216+ time_a: camera time 1
217+ time_b: camera time 2
218+ steps: number of steps the interpolated pose path should contain
219+ """
220+ times : List [Float [Tensor , "1" ]] = []
221+ ts = np .linspace (0 , 1 , steps )
222+ for t in ts :
223+ new_t = time_a * (1.0 - t ) + time_b * t
224+ times .append (new_t )
225+ return times
226+
227+
228+ def get_ordered_poses_and_k_and_time (
210229 poses : Float [Tensor , "num_poses 3 4" ],
211230 Ks : Float [Tensor , "num_poses 3 3" ],
212- ) -> Tuple [Float [Tensor , "num_poses 3 4" ], Float [Tensor , "num_poses 3 3" ]]:
231+ times : Optional [Float [Tensor , "num_poses 1" ]] = None ,
232+ ) -> Tuple [Float [Tensor , "num_poses 3 4" ], Float [Tensor , "num_poses 3 3" ], Optional [Float [Tensor , "num_poses 1" ]]]:
213233 """
214234 Returns ordered poses and intrinsics by euclidian distance between poses.
215235
216236 Args:
217237 poses: list of camera poses
218238 Ks: list of camera intrinsics
239+ times: list of camera times
219240
220241 Returns:
221- tuple of ordered poses and intrinsics
242+ tuple of ordered poses, intrinsics and times
222243
223244 """
224245
225246 poses_num = len (poses )
226247
227248 ordered_poses = torch .unsqueeze (poses [0 ], 0 )
228249 ordered_ks = torch .unsqueeze (Ks [0 ], 0 )
250+ ordered_times = torch .unsqueeze (times [0 ], 0 ) if times is not None else None
229251
230252 # remove the first pose from poses
231253 poses = poses [1 :]
232254 Ks = Ks [1 :]
255+ times = times [1 :] if times is not None else None
233256
234257 for _ in range (poses_num - 1 ):
235258 distances = torch .norm (ordered_poses [- 1 ][:, 3 ] - poses [:, :, 3 ], dim = 1 )
236259 idx = torch .argmin (distances )
237260 ordered_poses = torch .cat ((ordered_poses , torch .unsqueeze (poses [idx ], 0 )), dim = 0 )
238261 ordered_ks = torch .cat ((ordered_ks , torch .unsqueeze (Ks [idx ], 0 )), dim = 0 )
262+ ordered_times = torch .cat ((ordered_times , torch .unsqueeze (times [idx ], 0 )), dim = 0 ) if times is not None else None # type: ignore
239263 poses = torch .cat ((poses [0 :idx ], poses [idx + 1 :]), dim = 0 )
240264 Ks = torch .cat ((Ks [0 :idx ], Ks [idx + 1 :]), dim = 0 )
265+ times = torch .cat ((times [0 :idx ], times [idx + 1 :]), dim = 0 ) if times is not None else None
241266
242- return ordered_poses , ordered_ks
267+ return ordered_poses , ordered_ks , ordered_times
243268
244269
245270def get_interpolated_poses_many (
246271 poses : Float [Tensor , "num_poses 3 4" ],
247272 Ks : Float [Tensor , "num_poses 3 3" ],
273+ times : Optional [Float [Tensor , "num_poses 1" ]] = None ,
248274 steps_per_transition : int = 10 ,
249275 order_poses : bool = False ,
250- ) -> Tuple [Float [Tensor , "num_poses 3 4" ], Float [Tensor , "num_poses 3 3" ]]:
276+ ) -> Tuple [Float [Tensor , "num_poses 3 4" ], Float [Tensor , "num_poses 3 3" ], Optional [ Float [ Tensor , "num_poses 1" ]] ]:
251277 """Return interpolated poses for many camera poses.
252278
253279 Args:
@@ -261,21 +287,27 @@ def get_interpolated_poses_many(
261287 """
262288 traj = []
263289 k_interp = []
290+ time_interp = [] if times is not None else None
264291
265292 if order_poses :
266- poses , Ks = get_ordered_poses_and_k (poses , Ks )
293+ poses , Ks , times = get_ordered_poses_and_k_and_time (poses , Ks , times )
267294
268295 for idx in range (poses .shape [0 ] - 1 ):
269296 pose_a = poses [idx ].cpu ().numpy ()
270297 pose_b = poses [idx + 1 ].cpu ().numpy ()
271- poses_ab = get_interpolated_poses (pose_a , pose_b , steps = steps_per_transition )
272- traj += poses_ab
298+ traj += get_interpolated_poses (pose_a , pose_b , steps = steps_per_transition )
273299 k_interp += get_interpolated_k (Ks [idx ], Ks [idx + 1 ], steps = steps_per_transition )
300+ if times is not None :
301+ time_interp += get_interpolated_time (times [idx ], times [idx + 1 ], steps = steps_per_transition ) # type: ignore
274302
275303 traj = np .stack (traj , axis = 0 )
276304 k_interp = torch .stack (k_interp , dim = 0 )
277-
278- return torch .tensor (traj , dtype = torch .float32 ), torch .tensor (k_interp , dtype = torch .float32 )
305+ time_interp = torch .stack (time_interp , dim = 0 ) if time_interp is not None else None
306+ return (
307+ torch .tensor (traj , dtype = torch .float32 ),
308+ torch .tensor (k_interp , dtype = torch .float32 ),
309+ torch .tensor (time_interp , dtype = torch .float32 ) if time_interp is not None else None ,
310+ )
279311
280312
281313def normalize (x : torch .Tensor ) -> Float [Tensor , "*batch" ]:
0 commit comments