|
376 | 376 | "\n", |
377 | 377 | " **Parameters:**<br>\n", |
378 | 378 | " `h`: int, forecast horizon.<br>\n", |
379 | | - " `input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses all history.<br>\n", |
380 | | - " `inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.<br>\n", |
| 379 | + " `input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses 3 * horizon <br>\n", |
| 380 | + " `inference_input_size`: int, maximum sequence length for truncated inference. Default None uses input_size history.<br>\n", |
381 | 381 | " `cell_type`: str, type of RNN cell to use. Options: 'GRU', 'RNN', 'LSTM', 'ResLSTM', 'AttentiveLSTM'.<br>\n", |
382 | 382 | " `dilations`: int list, dilations betweem layers.<br>\n", |
383 | 383 | " `encoder_hidden_size`: int=200, units for the RNN's hidden state size.<br>\n", |
|
387 | 387 | " `futr_exog_list`: str list, future exogenous columns.<br>\n", |
388 | 388 | " `hist_exog_list`: str list, historic exogenous columns.<br>\n", |
389 | 389 | " `stat_exog_list`: str list, static exogenous columns.<br>\n", |
| 390 | + " `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n", |
390 | 391 | " `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n", |
391 | 392 | " `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n", |
392 | 393 | " `max_steps`: int, maximum number of training steps.<br>\n", |
|
396 | 397 | " `val_check_steps`: int, Number of training steps between every validation loss check.<br>\n", |
397 | 398 | " `batch_size`: int=32, number of different series in each batch.<br>\n", |
398 | 399 | " `valid_batch_size`: int=None, number of different series in each validation and test batch.<br>\n", |
| 400 | + " `windows_batch_size`: int=128, number of windows to sample in each training batch, default uses all.<br>\n", |
| 401 | + " `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch, -1 uses all.<br>\n", |
| 402 | + " `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br> \n", |
399 | 403 | " `step_size`: int=1, step size between each window of temporal data.<br>\n", |
400 | 404 | " `scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n", |
401 | 405 | " `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>\n", |
|
417 | 421 | "\n", |
418 | 422 | " def __init__(self,\n", |
419 | 423 | " h: int,\n", |
420 | | - " input_size: int,\n", |
421 | | - " inference_input_size: int = -1,\n", |
| 424 | + " input_size: int = -1,\n", |
| 425 | + " inference_input_size: Optional[int] = None,\n", |
422 | 426 | " cell_type: str = 'LSTM',\n", |
423 | 427 | " dilations: List[List[int]] = [[1, 2], [4, 8]],\n", |
424 | 428 | " encoder_hidden_size: int = 128,\n", |
|
445 | 449 | " scaler_type: str = 'robust',\n", |
446 | 450 | " random_seed: int = 1,\n", |
447 | 451 | " drop_last_loader: bool = False,\n", |
| 452 | + " alias: Optional[str] = None,\n", |
448 | 453 | " optimizer = None,\n", |
449 | 454 | " optimizer_kwargs = None,\n", |
450 | 455 | " lr_scheduler = None,\n", |
|
454 | 459 | " super(DilatedRNN, self).__init__(\n", |
455 | 460 | " h=h,\n", |
456 | 461 | " input_size=input_size,\n", |
| 462 | + " inference_input_size=inference_input_size,\n", |
457 | 463 | " futr_exog_list=futr_exog_list,\n", |
458 | 464 | " hist_exog_list=hist_exog_list,\n", |
459 | 465 | " stat_exog_list=stat_exog_list,\n", |
|
474 | 480 | " scaler_type=scaler_type,\n", |
475 | 481 | " random_seed=random_seed,\n", |
476 | 482 | " drop_last_loader=drop_last_loader,\n", |
| 483 | + " alias=alias,\n", |
477 | 484 | " optimizer=optimizer,\n", |
478 | 485 | " optimizer_kwargs=optimizer_kwargs,\n", |
479 | 486 | " lr_scheduler=lr_scheduler,\n", |
|
0 commit comments