@@ -1914,21 +1914,32 @@ def reject_outlier_keypoints(cls, keypoints, threshold_in_stds=2):
19141914 return temp
19151915
19161916 @classmethod
1917- def ast_fillna_2d (cls , arr ):
1917+ def ast_fillna_2d (cls , arr : np .ndarray ) -> np .ndarray :
1918+ """
1919+ Fills NaN values in a 4D keypoints array using linear interpolation.
1920+
1921+ Parameters:
1922+ arr (np.ndarray): A 4D numpy array of shape (n_frames, n_individuals, n_kpts, n_dims).
1923+
1924+ Returns:
1925+ np.ndarray: The 4D array with NaN values filled.
1926+ """
19181927 n_frames , n_individuals , n_kpts , n_dims = arr .shape
19191928 arr_reshaped = arr .reshape (n_frames , - 1 )
19201929 x = np .arange (n_frames )
19211930 for i in range (arr_reshaped .shape [1 ]):
19221931 valid_mask = ~ np .isnan (arr_reshaped [:, i ])
19231932 if np .all (valid_mask ):
19241933 continue
1925- arr_reshaped [:, i ] = np .interp (
1926- x , x [valid_mask ], arr_reshaped [valid_mask , i ]
1927- )
1928- # Reshape the array back to 4D
1929- arr = arr_reshaped .reshape (n_frames , n_individuals , n_kpts , n_dims )
1934+ elif np .any (valid_mask ):
1935+ # Perform interpolation when there are some valid points
1936+ arr_reshaped [:, i ] = np .interp (x , x [valid_mask ], arr_reshaped [valid_mask , i ])
1937+ else :
1938+ # Handle the case where all values are NaN
1939+ # Replace with a default value or another suitable handling
1940+ arr_reshaped [:, i ].fill (0 ) # Example: filling with 0
19301941
1931- return arr
1942+ return arr_reshaped . reshape ( n_frames , n_individuals , n_kpts , n_dims )
19321943
19331944 @classmethod
19341945 @timer_decorator
0 commit comments