Skip to content

Commit 2de5a82

Browse files
tblume1992marcopeixelephaint
authored
DynamicNBEATs model (#1191)
Co-authored-by: marcopeix <marco@nixtla.io> Co-authored-by: Olivier Sprangers <osprangers@gmail.com>
1 parent 00531d1 commit 2de5a82

File tree

4 files changed

+806
-59
lines changed

4 files changed

+806
-59
lines changed

experiments/nbeats_basis/nbeats_basis_experiment.ipynb

Lines changed: 419 additions & 0 deletions
Large diffs are not rendered by default.

nbs/models.nbeats.ipynb

Lines changed: 189 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,15 @@
5959
"outputs": [],
6060
"source": [
6161
"#| export\n",
62+
"import warnings\n",
6263
"from typing import Tuple, Optional\n",
6364
"\n",
6465
"import numpy as np\n",
66+
"from numpy.polynomial.legendre import Legendre\n",
67+
"from numpy.polynomial.chebyshev import Chebyshev\n",
6568
"import torch\n",
6669
"import torch.nn as nn\n",
70+
"from scipy.interpolate import BSpline\n",
6771
"\n",
6872
"from neuralforecast.losses.pytorch import MAE\n",
6973
"from neuralforecast.common._base_model import BaseModel"
@@ -87,6 +91,143 @@
8791
"import matplotlib.pyplot as plt"
8892
]
8993
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": null,
97+
"id": "b3b21a80",
98+
"metadata": {},
99+
"outputs": [],
100+
"source": [
101+
"#| exporti\n",
102+
"def generate_legendre_basis(length, n_basis):\n",
103+
" \"\"\"\n",
104+
" Generates Legendre polynomial basis functions.\n",
105+
"\n",
106+
" Parameters:\n",
107+
" - n_points (int): Number of data points.\n",
108+
" - n_functions (int): Number of basis functions to generate.\n",
109+
"\n",
110+
" Returns:\n",
111+
" - legendre_basis (ndarray): An array of Legendre basis functions.\n",
112+
" \"\"\"\n",
113+
" x = np.linspace(-1, 1, length) # Legendre polynomials are defined on [-1, 1]\n",
114+
" legendre_basis = np.zeros((length, n_basis))\n",
115+
" for i in range(n_basis):\n",
116+
" # Legendre polynomial of degree i\n",
117+
" P_i = Legendre.basis(i)\n",
118+
" legendre_basis[:, i] = P_i(x)\n",
119+
" return legendre_basis\n",
120+
"\n",
121+
"def generate_polynomial_basis(length, n_basis):\n",
122+
" \"\"\"\n",
123+
" Generates standard polynomial basis functions.\n",
124+
"\n",
125+
" Parameters:\n",
126+
" - n_points (int): Number of data points.\n",
127+
" - n_functions (int): Number of polynomial functions to generate.\n",
128+
"\n",
129+
" Returns:\n",
130+
" - poly_basis (ndarray): An array of polynomial basis functions.\n",
131+
" \"\"\"\n",
132+
" return np.concatenate([np.power(np.arange(length, dtype=float) / length, i)[None, :]\n",
133+
" for i in range(n_basis)]).T\n",
134+
"\n",
135+
"\n",
136+
"def generate_changepoint_basis(length, n_basis):\n",
137+
" \"\"\"\n",
138+
" Generates changepoint basis functions with automatically spaced changepoints.\n",
139+
"\n",
140+
" Parameters:\n",
141+
" - n_points (int): Number of data points.\n",
142+
" - n_functions (int): Number of changepoint functions to generate.\n",
143+
"\n",
144+
" Returns:\n",
145+
" - changepoint_basis (ndarray): An array of changepoint basis functions.\n",
146+
" \"\"\"\n",
147+
" x = np.linspace(0, 1, length)[:, None] # Shape: (length, 1)\n",
148+
" changepoint_locations = np.linspace(0, 1, n_basis + 1)[1:][None, :] # Shape: (1, n_basis)\n",
149+
" return np.maximum(0, x - changepoint_locations)\n",
150+
"\n",
151+
"def generate_piecewise_linear_basis(length, n_basis):\n",
152+
" \"\"\"\n",
153+
" Generates piecewise linear basis functions (linear splines).\n",
154+
"\n",
155+
" Parameters:\n",
156+
" - n_points (int): Number of data points.\n",
157+
" - n_functions (int): Number of piecewise linear basis functions to generate.\n",
158+
"\n",
159+
" Returns:\n",
160+
" - pw_linear_basis (ndarray): An array of piecewise linear basis functions.\n",
161+
" \"\"\"\n",
162+
" x = np.linspace(0, 1, length)\n",
163+
" knots = np.linspace(0, 1, n_basis+1)\n",
164+
" pw_linear_basis = np.zeros((length, n_basis))\n",
165+
" for i in range(1, n_basis):\n",
166+
" pw_linear_basis[:, i] = np.maximum(0, np.minimum((x - knots[i-1]) / (knots[i] - knots[i-1]), (knots[i+1] - x) / (knots[i+1] - knots[i])))\n",
167+
" return pw_linear_basis\n",
168+
"\n",
169+
"def generate_linear_hat_basis(length, n_basis):\n",
170+
" x = np.linspace(0, 1, length)[:, None] # Shape: (length, 1)\n",
171+
" centers = np.linspace(0, 1, n_basis)[None, :] # Shape: (1, n_basis)\n",
172+
" width = 1.0 / (n_basis - 1)\n",
173+
" \n",
174+
" # Create triangular functions using piecewise linear equations\n",
175+
" return np.maximum(0, 1 - np.abs(x - centers) / width)\n",
176+
"\n",
177+
"def generate_spline_basis(length, n_basis):\n",
178+
" \"\"\"\n",
179+
" Generates cubic spline basis functions.\n",
180+
"\n",
181+
" Parameters:\n",
182+
" - n_points (int): Number of data points.\n",
183+
" - n_functions (int): Number of basis functions.\n",
184+
"\n",
185+
" Returns:\n",
186+
" - spline_basis (ndarray): An array of cubic spline basis functions.\n",
187+
" \"\"\"\n",
188+
" if n_basis < 4:\n",
189+
" raise ValueError(f\"To use the spline basis, n_basis must be set to 4 or more. Current value is {n_basis}\")\n",
190+
" x = np.linspace(0, 1, length)\n",
191+
" knots = np.linspace(0, 1, n_basis - 2)\n",
192+
" t = np.concatenate(([0, 0, 0], knots, [1, 1, 1]))\n",
193+
" degree = 3\n",
194+
" # Create basis coefficient matrix once\n",
195+
" coefficients = np.eye(n_basis)\n",
196+
" # Create single BSpline object with all coefficients\n",
197+
" spline = BSpline(t, coefficients.T, degree)\n",
198+
" return spline(x)\n",
199+
"\n",
200+
"def generate_chebyshev_basis(length, n_basis):\n",
201+
" \"\"\"\n",
202+
" Generates Chebyshev polynomial basis functions.\n",
203+
"\n",
204+
" Parameters:\n",
205+
" - n_points (int): Number of data points.\n",
206+
" - n_functions (int): Number of Chebyshev polynomials to generate.\n",
207+
"\n",
208+
" Returns:\n",
209+
" - chebyshev_basis (ndarray): An array of Chebyshev polynomial basis functions.\n",
210+
" \"\"\"\n",
211+
" x = np.linspace(-1, 1, length)\n",
212+
" chebyshev_basis = np.zeros((length, n_basis))\n",
213+
" for i in range(n_basis):\n",
214+
" T_i = Chebyshev.basis(i)\n",
215+
" chebyshev_basis[:, i] = T_i(x)\n",
216+
" return chebyshev_basis\n",
217+
"\n",
218+
"def get_basis(length, n_basis, basis):\n",
219+
" basis_dict = {\n",
220+
" 'legendre': generate_legendre_basis,\n",
221+
" 'polynomial': generate_polynomial_basis,\n",
222+
" 'changepoint': generate_changepoint_basis,\n",
223+
" 'piecewise_linear': generate_piecewise_linear_basis,\n",
224+
" 'linear_hat': generate_linear_hat_basis,\n",
225+
" 'spline': generate_spline_basis,\n",
226+
" 'chebyshev': generate_chebyshev_basis\n",
227+
" }\n",
228+
" return basis_dict[basis](length, n_basis+1)"
229+
]
230+
},
90231
{
91232
"cell_type": "code",
92233
"execution_count": null,
@@ -110,19 +251,19 @@
110251
" return backcast, forecast\n",
111252
"\n",
112253
"class TrendBasis(nn.Module):\n",
113-
" def __init__(self, degree_of_polynomial: int,\n",
114-
" backcast_size: int, forecast_size: int,\n",
115-
" out_features: int=1):\n",
254+
" def __init__(self, \n",
255+
" n_basis: int,\n",
256+
" backcast_size: int,\n",
257+
" forecast_size: int,\n",
258+
" out_features: int=1,\n",
259+
" basis='polynomial'):\n",
116260
" super().__init__()\n",
117261
" self.out_features = out_features\n",
118-
" polynomial_size = degree_of_polynomial + 1\n",
119262
" self.backcast_basis = nn.Parameter(\n",
120-
" torch.tensor(np.concatenate([np.power(np.arange(backcast_size, dtype=float) / backcast_size, i)[None, :]\n",
121-
" for i in range(polynomial_size)]), dtype=torch.float32), requires_grad=False)\n",
263+
" torch.tensor(get_basis(backcast_size, n_basis, basis).T, dtype=torch.float32), requires_grad=False)\n",
122264
" self.forecast_basis = nn.Parameter(\n",
123-
" torch.tensor(np.concatenate([np.power(np.arange(forecast_size, dtype=float) / forecast_size, i)[None, :]\n",
124-
" for i in range(polynomial_size)]), dtype=torch.float32), requires_grad=False)\n",
125-
" \n",
265+
" torch.tensor(get_basis(forecast_size, n_basis, basis).T, dtype=torch.float32), requires_grad=False)\n",
266+
"\n",
126267
" def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
127268
" polynomial_size = self.forecast_basis.shape[0] # [polynomial_size, L+H]\n",
128269
" backcast_theta = theta[:, :polynomial_size]\n",
@@ -133,8 +274,10 @@
133274
" return backcast, forecast\n",
134275
"\n",
135276
"class SeasonalityBasis(nn.Module):\n",
136-
" def __init__(self, harmonics: int, \n",
137-
" backcast_size: int, forecast_size: int,\n",
277+
" def __init__(self, \n",
278+
" harmonics: int, \n",
279+
" backcast_size: int, \n",
280+
" forecast_size: int,\n",
138281
" out_features: int=1):\n",
139282
" super().__init__()\n",
140283
" self.out_features = out_features\n",
@@ -194,8 +337,6 @@
194337
" basis: nn.Module, \n",
195338
" dropout_prob: float, \n",
196339
" activation: str):\n",
197-
" \"\"\"\n",
198-
" \"\"\"\n",
199340
" super().__init__()\n",
200341
"\n",
201342
" self.dropout_prob = dropout_prob\n",
@@ -212,7 +353,6 @@
212353
"\n",
213354
" if self.dropout_prob>0:\n",
214355
" raise NotImplementedError('dropout')\n",
215-
" #hidden_layers.append(nn.Dropout(p=self.dropout_prob))\n",
216356
"\n",
217357
" output_layer = [nn.Linear(in_features=mlp_units[-1][1], out_features=n_theta)]\n",
218358
" layers = hidden_layers + output_layer\n",
@@ -248,7 +388,9 @@
248388
" `h`: int, forecast horizon.<br>\n",
249389
" `input_size`: int, considered autorregresive inputs (lags), y=[1,2,3,4] input_size=2 -> lags=[1,2].<br>\n",
250390
" `n_harmonics`: int, Number of harmonic terms for seasonality stack type. Note that len(n_harmonics) = len(stack_types). Note that it will only be used if a seasonality stack is used.<br>\n",
251-
" `n_polynomials`: int, polynomial degree for trend stack. Note that len(n_polynomials) = len(stack_types). Note that it will only be used if a trend stack is used.<br>\n",
391+
" `n_polynomials`: int, DEPRECATED - polynomial degree for trend stack. Note that len(n_polynomials) = len(stack_types). Note that it will only be used if a trend stack is used.<br>\n",
392+
" `basis`: str, Type of basis function to use in the trend stack. Choose one from ['legendre', 'polynomial', 'changepoint', 'piecewise_linear', 'linear_hat', 'spline', 'chebyshev']<br>\n",
393+
" `n_basis`: int, the degree of the basis function for the trend stack. Note that it will only be used if a trend stack is used.<br>\n",
252394
" `stack_types`: List[str], List of stack types. Subset from ['seasonality', 'trend', 'identity'].<br>\n",
253395
" `n_blocks`: List[int], Number of blocks for each stack. Note that len(n_blocks) = len(stack_types).<br>\n",
254396
" `mlp_units`: List[List[int]], Structure of hidden layers for each stack type. Each internal list should contain the number of units of each hidden layer. Note that len(n_hidden) = len(stack_types).<br>\n",
@@ -294,7 +436,9 @@
294436
" h,\n",
295437
" input_size,\n",
296438
" n_harmonics: int = 2,\n",
297-
" n_polynomials: int = 2,\n",
439+
" n_polynomials: Optional[int] = None,\n",
440+
" n_basis: int = 2,\n",
441+
" basis: str = 'polynomial',\n",
298442
" stack_types: list = ['identity', 'trend', 'seasonality'],\n",
299443
" n_blocks: list = [1, 1, 1],\n",
300444
" mlp_units: list = 3 * [[512, 512]],\n",
@@ -358,6 +502,15 @@
358502
" dataloader_kwargs=dataloader_kwargs,\n",
359503
" **trainer_kwargs)\n",
360504
"\n",
505+
" # Raise deprecation warning\n",
506+
" if n_polynomials is not None:\n",
507+
" warnings.warn(\n",
508+
" \"The parameter n_polynomials will be deprecated in favor of n_basis and basis and it is currently ignored.\\n\"\n",
509+
" \"The basis parameter defines the basis function to be used in the trend stack.\\n\"\n",
510+
" \"The n_basis defines the degree of the basis function used in the trend stack.\",\n",
511+
" DeprecationWarning\n",
512+
" )\n",
513+
" \n",
361514
" # Architecture\n",
362515
" blocks = self.create_stack(h=h,\n",
363516
" input_size=input_size,\n",
@@ -367,18 +520,23 @@
367520
" dropout_prob_theta=dropout_prob_theta,\n",
368521
" activation=activation,\n",
369522
" shared_weights=shared_weights,\n",
370-
" n_polynomials=n_polynomials, \n",
371-
" n_harmonics=n_harmonics)\n",
523+
" n_harmonics=n_harmonics,\n",
524+
" n_basis=n_basis,\n",
525+
" basis_type=basis)\n",
372526
" self.blocks = torch.nn.ModuleList(blocks)\n",
373527
"\n",
374-
" def create_stack(self, stack_types, \n",
528+
" def create_stack(self, \n",
529+
" stack_types, \n",
375530
" n_blocks, \n",
376531
" input_size, \n",
377532
" h, \n",
378533
" mlp_units, \n",
379534
" dropout_prob_theta, \n",
380-
" activation, shared_weights,\n",
381-
" n_polynomials, n_harmonics): \n",
535+
" activation, \n",
536+
" shared_weights,\n",
537+
" n_harmonics, \n",
538+
" n_basis, \n",
539+
" basis_type): \n",
382540
"\n",
383541
" block_list = []\n",
384542
" for i in range(len(stack_types)):\n",
@@ -392,14 +550,17 @@
392550
" n_theta = 2 * (self.loss.outputsize_multiplier + 1) * \\\n",
393551
" int(np.ceil(n_harmonics / 2 * h) - (n_harmonics - 1))\n",
394552
" basis = SeasonalityBasis(harmonics=n_harmonics,\n",
395-
" backcast_size=input_size,forecast_size=h,\n",
553+
" backcast_size=input_size,\n",
554+
" forecast_size=h,\n",
396555
" out_features=self.loss.outputsize_multiplier)\n",
397556
"\n",
398557
" elif stack_types[i] == 'trend':\n",
399-
" n_theta = (self.loss.outputsize_multiplier + 1) * (n_polynomials + 1)\n",
400-
" basis = TrendBasis(degree_of_polynomial=n_polynomials,\n",
401-
" backcast_size=input_size,forecast_size=h,\n",
402-
" out_features=self.loss.outputsize_multiplier)\n",
558+
" n_theta = (self.loss.outputsize_multiplier + 1) * (n_basis + 1)\n",
559+
" basis = TrendBasis(n_basis=n_basis,\n",
560+
" backcast_size=input_size,\n",
561+
" forecast_size=h,\n",
562+
" out_features=self.loss.outputsize_multiplier,\n",
563+
" basis=basis_type)\n",
403564
"\n",
404565
" elif stack_types[i] == 'identity':\n",
405566
" n_theta = input_size + self.loss.outputsize_multiplier * h\n",
@@ -674,6 +835,8 @@
674835
"Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
675836
"\n",
676837
"model = NBEATS(h=12, input_size=24,\n",
838+
" basis='changepoint',\n",
839+
" n_basis=2,\n",
677840
" loss=DistributionLoss(distribution='Poisson', level=[80, 90]),\n",
678841
" stack_types = ['identity', 'trend', 'seasonality'],\n",
679842
" max_steps=100,\n",

neuralforecast/_modidx.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,23 @@
871871
'neuralforecast.models.nbeats.TrendBasis.__init__': ( 'models.nbeats.html#trendbasis.__init__',
872872
'neuralforecast/models/nbeats.py'),
873873
'neuralforecast.models.nbeats.TrendBasis.forward': ( 'models.nbeats.html#trendbasis.forward',
874-
'neuralforecast/models/nbeats.py')},
874+
'neuralforecast/models/nbeats.py'),
875+
'neuralforecast.models.nbeats.generate_changepoint_basis': ( 'models.nbeats.html#generate_changepoint_basis',
876+
'neuralforecast/models/nbeats.py'),
877+
'neuralforecast.models.nbeats.generate_chebyshev_basis': ( 'models.nbeats.html#generate_chebyshev_basis',
878+
'neuralforecast/models/nbeats.py'),
879+
'neuralforecast.models.nbeats.generate_legendre_basis': ( 'models.nbeats.html#generate_legendre_basis',
880+
'neuralforecast/models/nbeats.py'),
881+
'neuralforecast.models.nbeats.generate_linear_hat_basis': ( 'models.nbeats.html#generate_linear_hat_basis',
882+
'neuralforecast/models/nbeats.py'),
883+
'neuralforecast.models.nbeats.generate_piecewise_linear_basis': ( 'models.nbeats.html#generate_piecewise_linear_basis',
884+
'neuralforecast/models/nbeats.py'),
885+
'neuralforecast.models.nbeats.generate_polynomial_basis': ( 'models.nbeats.html#generate_polynomial_basis',
886+
'neuralforecast/models/nbeats.py'),
887+
'neuralforecast.models.nbeats.generate_spline_basis': ( 'models.nbeats.html#generate_spline_basis',
888+
'neuralforecast/models/nbeats.py'),
889+
'neuralforecast.models.nbeats.get_basis': ( 'models.nbeats.html#get_basis',
890+
'neuralforecast/models/nbeats.py')},
875891
'neuralforecast.models.nbeatsx': { 'neuralforecast.models.nbeatsx.ExogenousBasis': ( 'models.nbeatsx.html#exogenousbasis',
876892
'neuralforecast/models/nbeatsx.py'),
877893
'neuralforecast.models.nbeatsx.ExogenousBasis.__init__': ( 'models.nbeatsx.html#exogenousbasis.__init__',

0 commit comments

Comments
 (0)