Skip to content

Commit 10412f5

Browse files
Rikliamarcopeix
andauthored
Fix type of n_windows in cross_validation (#1378)
Co-authored-by: Marco <marco@nixtla.io>
1 parent 3bb21b6 commit 10412f5

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

neuralforecast/core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,7 @@ def cross_validation(
13371337
self,
13381338
df: Optional[DataFrame] = None,
13391339
static_df: Optional[DataFrame] = None,
1340-
n_windows: int = 1,
1340+
n_windows: Optional[int] = 1,
13411341
step_size: int = 1,
13421342
val_size: Optional[int] = 0,
13431343
test_size: Optional[int] = None,
@@ -1362,8 +1362,8 @@ def cross_validation(
13621362
df (pandas or polars DataFrame, optional): DataFrame with columns [`unique_id`, `ds`, `y`] and exogenous variables.
13631363
If None, a previously stored dataset is required.
13641364
static_df (pandas or polars DataFrame, optional): DataFrame with columns [`unique_id`] and static exogenous. Defaults to None.
1365-
n_windows (int): Number of windows used for cross validation.
1366-
step_size (int): Step size between each window.
1365+
n_windows (int, None): Number of windows used for cross validation. If None, define `test_size`.
1366+
step_size (int): Step size between each window.
13671367
val_size (int, optional): Length of validation size. If passed, set `n_windows=None`. Defaults to 0.
13681368
test_size (int, optional): Length of test size. If passed, set `n_windows=None`.
13691369
use_init_models (bool, optional): Use initial model passed when object was instantiated.
@@ -1418,13 +1418,19 @@ def cross_validation(
14181418
if n_windows is None and test_size is None:
14191419
raise Exception("you must define `n_windows` or `test_size`.")
14201420
if test_size is None and h is not None:
1421+
assert n_windows is not None
14211422
test_size = h + step_size * (n_windows - 1)
14221423
elif n_windows is None:
1424+
assert test_size is not None
1425+
assert h is not None
14231426
if (test_size - h) % step_size:
1424-
raise Exception("`test_size - h` should be module `step_size`")
1427+
raise Exception("`test_size - h` must be divisible by `step_size`")
14251428
n_windows = int((test_size - h) / step_size) + 1
14261429
else:
14271430
raise Exception("you must define `n_windows` or `test_size` but not both")
1431+
1432+
assert n_windows is not None
1433+
assert test_size is not None
14281434

14291435
# Recover initial model if use_init_models.
14301436
if use_init_models:

0 commit comments

Comments
 (0)