Skip to content

Commit 7850291

Browse files
committed
Fixed keras and added a new test
1 parent d32c025 commit 7850291

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

mlprimitives/custom/timeseries_preprocessing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ def cutoff_window_sequences(X, timeseries, window_size, cutoff_time=None, time_i
214214
``pandas.DataFrame`` containing the actual timeseries data. The time index
215215
and either be set as the DataFrame index or as a column.
216216
window_size (int, str or Timedelta):
217-
Numer of elements to take before the cutoff time for each sequence.
217+
If an integer is passed, it is the number of elements to take before the
218+
cutoff time for each sequence. If a string or a Timedelta object is passed,
219+
it is the period of time we take the elements from.
218220
cutoff_time (str):
219221
Optional. If given, the indicated column will be used as the cutoff time.
220222
Otherwise, the table index will be used.
@@ -245,7 +247,6 @@ def cutoff_window_sequences(X, timeseries, window_size, cutoff_time=None, time_i
245247
selected = timeseries[timeseries.index < row.Index]
246248

247249
mask = [True] * len(selected)
248-
249250
for column in columns:
250251
mask &= selected.pop(column) == getattr(row, column)
251252

@@ -258,7 +259,6 @@ def cutoff_window_sequences(X, timeseries, window_size, cutoff_time=None, time_i
258259
selected = selected.iloc[-window_size:]
259260

260261
len_selected = len(selected)
261-
262262
if (len_selected != window_size):
263263
warnings.warn((
264264
'Sequence shorter than window_size found: {} < {}. '

tests/custom/test_timeseries_preprocessing.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class CutoffWindowSequencesTest(TestCase):
246246

247247
def setUp(self):
248248
self.X = pd.DataFrame({
249-
'id': [1, 2],
249+
'id1': [1, 2],
250250
'cutoff': pd.to_datetime(['2020-01-05', '2020-01-07'])
251251
}).set_index('cutoff')
252252
self.timeseries = pd.DataFrame({
@@ -257,11 +257,11 @@ def setUp(self):
257257
)) * 2,
258258
'value1': np.arange(1, 21),
259259
'value2': np.arange(21, 41),
260-
'id': [1] * 10 + [2] * 10
260+
'id1': [1] * 10 + [2] * 10
261261
}).set_index('timestamp')
262262

263-
"""Passing cutoff_time. The indicated column will be used as the cutoff time."""
264263
def test_cutoff_time_column(self):
264+
"""Passing cutoff_time. The indicated column will be used as the cutoff time."""
265265
# setup
266266
timeseries = self.timeseries
267267
X = self.X.reset_index()
@@ -286,8 +286,8 @@ def test_cutoff_time_column(self):
286286

287287
assert_allclose(array, expected_array)
288288

289-
"""Passing time_index. The indicated column will be used as the timeseries index."""
290289
def test_time_index_column(self):
290+
"""Passing time_index. The indicated column will be used as the timeseries index."""
291291
# setup
292292
X = self.X
293293
timeseries = self.timeseries.reset_index()
@@ -312,8 +312,8 @@ def test_time_index_column(self):
312312

313313
assert_allclose(array, expected_array)
314314

315-
"""window_size accepts integer."""
316315
def test_window_size_integer(self):
316+
"""window_size accepts integer."""
317317
# setup
318318
X = self.X
319319
timeseries = self.timeseries
@@ -337,8 +337,8 @@ def test_window_size_integer(self):
337337

338338
assert_allclose(array, expected_array)
339339

340-
"""window_size accepts string."""
341340
def test_window_size_string(self):
341+
"""window_size accepts string."""
342342
# setup
343343
X = self.X
344344
timeseries = self.timeseries
@@ -362,8 +362,8 @@ def test_window_size_string(self):
362362

363363
assert_allclose(array, expected_array)
364364

365-
"""window_size accepts Timedelta object."""
366365
def test_window_size_timedelta(self):
366+
"""window_size accepts Timedelta object."""
367367
# setup
368368
X = self.X
369369
timeseries = self.timeseries
@@ -387,8 +387,8 @@ def test_window_size_timedelta(self):
387387

388388
assert_allclose(array, expected_array)
389389

390-
"""If there is not enough data for the given window_size, shape changes."""
391390
def test_not_enough_data(self):
391+
"""If there is not enough data for the given window_size, shape changes."""
392392
# setup
393393
X = self.X
394394
timeseries = self.timeseries
@@ -429,13 +429,13 @@ def test_not_enough_data(self):
429429
expected_array[1]
430430
)
431431

432-
"""Test X without any other column than cutoff_time."""
433432
def test_cutoff_time_only(self):
433+
"""Test X without any other column than cutoff_time."""
434434
# setup
435435
X = self.X
436-
del X['id']
436+
del X['id1']
437437
timeseries = self.timeseries
438-
del timeseries['id']
438+
del timeseries['id1']
439439

440440
# run
441441
array = cutoff_window_sequences(
@@ -455,3 +455,28 @@ def test_cutoff_time_only(self):
455455
])
456456

457457
assert_allclose(array, expected_array)
458+
459+
def test_multiple_filter(self):
460+
"""Test X with two identifier columns."""
461+
# setup
462+
X = self.X
463+
X['id2'] = [3, 4]
464+
timeseries = self.timeseries
465+
timeseries['id2'] = [3, 4] * 10
466+
467+
# run
468+
array = cutoff_window_sequences(
469+
X,
470+
timeseries,
471+
window_size=2,
472+
)
473+
474+
# assert
475+
expected_array = np.array([
476+
[[1, 21],
477+
[3, 23]],
478+
[[14, 34],
479+
[16, 36]]
480+
])
481+
482+
assert_allclose(array, expected_array)

0 commit comments

Comments
 (0)