Skip to content

Commit 480bcff

Browse files
committed
fix shift default fill_value
1 parent 7135041 commit 480bcff

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

pymove/tests/test_utils_trajectories.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,17 @@ def test_shift():
9696

9797
expected = [4, 5, 0, 0, 0]
9898
array_ = [1, 2, 3, 4, 5]
99-
shifted_array = trajectories.shift(arr=array_, num=-3, fill_value=0)
99+
shifted_array = trajectories.shift(arr=array_, num=-3)
100+
assert_array_equal(shifted_array, expected)
101+
102+
expected = [False, False, False, True, True]
103+
array_ = [True, True, True, True, True]
104+
shifted_array = trajectories.shift(arr=array_, num=3)
105+
assert_array_equal(shifted_array, expected)
106+
107+
expected = ['dewberry', 'eggplant', 'nan', 'nan', 'nan']
108+
array_ = ['apple', 'banana', 'coconut', 'dewberry', 'eggplant']
109+
shifted_array = trajectories.shift(arr=array_, num=-3, fill_value=np.nan)
100110
assert_array_equal(shifted_array, expected)
101111

102112

pymove/utils/trajectories.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ def shift(
210210

211211
result = np.empty_like(arr)
212212
if fill_value is None:
213-
if isinstance(result.dtype, int):
213+
dtype = result.dtype
214+
if np.issubdtype(dtype, np.bool_):
215+
fill_value = False
216+
elif np.issubdtype(dtype, np.integer):
214217
fill_value = 0
215218
else:
216219
fill_value = np.nan

0 commit comments

Comments
 (0)