Skip to content

Commit 4444bd8

Browse files
authored
Merge pull request #3 from LeonEthan/main
FIX: Add exclude_insample_y param to TimeXer for model loading (Nixtla#1306)
2 parents 4ccf66a + 00531d1 commit 4444bd8

File tree

5 files changed

+7
-4
lines changed

5 files changed

+7
-4
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ jobs:
1111
command: micromamba install -n base -c conda-forge -y python=3.10 git -f environment-cpu.yml
1212
- run:
1313
name: Run nbdev tests
14+
no_output_timeout: 20m
1415
command: |
1516
eval "$(micromamba shell hook --shell bash)"
1617
micromamba activate base

.github/workflows/lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
python-version: "3.10"
2020

2121
- name: Install dependencies
22-
run: pip install black nbdev==2.3.25 pre-commit
22+
run: pip install black "fastcore<1.8.0" nbdev==2.3.25 pre-commit
2323

2424
- name: Run pre-commit
2525
run: pre-commit run --files neuralforecast/*

nbs/models.timexer.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@
293293
" futr_exog_list = None,\n",
294294
" hist_exog_list = None,\n",
295295
" stat_exog_list = None,\n",
296+
" exclude_insample_y: bool = False,\n",
296297
" patch_len: int = 16,\n",
297298
" hidden_size: int = 512,\n",
298299
" n_heads: int = 8,\n",
@@ -332,7 +333,7 @@
332333
" futr_exog_list=futr_exog_list,\n",
333334
" hist_exog_list=hist_exog_list,\n",
334335
" stat_exog_list=stat_exog_list,\n",
335-
" exclude_insample_y = False,\n",
336+
" exclude_insample_y=exclude_insample_y,\n",
336337
" loss=loss,\n",
337338
" valid_loss=valid_loss,\n",
338339
" max_steps=max_steps,\n",

neuralforecast/models/timexer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def __init__(
199199
futr_exog_list=None,
200200
hist_exog_list=None,
201201
stat_exog_list=None,
202+
exclude_insample_y: bool = False,
202203
patch_len: int = 16,
203204
hidden_size: int = 512,
204205
n_heads: int = 8,
@@ -240,7 +241,7 @@ def __init__(
240241
futr_exog_list=futr_exog_list,
241242
hist_exog_list=hist_exog_list,
242243
stat_exog_list=stat_exog_list,
243-
exclude_insample_y=False,
244+
exclude_insample_y=exclude_insample_y,
244245
loss=loss,
245246
valid_loss=valid_loss,
246247
max_steps=max_steps,

settings.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ status = 2
1818
requirements = coreforecast>=0.0.6 fsspec numpy>=1.21.6 pandas>=1.3.5 torch>=2.0.0 pytorch-lightning>=2.0.0 ray[tune]>=2.2.0 optuna utilsforecast>=0.2.3
1919
spark_requirements = fugue pyspark>=3.5
2020
aws_requirements = fsspec[s3]
21-
dev_requirements = black gitpython hyperopt ipython<=8.32.0 matplotlib mypy nbdev==2.3.25 polars pre-commit pyarrow ruff s3fs transformers
21+
dev_requirements = black fastcore<=1.7.29 gitpython hyperopt ipython<=8.32.0 matplotlib mypy nbdev==2.3.25 polars pre-commit pyarrow ruff s3fs transformers
2222
nbs_path = nbs
2323
doc_path = _docs
2424
recursive = True

0 commit comments

Comments
 (0)