diff --git a/nbs/common.modules.ipynb b/nbs/common.modules.ipynb index f90e936da..3b5d534e0 100644 --- a/nbs/common.modules.ipynb +++ b/nbs/common.modules.ipynb @@ -35,6 +35,7 @@ "source": [ "#| export\n", "import math\n", + "import numpy as np\n", "\n", "import torch\n", "import torch.nn as nn\n", @@ -409,40 +410,93 @@ "source": [ "#| export\n", "class AttentionLayer(nn.Module):\n", - " def __init__(self, attention, hidden_size, n_head, d_keys=None,\n", + " def __init__(self, attention, hidden_size, n_heads, d_keys=None,\n", " d_values=None):\n", " super(AttentionLayer, self).__init__()\n", "\n", - " d_keys = d_keys or (hidden_size // n_head)\n", - " d_values = d_values or (hidden_size // n_head)\n", + " d_keys = d_keys or (hidden_size // n_heads)\n", + " d_values = d_values or (hidden_size // n_heads)\n", "\n", " self.inner_attention = attention\n", - " self.query_projection = nn.Linear(hidden_size, d_keys * n_head)\n", - " self.key_projection = nn.Linear(hidden_size, d_keys * n_head)\n", - " self.value_projection = nn.Linear(hidden_size, d_values * n_head)\n", - " self.out_projection = nn.Linear(d_values * n_head, hidden_size)\n", - " self.n_head = n_head\n", + " self.query_projection = nn.Linear(hidden_size, d_keys * n_heads)\n", + " self.key_projection = nn.Linear(hidden_size, d_keys * n_heads)\n", + " self.value_projection = nn.Linear(hidden_size, d_values * n_heads)\n", + " self.out_projection = nn.Linear(d_values * n_heads, hidden_size)\n", + " self.n_heads = n_heads\n", "\n", - " def forward(self, queries, keys, values, attn_mask):\n", + " def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n", " B, L, _ = queries.shape\n", " _, S, _ = keys.shape\n", - " H = self.n_head\n", + " H = self.n_heads\n", "\n", " queries = self.query_projection(queries).view(B, L, H, -1)\n", " keys = self.key_projection(keys).view(B, S, H, -1)\n", " values = self.value_projection(values).view(B, S, H, -1)\n", "\n", " out, attn = self.inner_attention(\n", - " queries,\n", - " keys,\n", - " values,\n", - " attn_mask\n", + " queries=queries,\n", + " keys=keys,\n", + " values=values,\n", + " attn_mask=attn_mask,\n", + " tau=tau,\n", + " delta=delta\n", " )\n", " out = out.view(B, L, -1)\n", "\n", " return self.out_projection(out), attn" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class TriangularCausalMask():\n", + " \"\"\"\n", + " TriangularCausalMask\n", + " \"\"\" \n", + " def __init__(self, B, L, device=\"cpu\"):\n", + " mask_shape = [B, 1, L, L]\n", + " with torch.no_grad():\n", + " self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)\n", + "\n", + " @property\n", + " def mask(self):\n", + " return self._mask\n", + "\n", + "class FullAttention(nn.Module):\n", + " def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):\n", + " super(FullAttention, self).__init__()\n", + " self.scale = scale\n", + " self.mask_flag = mask_flag\n", + " self.output_attention = output_attention\n", + " self.dropout = nn.Dropout(attention_dropout)\n", + "\n", + " def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n", + " B, L, H, E = queries.shape\n", + " _, S, _, D = values.shape\n", + " scale = self.scale or 1. / math.sqrt(E)\n", + "\n", + " scores = torch.einsum(\"blhe,bshe->bhls\", queries, keys)\n", + "\n", + " if self.mask_flag:\n", + " if attn_mask is None:\n", + " attn_mask = TriangularCausalMask(B, L, device=queries.device)\n", + "\n", + " scores.masked_fill_(attn_mask.mask, -np.inf)\n", + "\n", + " A = self.dropout(torch.softmax(scale * scores, dim=-1))\n", + " V = torch.einsum(\"bhls,bshd->blhd\", A, values)\n", + "\n", + " if self.output_attention:\n", + " return V.contiguous(), A\n", + " else:\n", + " return V.contiguous(), None " + ] + }, { "cell_type": "code", "execution_count": null, @@ -570,6 +624,26 @@ " if self.temporal_embedding is not None:\n", " x = x + self.temporal_embedding(x_mark) \n", "\n", + " return self.dropout(x)\n", + "\n", + "class DataEmbedding_inverted(nn.Module):\n", + " \"\"\"\n", + " DataEmbedding_inverted\n", + " \"\"\" \n", + " def __init__(self, c_in, hidden_size, dropout=0.1):\n", + " super(DataEmbedding_inverted, self).__init__()\n", + " self.value_embedding = nn.Linear(c_in, hidden_size)\n", + " self.dropout = nn.Dropout(p=dropout)\n", + "\n", + " def forward(self, x, x_mark):\n", + " x = x.permute(0, 2, 1)\n", + " # x: [Batch Variate Time]\n", + " if x_mark is None:\n", + " x = self.value_embedding(x)\n", + " else:\n", + " # the potential to take covariates (e.g. timestamps) as tokens\n", + " x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) \n", + " # x: [Batch Variate hidden_size]\n", " return self.dropout(x)" ] }, diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 5138b5e68..6e663ded2 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -93,7 +93,7 @@ " StemGNN, PatchTST, TimesNet, TimeLLM, TSMixer, TSMixerx,\n", " MLPMultivariate, iTransformer,\n", " BiTCN, TiDE, DeepNPTS, SOFTS,\n", - " TimeMixer, KAN, RMoK\n", + " TimeMixer, KAN, RMoK, TimeXer\n", ")\n", "from neuralforecast.common._base_auto import BaseAuto, MockTrial\n", "from neuralforecast.utils import PredictionIntervals, get_prediction_interval_method" @@ -247,7 +247,8 @@ " 'softs': SOFTS, 'autosofts': SOFTS,\n", " 'timemixer': TimeMixer, 'autotimemixer': TimeMixer,\n", " 'kan': KAN, 'autokan': KAN,\n", - " 'rmok': RMoK, 'autormok': RMoK\n", + " 'rmok': RMoK, 'autormok': RMoK,\n", + " 'timexer': TimeXer, 'autotimexer': TimeXer\n", "}" ] }, diff --git a/nbs/docs/capabilities/01_overview.ipynb b/nbs/docs/capabilities/01_overview.ipynb index 11b964a7f..7c968c338 100644 --- a/nbs/docs/capabilities/01_overview.ipynb +++ b/nbs/docs/capabilities/01_overview.ipynb @@ -43,7 +43,8 @@ "|`TiDE` | `AutoTiDE` | MLP | Univariate | Direct | F/H/S | \n", "|`TimeMixer` | `AutoTimeMixer` | MLP | Multivariate | Direct | - | \n", "|`TimeLLM` | - | LLM | Univariate | Direct | - | \n", - "|`TimesNet` | `AutoTimesNet` | CNN | Univariate | Direct | F | \n", + "|`TimesNet` | `AutoTimesNet` | CNN | Univariate | Direct | F |\n", + "|`TimeXer` | `AutoTimeXer` | Transformer | Multivariate | Direct | F | \n", "|`TSMixer` | `AutoTSMixer` | MLP | Multivariate | Direct | - | \n", "|`TSMixerx` | `AutoTSMixerx` | MLP | Multivariate | Direct | F/H/S | \n", "|`VanillaTransformer` | `AutoVanillaTransformer` | Transformer | Univariate | Direct | F | \n", diff --git a/nbs/imgs_models/timexer.png b/nbs/imgs_models/timexer.png new file mode 100644 index 000000000..99a6e22ab Binary files /dev/null and b/nbs/imgs_models/timexer.png differ diff --git a/nbs/mint.json b/nbs/mint.json index b07f8ae33..18f868293 100644 --- a/nbs/mint.json +++ b/nbs/mint.json @@ -131,6 +131,7 @@ "models.timellm.html", "models.timemixer.html", "models.timesnet.html", + "models.timexer.html", "models.tsmixer.html", "models.tsmixerx.html", "models.vanillatransformer.html" diff --git a/nbs/models.informer.ipynb b/nbs/models.informer.ipynb index dffc74e8b..74264b55f 100644 --- a/nbs/models.informer.ipynb +++ b/nbs/models.informer.ipynb @@ -216,7 +216,7 @@ " else:\n", " return (context_in, None)\n", "\n", - " def forward(self, queries, keys, values, attn_mask):\n", + " def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n", " B, L_Q, H, D = queries.shape\n", " _, L_K, _, _ = keys.shape\n", "\n", diff --git a/nbs/models.ipynb b/nbs/models.ipynb index 018525399..012190712 100644 --- a/nbs/models.ipynb +++ b/nbs/models.ipynb @@ -64,6 +64,7 @@ "from neuralforecast.models.patchtst import PatchTST\n", "from neuralforecast.models.timesnet import TimesNet\n", "from neuralforecast.models.itransformer import iTransformer\n", + "from neuralforecast.models.timexer import TimeXer\n", "\n", "from neuralforecast.models.kan import KAN\n", "from neuralforecast.models.rmok import RMoK\n", @@ -3430,6 +3431,157 @@ "model.fit(dataset=dataset)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "34660732", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class AutoTimeXer(BaseAuto):\n", + "\n", + " default_config = {\n", + " \"input_size_multiplier\": [1, 2, 3, 4, 5],\n", + " \"h\": None,\n", + " \"n_series\": None,\n", + " \"hidden_size\": tune.choice([128, 256, 512]),\n", + " \"n_heads\": tune.choice([4, 8]),\n", + " \"learning_rate\": tune.loguniform(1e-4, 1e-1),\n", + " \"scaler_type\": tune.choice([None, 'robust', 'standard']),\n", + " \"max_steps\": tune.choice([500, 1000, 2000]),\n", + " \"batch_size\": tune.choice([32, 64, 128, 256]),\n", + " \"loss\": None,\n", + " \"random_seed\": tune.randint(1, 20),\n", + " }\n", + "\n", + " def __init__(self,\n", + " h,\n", + " n_series,\n", + " loss=MAE(),\n", + " valid_loss=None,\n", + " config=None, \n", + " search_alg=BasicVariantGenerator(random_state=1),\n", + " num_samples=10,\n", + " refit_with_val=False,\n", + " cpus=cpu_count(),\n", + " gpus=torch.cuda.device_count(),\n", + " verbose=False,\n", + " alias=None,\n", + " backend='ray',\n", + " callbacks=None):\n", + " \n", + " # Define search space, input/output sizes\n", + " if config is None:\n", + " config = self.get_default_config(h=h, backend=backend, n_series=n_series) \n", + "\n", + " # Always use n_series from parameters, raise exception with Optuna because we can't enforce it\n", + " if backend == 'ray':\n", + " config['n_series'] = n_series\n", + " elif backend == 'optuna':\n", + " mock_trial = MockTrial()\n", + " if ('n_series' in config(mock_trial) and config(mock_trial)['n_series'] != n_series) or ('n_series' not in config(mock_trial)):\n", + " raise Exception(f\"config needs 'n_series': {n_series}\") \n", + "\n", + " super(AutoTimeXer, self).__init__(\n", + " cls_model=TimeXer, \n", + " h=h,\n", + " loss=loss,\n", + " valid_loss=valid_loss,\n", + " config=config,\n", + " search_alg=search_alg,\n", + " num_samples=num_samples, \n", + " refit_with_val=refit_with_val,\n", + " cpus=cpus,\n", + " gpus=gpus,\n", + " verbose=verbose,\n", + " alias=alias,\n", + " backend=backend,\n", + " callbacks=callbacks, \n", + " )\n", + "\n", + " @classmethod\n", + " def get_default_config(cls, h, backend, n_series):\n", + " config = cls.default_config.copy() \n", + " config['input_size'] = tune.choice([h * x \\\n", + " for x in config[\"input_size_multiplier\"]])\n", + "\n", + " # Rolling windows with step_size=1 or step_size=h\n", + " # See `BaseWindows` and `BaseRNN`'s create_windows\n", + " config['step_size'] = tune.choice([1, h])\n", + " del config[\"input_size_multiplier\"]\n", + " if backend == 'optuna':\n", + " # Always use n_series from parameters\n", + " config['n_series'] = n_series\n", + " config = cls._ray_config_to_optuna(config) \n", + "\n", + " return config " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de761efc", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(AutoTimeXer, title_level=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f08f23a2", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "# Use your own config or AutoTimeXer.default_config\n", + "config = dict(max_steps=1, val_check_steps=1, input_size=12, patch_len=12)\n", + "model = AutoTimeXer(h=12, n_series=1, config=config, num_samples=1, cpus=1)\n", + "\n", + "# Fit and predict\n", + "model.fit(dataset=dataset)\n", + "y_hat = model.predict(dataset=dataset)\n", + "\n", + "# Optuna\n", + "model = AutoTimeXer(h=12, n_series=1, config=None, backend='optuna')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8488c991", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# Check Optuna\n", + "assert model.config(MockTrial())['h'] == 12\n", + "\n", + "# Unit test to test that Auto* model contains all required arguments from BaseAuto\n", + "test_args(AutoTimeXer, exclude_args=['cls_model']) \n", + "\n", + "# Unit test for situation: Optuna with updated default config\n", + "my_config = AutoTimeXer.get_default_config(h=12, n_series=1, backend='optuna')\n", + "def my_config_new(trial):\n", + " config = {**my_config(trial)}\n", + " config.update({'max_steps': 1, 'val_check_steps': 1, 'input_size': 12, 'patch_len': 12})\n", + " return config\n", + "\n", + "model = AutoTimeXer(h=12, n_series=1, config=my_config_new, backend='optuna', num_samples=1, cpus=1)\n", + "model.fit(dataset=dataset)\n", + "\n", + "# Unit test for situation: Ray with updated default config\n", + "my_config = AutoTimeXer.get_default_config(h=12, n_series=1, backend='ray')\n", + "my_config['max_steps'] = 1\n", + "my_config['val_check_steps'] = 1\n", + "my_config['input_size'] = 12\n", + "my_config['patch_len'] = 12\n", + "model = AutoTimeXer(h=12, n_series=1, config=my_config, backend='ray', num_samples=1, cpus=1)\n", + "model.fit(dataset=dataset)" + ] + }, { "attachments": {}, "cell_type": "markdown", diff --git a/nbs/models.itransformer.ipynb b/nbs/models.itransformer.ipynb index e8b6f15d7..c54e2a98e 100644 --- a/nbs/models.itransformer.ipynb +++ b/nbs/models.itransformer.ipynb @@ -66,125 +66,23 @@ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", - "import numpy as np\n", - "\n", - "from math import sqrt\n", - "\n", "from neuralforecast.losses.pytorch import MAE\n", "from neuralforecast.common._base_multivariate import BaseMultivariate\n", "\n", - "from neuralforecast.common._modules import TransEncoder, TransEncoderLayer, AttentionLayer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 1. Auxiliary functions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1.1 Attention" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "\n", - "class TriangularCausalMask():\n", - " \"\"\"\n", - " TriangularCausalMask\n", - " \"\"\" \n", - " def __init__(self, B, L, device=\"cpu\"):\n", - " mask_shape = [B, 1, L, L]\n", - " with torch.no_grad():\n", - " self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)\n", - "\n", - " @property\n", - " def mask(self):\n", - " return self._mask\n", - "\n", - "class FullAttention(nn.Module):\n", - " \"\"\"\n", - " FullAttention\n", - " \"\"\" \n", - " def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):\n", - " super(FullAttention, self).__init__()\n", - " self.scale = scale\n", - " self.mask_flag = mask_flag\n", - " self.output_attention = output_attention\n", - " self.dropout = nn.Dropout(attention_dropout)\n", - "\n", - " def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n", - " B, L, H, E = queries.shape\n", - " _, S, _, D = values.shape\n", - " scale = self.scale or 1. / sqrt(E)\n", - "\n", - " scores = torch.einsum(\"blhe,bshe->bhls\", queries, keys)\n", - "\n", - " if self.mask_flag:\n", - " if attn_mask is None:\n", - " attn_mask = TriangularCausalMask(B, L, device=queries.device)\n", - "\n", - " scores.masked_fill_(attn_mask.mask, -np.inf)\n", - "\n", - " A = self.dropout(torch.softmax(scale * scores, dim=-1))\n", - " V = torch.einsum(\"bhls,bshd->blhd\", A, values)\n", - "\n", - " if self.output_attention:\n", - " return (V.contiguous(), A)\n", - " else:\n", - " return (V.contiguous(), None) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1.2 Inverted embedding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "\n", - "class DataEmbedding_inverted(nn.Module):\n", - " \"\"\"\n", + "from neuralforecast.common._modules import (\n", + " TransEncoder, \n", + " TransEncoderLayer, \n", + " AttentionLayer, \n", + " FullAttention, \n", " DataEmbedding_inverted\n", - " \"\"\" \n", - " def __init__(self, c_in, hidden_size, dropout=0.1):\n", - " super(DataEmbedding_inverted, self).__init__()\n", - " self.value_embedding = nn.Linear(c_in, hidden_size)\n", - " self.dropout = nn.Dropout(p=dropout)\n", - "\n", - " def forward(self, x, x_mark):\n", - " x = x.permute(0, 2, 1)\n", - " # x: [Batch Variate Time]\n", - " if x_mark is None:\n", - " x = self.value_embedding(x)\n", - " else:\n", - " # the potential to take covariates (e.g. timestamps) as tokens\n", - " x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) \n", - " # x: [Batch Variate hidden_size]\n", - " return self.dropout(x)" + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 2. Model" + "# 1. Model" ] }, { @@ -410,7 +308,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# 3. Usage example" + "# 2. Usage example" ] }, { diff --git a/nbs/models.timexer.ipynb b/nbs/models.timexer.ipynb new file mode 100644 index 000000000..44565dce5 --- /dev/null +++ b/nbs/models.timexer.ipynb @@ -0,0 +1,528 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp models.timexer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "from nbdev.showdoc import show_doc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TimeXer\n", + "\n", + "TimeXer empowers the canonical Transformer with the ability to reconcile endogenous and exogenous information, where patch-wise self-attention and variate-wise cross-attention are used simultaneously.\n", + "\n", + "**References**\n", + "- [Yuxuan Wang, Haixu Wu, Jiaxiang Dong, Guo Qin, Haoran Zhang, Yong Liu, Yunzhong Qiu, Jianmin Wang, Mingsheng Long. \"TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables\"](https://arxiv.org/abs/2402.19072)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Figure 1. Architecture of TimeXer.](imgs_models/timexer.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "from neuralforecast.losses.pytorch import MAE\n", + "from neuralforecast.common._base_multivariate import BaseMultivariate\n", + "from neuralforecast.common._modules import (\n", + " DataEmbedding_inverted, \n", + " PositionalEmbedding,\n", + " FullAttention,\n", + " AttentionLayer\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Auxiliary functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class FlattenHead(nn.Module):\n", + " def __init__(self, n_vars, nf, target_window, head_dropout=0):\n", + " super().__init__()\n", + " self.n_vars = n_vars\n", + " self.flatten = nn.Flatten(start_dim=-2)\n", + " self.linear = nn.Linear(nf, target_window)\n", + " self.dropout = nn.Dropout(head_dropout)\n", + "\n", + " def forward(self, x): # x: [bs x nvars x d_model x patch_num]\n", + " x = self.flatten(x)\n", + " x = self.linear(x)\n", + " x = self.dropout(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class Encoder(nn.Module):\n", + " def __init__(self, layers, norm_layer=None, projection=None):\n", + " super(Encoder, self).__init__()\n", + " self.layers = nn.ModuleList(layers)\n", + " self.norm = norm_layer\n", + " self.projection = projection\n", + "\n", + " def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):\n", + " for layer in self.layers:\n", + " x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)\n", + "\n", + " if self.norm is not None:\n", + " x = self.norm(x)\n", + "\n", + " if self.projection is not None:\n", + " x = self.projection(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class EncoderLayer(nn.Module):\n", + " def __init__(self, self_attention, cross_attention, d_model, d_ff=None,\n", + " dropout=0.1, activation=\"relu\"):\n", + " super(EncoderLayer, self).__init__()\n", + " d_ff = d_ff or 4 * d_model\n", + " self.self_attention = self_attention\n", + " self.cross_attention = cross_attention\n", + " self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)\n", + " self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)\n", + " self.norm1 = nn.LayerNorm(d_model)\n", + " self.norm2 = nn.LayerNorm(d_model)\n", + " self.norm3 = nn.LayerNorm(d_model)\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.activation = F.relu if activation == \"relu\" else F.gelu\n", + "\n", + " def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):\n", + " B, L, D = cross.shape\n", + " x = x + self.dropout(self.self_attention(\n", + " x, x, x,\n", + " attn_mask=x_mask,\n", + " tau=tau, delta=None\n", + " )[0])\n", + " x = self.norm1(x)\n", + "\n", + " x_glb_ori = x[:, -1, :].unsqueeze(1)\n", + " x_glb = torch.reshape(x_glb_ori, (B, -1, D))\n", + " x_glb_attn = self.dropout(self.cross_attention(\n", + " x_glb, cross, cross,\n", + " attn_mask=cross_mask,\n", + " tau=tau, delta=delta\n", + " )[0])\n", + " x_glb_attn = torch.reshape(x_glb_attn,\n", + " (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])).unsqueeze(1)\n", + " x_glb = x_glb_ori + x_glb_attn\n", + " x_glb = self.norm2(x_glb)\n", + "\n", + " y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)\n", + "\n", + " y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n", + " y = self.dropout(self.conv2(y).transpose(-1, 1))\n", + "\n", + " return self.norm3(x + y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class EnEmbedding(nn.Module):\n", + " def __init__(self, n_vars, d_model, patch_len, dropout):\n", + " super(EnEmbedding, self).__init__()\n", + " # Patching\n", + " self.patch_len = patch_len\n", + "\n", + " self.value_embedding = nn.Linear(patch_len, d_model, bias=False)\n", + " self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))\n", + " self.position_embedding = PositionalEmbedding(d_model)\n", + "\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x):\n", + " # do patching\n", + " n_vars = x.shape[1]\n", + " glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))\n", + "\n", + " x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)\n", + " x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))\n", + " # Input encoding\n", + " x = self.value_embedding(x) + self.position_embedding(x)\n", + " x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))\n", + " x = torch.cat([x, glb], dim=2)\n", + " x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))\n", + " return self.dropout(x), n_vars" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class TimeXer(BaseMultivariate):\n", + " \"\"\"\n", + " TimeXer\n", + "\n", + " **Parameters:**
\n", + " `h`: int, Forecast horizon.
\n", + " `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
\n", + " `n_series`: int, number of time-series.
\n", + " `futr_exog_list`: str list, future exogenous columns.
\n", + " `hist_exog_list`: str list, historic exogenous columns.
\n", + " `stat_exog_list`: str list, static exogenous columns.
\n", + " `patch_len`: int, length of patches.
\n", + " `hidden_size`: int, dimension of the model.
\n", + " `n_heads`: int, number of heads.
\n", + " `e_layers`: int, number of encoder layers.
\n", + " `d_ff`: int, dimension of fully-connected layer.
\n", + " `factor`: int, attention factor.
\n", + " `dropout`: float, dropout rate.
\n", + " `use_norm`: bool, whether to normalize or not.
\n", + " `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", + " `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", + " `max_steps`: int=1000, maximum number of training steps.
\n", + " `learning_rate`: float=1e-3, Learning rate between (0, 1).
\n", + " `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n", + " `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n", + " `val_check_steps`: int=100, Number of training steps between every validation loss check.
\n", + " `batch_size`: int=32, number of different series in each batch.
\n", + " `step_size`: int=1, step size between each window of temporal data.
\n", + " `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n", + " `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n", + " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", + " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", + " `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
\n", + " `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
\n", + " `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`.
\n", + " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", + "\n", + " **Parameters:**
\n", + "\n", + " **References**\n", + " - [Yuxuan Wang, Haixu Wu, Jiaxiang Dong, Guo Qin, Haoran Zhang, Yong Liu, Yunzhong Qiu, Jianmin Wang, Mingsheng Long. \"TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables\"](https://arxiv.org/abs/2402.19072)\n", + " \"\"\"\n", + "\n", + " # Class attributes\n", + " SAMPLING_TYPE = 'multivariate'\n", + " EXOGENOUS_FUTR = True\n", + " EXOGENOUS_HIST = False\n", + " EXOGENOUS_STAT = False\n", + "\n", + " def __init__(self,\n", + " h,\n", + " input_size,\n", + " n_series,\n", + " futr_exog_list = None,\n", + " hist_exog_list = None,\n", + " stat_exog_list = None,\n", + " patch_len: int = 16,\n", + " hidden_size: int = 512,\n", + " n_heads: int = 8,\n", + " e_layers: int = 2,\n", + " d_ff: int = 2048,\n", + " factor: int = 1,\n", + " dropout: float = 0.1,\n", + " use_norm: bool = True,\n", + " loss = MAE(),\n", + " valid_loss = None,\n", + " max_steps: int = 1000,\n", + " learning_rate: float = 1e-3,\n", + " num_lr_decays: int = -1,\n", + " early_stop_patience_steps: int =-1,\n", + " val_check_steps: int = 100,\n", + " batch_size: int = 32,\n", + " step_size: int = 1,\n", + " scaler_type: str = 'identity',\n", + " random_seed: int = 1,\n", + " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", + " lr_scheduler = None,\n", + " lr_scheduler_kwargs = None, \n", + " dataloader_kwargs = None, \n", + " **trainer_kwargs):\n", + " \n", + " super(TimeXer, self).__init__(h=h,\n", + " input_size=input_size,\n", + " n_series=n_series,\n", + " stat_exog_list = stat_exog_list,\n", + " futr_exog_list = futr_exog_list,\n", + " hist_exog_list = hist_exog_list,\n", + " loss=loss,\n", + " valid_loss=valid_loss,\n", + " max_steps=max_steps,\n", + " learning_rate=learning_rate,\n", + " num_lr_decays=num_lr_decays,\n", + " early_stop_patience_steps=early_stop_patience_steps,\n", + " val_check_steps=val_check_steps,\n", + " batch_size=batch_size,\n", + " step_size=step_size,\n", + " scaler_type=scaler_type,\n", + " random_seed=random_seed,\n", + " drop_last_loader=drop_last_loader,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", + " lr_scheduler=lr_scheduler,\n", + " lr_scheduler_kwargs=lr_scheduler_kwargs,\n", + " dataloader_kwargs=dataloader_kwargs,\n", + " **trainer_kwargs)\n", + " \n", + " self.enc_in = n_series\n", + " self.hidden_size = hidden_size\n", + " self.n_heads = n_heads\n", + " self.e_layers = e_layers\n", + " self.d_ff = d_ff\n", + " self.dropout = dropout\n", + " self.factor = factor\n", + " self.patch_len = patch_len\n", + " self.use_norm = use_norm\n", + " self.patch_num = int(input_size // self.patch_len)\n", + "\n", + " # Architecture\n", + " self.en_embedding = EnEmbedding(n_series, self.hidden_size, self.patch_len, self.dropout)\n", + " self.ex_embedding = DataEmbedding_inverted(input_size, self.hidden_size, self.dropout)\n", + "\n", + " self.encoder = Encoder(\n", + " [\n", + " EncoderLayer(\n", + " AttentionLayer(\n", + " FullAttention(False, self.factor, attention_dropout=self.dropout,\n", + " output_attention=False),\n", + " self.hidden_size, self.n_heads),\n", + " AttentionLayer(\n", + " FullAttention(False, self.factor, attention_dropout=self.dropout,\n", + " output_attention=False),\n", + " self.hidden_size, self.n_heads),\n", + " self.hidden_size,\n", + " self.d_ff,\n", + " dropout=self.dropout,\n", + " activation='relu',\n", + " )\n", + " for l in range(self.e_layers)\n", + " ],\n", + " norm_layer=torch.nn.LayerNorm(self.hidden_size)\n", + " )\n", + " self.head_nf = self.hidden_size * (self.patch_num + 1)\n", + " self.head = FlattenHead(self.enc_in, self.head_nf, h,\n", + " head_dropout=self.dropout)\n", + " \n", + " def forecast(self, x_enc, x_mark_enc):\n", + " if self.use_norm:\n", + " # Normalization from Non-stationary Transformer\n", + " means = x_enc.mean(1, keepdim=True).detach()\n", + " x_enc = x_enc - means\n", + " stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)\n", + " x_enc /= stdev\n", + "\n", + " _, _, N = x_enc.shape\n", + "\n", + " en_embed, n_vars = self.en_embedding(x_enc.permute(0, 2, 1))\n", + " ex_embed = self.ex_embedding(x_enc, x_mark_enc)\n", + "\n", + " enc_out = self.encoder(en_embed, ex_embed)\n", + " enc_out = torch.reshape(\n", + " enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))\n", + " # z: [bs x nvars x d_model x patch_num]\n", + " enc_out = enc_out.permute(0, 1, 3, 2)\n", + "\n", + " dec_out = self.head(enc_out) # z: [bs x nvars x target_window]\n", + " dec_out = dec_out.permute(0, 2, 1)\n", + "\n", + " if self.use_norm:\n", + " # De-Normalization from Non-stationary Transformer\n", + " dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.h, 1))\n", + " dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.h, 1))\n", + "\n", + " return dec_out\n", + " \n", + " def forward(self, windows_batch):\n", + " insample_y = windows_batch['insample_y']\n", + " futr_exog = windows_batch['futr_exog']\n", + " \n", + " if self.futr_exog_size > 0:\n", + " x_mark_enc = futr_exog[:, :, :self.input_size, :]\n", + " B, V, T, D = x_mark_enc.shape\n", + " x_mark_enc = x_mark_enc.reshape(B, T, V*D)\n", + " else:\n", + " x_mark_enc = None\n", + "\n", + " y_pred = self.forecast(insample_y, x_mark_enc)\n", + " y_pred = y_pred[:, -self.h:, :]\n", + " y_pred = self.loss.domain_map(y_pred)\n", + "\n", + " if y_pred.ndim == 2:\n", + " return y_pred.unsqueeze(-1)\n", + " else:\n", + " return y_pred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(TimeXer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(TimeXer.fit, name='TimeXer.fit')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(TimeXer.predict, name='TimeXer.predict')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Usage example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| eval: false\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from neuralforecast import NeuralForecast\n", + "from neuralforecast.models import TimeXer\n", + "from neuralforecast.losses.pytorch import MSE\n", + "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n", + "\n", + "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n", + "\n", + "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n", + "\n", + "model = TimeXer(h=12,\n", + " input_size=24,\n", + " n_series=2,\n", + " futr_exog_list=[\"trend\", \"month\"],\n", + " patch_len=12,\n", + " hidden_size=128,\n", + " n_heads=2,\n", + " e_layers=2,\n", + " d_ff=4,\n", + " factor=1,\n", + " dropout=0.1,\n", + " use_norm=True,\n", + " loss=MSE(),\n", + " valid_loss=MAE(),\n", + " early_stop_patience_steps=3,\n", + " batch_size=32)\n", + "\n", + "fcst = NeuralForecast(models=[model], freq='M')\n", + "fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n", + "forecasts = fcst.predict(futr_df=Y_test_df)\n", + "\n", + "# Plot predictions\n", + "fig, ax = plt.subplots(1, 1, figsize = (20, 7))\n", + "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n", + "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n", + "plot_df = pd.concat([Y_train_df, plot_df])\n", + "\n", + "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n", + "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n", + "plt.plot(plot_df['ds'], plot_df['TimeXer'], c='blue', label='Forecast')\n", + "ax.set_title('AirPassengers Forecast', fontsize=22)\n", + "ax.set_ylabel('Monthly Passengers', fontsize=20)\n", + "ax.set_xlabel('Year', fontsize=20)\n", + "ax.legend(prop={'size': 15})\n", + "ax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/models.vanillatransformer.ipynb b/nbs/models.vanillatransformer.ipynb index 232de7dfa..e4354af1a 100644 --- a/nbs/models.vanillatransformer.ipynb +++ b/nbs/models.vanillatransformer.ipynb @@ -55,7 +55,6 @@ "outputs": [], "source": [ "#| export\n", - "import math\n", "import numpy as np\n", "from typing import Optional\n", "\n", @@ -65,7 +64,7 @@ "from neuralforecast.common._modules import (\n", " TransEncoderLayer, TransEncoder,\n", " TransDecoderLayer, TransDecoder,\n", - " DataEmbedding, AttentionLayer,\n", + " DataEmbedding, AttentionLayer, FullAttention\n", ")\n", "from neuralforecast.common._base_windows import BaseWindows\n", "\n", @@ -87,64 +86,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Auxiliary Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "class TriangularCausalMask():\n", - " def __init__(self, B, L, device=\"cpu\"):\n", - " mask_shape = [B, 1, L, L]\n", - " with torch.no_grad():\n", - " self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)\n", - "\n", - " @property\n", - " def mask(self):\n", - " return self._mask\n", - "\n", - "class FullAttention(nn.Module):\n", - " \"\"\"\n", - " FullAttention\n", - " \"\"\" \n", - " def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):\n", - " super(FullAttention, self).__init__()\n", - " self.scale = scale\n", - " self.mask_flag = mask_flag\n", - " self.output_attention = output_attention\n", - " self.dropout = nn.Dropout(attention_dropout)\n", - "\n", - " def forward(self, queries, keys, values, attn_mask):\n", - " B, L, H, E = queries.shape\n", - " _, S, _, D = values.shape\n", - " scale = self.scale or 1. / math.sqrt(E)\n", - "\n", - " scores = torch.einsum(\"blhe,bshe->bhls\", queries, keys)\n", - " \n", - " if self.mask_flag:\n", - " if attn_mask is None:\n", - " attn_mask = TriangularCausalMask(B, L, device=queries.device)\n", - "\n", - " scores.masked_fill_(attn_mask.mask, -np.inf)\n", - "\n", - " A = self.dropout(torch.softmax(scale * scores, dim=-1))\n", - " V = torch.einsum(\"bhls,bshd->blhd\", A, values)\n", - "\n", - " if self.output_attention:\n", - " return (V.contiguous(), A)\n", - " else:\n", - " return (V.contiguous(), None)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. VanillaTransformer" + "## 1. VanillaTransformer" ] }, { diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml index f4d7e3b2d..31ed63346 100644 --- a/nbs/sidebar.yml +++ b/nbs/sidebar.yml @@ -50,6 +50,7 @@ website: - models.timemixer.ipynb - models.timellm.ipynb - models.timesnet.ipynb + - models.timexer.ipynb - models.tsmixer.ipynb - models.tsmixerx.ipynb - models.vanillatransformer.ipynb diff --git a/neuralforecast/_modidx.py b/neuralforecast/_modidx.py index 445266155..ab1fe227b 100644 --- a/neuralforecast/_modidx.py +++ b/neuralforecast/_modidx.py @@ -140,6 +140,11 @@ 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoTimeMixer.get_default_config': ( 'models.html#autotimemixer.get_default_config', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoTimeXer': ('models.html#autotimexer', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoTimeXer.__init__': ( 'models.html#autotimexer.__init__', + 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoTimeXer.get_default_config': ( 'models.html#autotimexer.get_default_config', + 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoTimesNet': ('models.html#autotimesnet', 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoTimesNet.__init__': ( 'models.html#autotimesnet.__init__', 'neuralforecast/auto.py'), @@ -788,25 +793,7 @@ 'neuralforecast/models/informer.py'), 'neuralforecast.models.informer.ProbMask.mask': ( 'models.informer.html#probmask.mask', 'neuralforecast/models/informer.py')}, - 'neuralforecast.models.itransformer': { 'neuralforecast.models.itransformer.DataEmbedding_inverted': ( 'models.itransformer.html#dataembedding_inverted', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.DataEmbedding_inverted.__init__': ( 'models.itransformer.html#dataembedding_inverted.__init__', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.DataEmbedding_inverted.forward': ( 'models.itransformer.html#dataembedding_inverted.forward', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.FullAttention': ( 'models.itransformer.html#fullattention', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.FullAttention.__init__': ( 'models.itransformer.html#fullattention.__init__', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.FullAttention.forward': ( 'models.itransformer.html#fullattention.forward', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.TriangularCausalMask': ( 'models.itransformer.html#triangularcausalmask', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.TriangularCausalMask.__init__': ( 'models.itransformer.html#triangularcausalmask.__init__', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.TriangularCausalMask.mask': ( 'models.itransformer.html#triangularcausalmask.mask', - 'neuralforecast/models/itransformer.py'), - 'neuralforecast.models.itransformer.iTransformer': ( 'models.itransformer.html#itransformer', + 'neuralforecast.models.itransformer': { 'neuralforecast.models.itransformer.iTransformer': ( 'models.itransformer.html#itransformer', 'neuralforecast/models/itransformer.py'), 'neuralforecast.models.itransformer.iTransformer.__init__': ( 'models.itransformer.html#itransformer.__init__', 'neuralforecast/models/itransformer.py'), @@ -1292,6 +1279,38 @@ 'neuralforecast/models/timesnet.py'), 'neuralforecast.models.timesnet.TimesNet.forward': ( 'models.timesnet.html#timesnet.forward', 'neuralforecast/models/timesnet.py')}, + 'neuralforecast.models.timexer': { 'neuralforecast.models.timexer.EnEmbedding': ( 'models.timexer.html#enembedding', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.EnEmbedding.__init__': ( 'models.timexer.html#enembedding.__init__', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.EnEmbedding.forward': ( 'models.timexer.html#enembedding.forward', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.Encoder': ( 'models.timexer.html#encoder', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.Encoder.__init__': ( 'models.timexer.html#encoder.__init__', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.Encoder.forward': ( 'models.timexer.html#encoder.forward', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.EncoderLayer': ( 'models.timexer.html#encoderlayer', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.EncoderLayer.__init__': ( 'models.timexer.html#encoderlayer.__init__', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.EncoderLayer.forward': ( 'models.timexer.html#encoderlayer.forward', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.FlattenHead': ( 'models.timexer.html#flattenhead', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.FlattenHead.__init__': ( 'models.timexer.html#flattenhead.__init__', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.FlattenHead.forward': ( 'models.timexer.html#flattenhead.forward', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.TimeXer': ( 'models.timexer.html#timexer', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.TimeXer.__init__': ( 'models.timexer.html#timexer.__init__', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.TimeXer.forecast': ( 'models.timexer.html#timexer.forecast', + 'neuralforecast/models/timexer.py'), + 'neuralforecast.models.timexer.TimeXer.forward': ( 'models.timexer.html#timexer.forward', + 'neuralforecast/models/timexer.py')}, 'neuralforecast.models.tsmixer': { 'neuralforecast.models.tsmixer.FeatureMixing': ( 'models.tsmixer.html#featuremixing', 'neuralforecast/models/tsmixer.py'), 'neuralforecast.models.tsmixer.FeatureMixing.__init__': ( 'models.tsmixer.html#featuremixing.__init__', @@ -1362,19 +1381,7 @@ 'neuralforecast/models/tsmixerx.py'), 'neuralforecast.models.tsmixerx.TemporalMixing.forward': ( 'models.tsmixerx.html#temporalmixing.forward', 'neuralforecast/models/tsmixerx.py')}, - 'neuralforecast.models.vanillatransformer': { 'neuralforecast.models.vanillatransformer.FullAttention': ( 'models.vanillatransformer.html#fullattention', - 'neuralforecast/models/vanillatransformer.py'), - 'neuralforecast.models.vanillatransformer.FullAttention.__init__': ( 'models.vanillatransformer.html#fullattention.__init__', - 'neuralforecast/models/vanillatransformer.py'), - 'neuralforecast.models.vanillatransformer.FullAttention.forward': ( 'models.vanillatransformer.html#fullattention.forward', - 'neuralforecast/models/vanillatransformer.py'), - 'neuralforecast.models.vanillatransformer.TriangularCausalMask': ( 'models.vanillatransformer.html#triangularcausalmask', - 'neuralforecast/models/vanillatransformer.py'), - 'neuralforecast.models.vanillatransformer.TriangularCausalMask.__init__': ( 'models.vanillatransformer.html#triangularcausalmask.__init__', - 'neuralforecast/models/vanillatransformer.py'), - 'neuralforecast.models.vanillatransformer.TriangularCausalMask.mask': ( 'models.vanillatransformer.html#triangularcausalmask.mask', - 'neuralforecast/models/vanillatransformer.py'), - 'neuralforecast.models.vanillatransformer.VanillaTransformer': ( 'models.vanillatransformer.html#vanillatransformer', + 'neuralforecast.models.vanillatransformer': { 'neuralforecast.models.vanillatransformer.VanillaTransformer': ( 'models.vanillatransformer.html#vanillatransformer', 'neuralforecast/models/vanillatransformer.py'), 'neuralforecast.models.vanillatransformer.VanillaTransformer.__init__': ( 'models.vanillatransformer.html#vanillatransformer.__init__', 'neuralforecast/models/vanillatransformer.py'), diff --git a/neuralforecast/auto.py b/neuralforecast/auto.py index b3c85892a..1d407bd58 100644 --- a/neuralforecast/auto.py +++ b/neuralforecast/auto.py @@ -4,7 +4,7 @@ __all__ = ['AutoRNN', 'AutoLSTM', 'AutoGRU', 'AutoTCN', 'AutoDeepAR', 'AutoDilatedRNN', 'AutoBiTCN', 'AutoMLP', 'AutoNBEATS', 'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTiDE', 'AutoDeepNPTS', 'AutoKAN', 'AutoTFT', 'AutoVanillaTransformer', 'AutoInformer', 'AutoAutoformer', 'AutoFEDformer', 'AutoPatchTST', - 'AutoiTransformer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer', 'AutoTSMixerx', + 'AutoiTransformer', 'AutoTimeXer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer', 'AutoTSMixerx', 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer', 'AutoRMoK'] # %% ../nbs/models.ipynb 2 @@ -42,6 +42,7 @@ from .models.patchtst import PatchTST from .models.timesnet import TimesNet from .models.itransformer import iTransformer +from .models.timexer import TimeXer from .models.kan import KAN from .models.rmok import RMoK @@ -1672,7 +1673,92 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 109 +# %% ../nbs/models.ipynb 108 +class AutoTimeXer(BaseAuto): + + default_config = { + "input_size_multiplier": [1, 2, 3, 4, 5], + "h": None, + "n_series": None, + "hidden_size": tune.choice([128, 256, 512]), + "n_heads": tune.choice([4, 8]), + "learning_rate": tune.loguniform(1e-4, 1e-1), + "scaler_type": tune.choice([None, "robust", "standard"]), + "max_steps": tune.choice([500, 1000, 2000]), + "batch_size": tune.choice([32, 64, 128, 256]), + "loss": None, + "random_seed": tune.randint(1, 20), + } + + def __init__( + self, + h, + n_series, + loss=MAE(), + valid_loss=None, + config=None, + search_alg=BasicVariantGenerator(random_state=1), + num_samples=10, + refit_with_val=False, + cpus=cpu_count(), + gpus=torch.cuda.device_count(), + verbose=False, + alias=None, + backend="ray", + callbacks=None, + ): + + # Define search space, input/output sizes + if config is None: + config = self.get_default_config(h=h, backend=backend, n_series=n_series) + + # Always use n_series from parameters, raise exception with Optuna because we can't enforce it + if backend == "ray": + config["n_series"] = n_series + elif backend == "optuna": + mock_trial = MockTrial() + if ( + "n_series" in config(mock_trial) + and config(mock_trial)["n_series"] != n_series + ) or ("n_series" not in config(mock_trial)): + raise Exception(f"config needs 'n_series': {n_series}") + + super(AutoTimeXer, self).__init__( + cls_model=TimeXer, + h=h, + loss=loss, + valid_loss=valid_loss, + config=config, + search_alg=search_alg, + num_samples=num_samples, + refit_with_val=refit_with_val, + cpus=cpus, + gpus=gpus, + verbose=verbose, + alias=alias, + backend=backend, + callbacks=callbacks, + ) + + @classmethod + def get_default_config(cls, h, backend, n_series): + config = cls.default_config.copy() + config["input_size"] = tune.choice( + [h * x for x in config["input_size_multiplier"]] + ) + + # Rolling windows with step_size=1 or step_size=h + # See `BaseWindows` and `BaseRNN`'s create_windows + config["step_size"] = tune.choice([1, h]) + del config["input_size_multiplier"] + if backend == "optuna": + # Always use n_series from parameters + config["n_series"] = n_series + config = cls._ray_config_to_optuna(config) + + return config + +# %% ../nbs/models.ipynb 113 class AutoTimesNet(BaseAuto): default_config = { @@ -1740,7 +1826,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 114 +# %% ../nbs/models.ipynb 118 class AutoStemGNN(BaseAuto): default_config = { @@ -1825,7 +1911,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 118 +# %% ../nbs/models.ipynb 122 class AutoHINT(BaseAuto): def __init__( @@ -1897,7 +1983,7 @@ def _fit_model( def get_default_config(cls, h, backend, n_series=None): raise Exception("AutoHINT has no default configuration.") -# %% ../nbs/models.ipynb 123 +# %% ../nbs/models.ipynb 127 class AutoTSMixer(BaseAuto): default_config = { @@ -1983,7 +2069,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 127 +# %% ../nbs/models.ipynb 131 class AutoTSMixerx(BaseAuto): default_config = { @@ -2069,7 +2155,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 131 +# %% ../nbs/models.ipynb 135 class AutoMLPMultivariate(BaseAuto): default_config = { @@ -2154,7 +2240,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 135 +# %% ../nbs/models.ipynb 139 class AutoSOFTS(BaseAuto): default_config = { @@ -2239,7 +2325,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 139 +# %% ../nbs/models.ipynb 143 class AutoTimeMixer(BaseAuto): default_config = { @@ -2325,7 +2411,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 143 +# %% ../nbs/models.ipynb 147 class AutoRMoK(BaseAuto): default_config = { diff --git a/neuralforecast/common/_modules.py b/neuralforecast/common/_modules.py index d50228b87..9c73ff85c 100644 --- a/neuralforecast/common/_modules.py +++ b/neuralforecast/common/_modules.py @@ -2,12 +2,13 @@ # %% auto 0 __all__ = ['ACTIVATIONS', 'MLP', 'Chomp1d', 'CausalConv1d', 'TemporalConvolutionEncoder', 'TransEncoderLayer', 'TransEncoder', - 'TransDecoderLayer', 'TransDecoder', 'AttentionLayer', 'PositionalEmbedding', 'TokenEmbedding', - 'TimeFeatureEmbedding', 'FixedEmbedding', 'TemporalEmbedding', 'DataEmbedding', 'MovingAvg', 'SeriesDecomp', - 'RevIN'] + 'TransDecoderLayer', 'TransDecoder', 'AttentionLayer', 'TriangularCausalMask', 'FullAttention', + 'PositionalEmbedding', 'TokenEmbedding', 'TimeFeatureEmbedding', 'FixedEmbedding', 'TemporalEmbedding', + 'DataEmbedding', 'DataEmbedding_inverted', 'MovingAvg', 'SeriesDecomp', 'RevIN'] # %% ../../nbs/common.modules.ipynb 3 import math +import numpy as np import torch import torch.nn as nn @@ -317,34 +318,95 @@ def forward(self, x, cross, x_mask=None, cross_mask=None): # %% ../../nbs/common.modules.ipynb 17 class AttentionLayer(nn.Module): - def __init__(self, attention, hidden_size, n_head, d_keys=None, d_values=None): + def __init__(self, attention, hidden_size, n_heads, d_keys=None, d_values=None): super(AttentionLayer, self).__init__() - d_keys = d_keys or (hidden_size // n_head) - d_values = d_values or (hidden_size // n_head) + d_keys = d_keys or (hidden_size // n_heads) + d_values = d_values or (hidden_size // n_heads) self.inner_attention = attention - self.query_projection = nn.Linear(hidden_size, d_keys * n_head) - self.key_projection = nn.Linear(hidden_size, d_keys * n_head) - self.value_projection = nn.Linear(hidden_size, d_values * n_head) - self.out_projection = nn.Linear(d_values * n_head, hidden_size) - self.n_head = n_head + self.query_projection = nn.Linear(hidden_size, d_keys * n_heads) + self.key_projection = nn.Linear(hidden_size, d_keys * n_heads) + self.value_projection = nn.Linear(hidden_size, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, hidden_size) + self.n_heads = n_heads - def forward(self, queries, keys, values, attn_mask): + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): B, L, _ = queries.shape _, S, _ = keys.shape - H = self.n_head + H = self.n_heads queries = self.query_projection(queries).view(B, L, H, -1) keys = self.key_projection(keys).view(B, S, H, -1) values = self.value_projection(values).view(B, S, H, -1) - out, attn = self.inner_attention(queries, keys, values, attn_mask) + out, attn = self.inner_attention( + queries=queries, + keys=keys, + values=values, + attn_mask=attn_mask, + tau=tau, + delta=delta, + ) out = out.view(B, L, -1) return self.out_projection(out), attn # %% ../../nbs/common.modules.ipynb 18 +class TriangularCausalMask: + """ + TriangularCausalMask + """ + + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu( + torch.ones(mask_shape, dtype=torch.bool), diagonal=1 + ).to(device) + + @property + def mask(self): + return self._mask + + +class FullAttention(nn.Module): + def __init__( + self, + mask_flag=True, + factor=5, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1.0 / math.sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + +# %% ../../nbs/common.modules.ipynb 19 class PositionalEmbedding(nn.Module): def __init__(self, hidden_size, max_len=5000): super(PositionalEmbedding, self).__init__() @@ -487,7 +549,29 @@ def forward(self, x, x_mark=None): return self.dropout(x) -# %% ../../nbs/common.modules.ipynb 19 + +class DataEmbedding_inverted(nn.Module): + """ + DataEmbedding_inverted + """ + + def __init__(self, c_in, hidden_size, dropout=0.1): + super(DataEmbedding_inverted, self).__init__() + self.value_embedding = nn.Linear(c_in, hidden_size) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + # the potential to take covariates (e.g. timestamps) as tokens + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate hidden_size] + return self.dropout(x) + +# %% ../../nbs/common.modules.ipynb 20 class MovingAvg(nn.Module): """ Moving average block to highlight the trend of time series @@ -522,7 +606,7 @@ def forward(self, x): res = x - moving_mean return res, moving_mean -# %% ../../nbs/common.modules.ipynb 20 +# %% ../../nbs/common.modules.ipynb 21 class RevIN(nn.Module): """RevIN (Reversible-Instance-Normalization)""" diff --git a/neuralforecast/core.py b/neuralforecast/core.py index f8b254745..74050c6c4 100644 --- a/neuralforecast/core.py +++ b/neuralforecast/core.py @@ -66,6 +66,7 @@ TimeMixer, KAN, RMoK, + TimeXer, ) from .common._base_auto import BaseAuto, MockTrial from .utils import PredictionIntervals, get_prediction_interval_method @@ -193,6 +194,8 @@ def _insample_times( "autokan": KAN, "rmok": RMoK, "autormok": RMoK, + "timexer": TimeXer, + "autotimexer": TimeXer, } # %% ../nbs/core.ipynb 8 diff --git a/neuralforecast/models/__init__.py b/neuralforecast/models/__init__.py index 414689631..6a12e7b4a 100644 --- a/neuralforecast/models/__init__.py +++ b/neuralforecast/models/__init__.py @@ -3,6 +3,7 @@ 'TFT', 'VanillaTransformer', 'Informer', 'Autoformer', 'PatchTST', 'FEDformer', 'StemGNN', 'HINT', 'TimesNet', 'TimeLLM', 'TSMixer', 'TSMixerx', 'MLPMultivariate', 'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'TimeMixer', 'KAN', 'RMoK', + 'TimeXer', ] from .rnn import RNN @@ -38,3 +39,4 @@ from .timemixer import TimeMixer from .kan import KAN from .rmok import RMoK +from .timexer import TimeXer diff --git a/neuralforecast/models/informer.py b/neuralforecast/models/informer.py index cb4ff2622..7b2465987 100644 --- a/neuralforecast/models/informer.py +++ b/neuralforecast/models/informer.py @@ -149,7 +149,7 @@ def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): else: return (context_in, None) - def forward(self, queries, keys, values, attn_mask): + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): B, L_Q, H, D = queries.shape _, L_K, _, _ = keys.shape diff --git a/neuralforecast/models/itransformer.py b/neuralforecast/models/itransformer.py index 121eac2b5..45315d8ab 100644 --- a/neuralforecast/models/itransformer.py +++ b/neuralforecast/models/itransformer.py @@ -1,17 +1,13 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.itransformer.ipynb. # %% auto 0 -__all__ = ['TriangularCausalMask', 'FullAttention', 'DataEmbedding_inverted', 'iTransformer'] +__all__ = ['iTransformer'] # %% ../../nbs/models.itransformer.ipynb 6 import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np - -from math import sqrt - from ..losses.pytorch import MAE from ..common._base_multivariate import BaseMultivariate @@ -19,89 +15,11 @@ TransEncoder, TransEncoderLayer, AttentionLayer, + FullAttention, + DataEmbedding_inverted, ) -# %% ../../nbs/models.itransformer.ipynb 9 -class TriangularCausalMask: - """ - TriangularCausalMask - """ - - def __init__(self, B, L, device="cpu"): - mask_shape = [B, 1, L, L] - with torch.no_grad(): - self._mask = torch.triu( - torch.ones(mask_shape, dtype=torch.bool), diagonal=1 - ).to(device) - - @property - def mask(self): - return self._mask - - -class FullAttention(nn.Module): - """ - FullAttention - """ - - def __init__( - self, - mask_flag=True, - factor=5, - scale=None, - attention_dropout=0.1, - output_attention=False, - ): - super(FullAttention, self).__init__() - self.scale = scale - self.mask_flag = mask_flag - self.output_attention = output_attention - self.dropout = nn.Dropout(attention_dropout) - - def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): - B, L, H, E = queries.shape - _, S, _, D = values.shape - scale = self.scale or 1.0 / sqrt(E) - - scores = torch.einsum("blhe,bshe->bhls", queries, keys) - - if self.mask_flag: - if attn_mask is None: - attn_mask = TriangularCausalMask(B, L, device=queries.device) - - scores.masked_fill_(attn_mask.mask, -np.inf) - - A = self.dropout(torch.softmax(scale * scores, dim=-1)) - V = torch.einsum("bhls,bshd->blhd", A, values) - - if self.output_attention: - return (V.contiguous(), A) - else: - return (V.contiguous(), None) - -# %% ../../nbs/models.itransformer.ipynb 11 -class DataEmbedding_inverted(nn.Module): - """ - DataEmbedding_inverted - """ - - def __init__(self, c_in, hidden_size, dropout=0.1): - super(DataEmbedding_inverted, self).__init__() - self.value_embedding = nn.Linear(c_in, hidden_size) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, x, x_mark): - x = x.permute(0, 2, 1) - # x: [Batch Variate Time] - if x_mark is None: - x = self.value_embedding(x) - else: - # the potential to take covariates (e.g. timestamps) as tokens - x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) - # x: [Batch Variate hidden_size] - return self.dropout(x) - -# %% ../../nbs/models.itransformer.ipynb 13 +# %% ../../nbs/models.itransformer.ipynb 8 class iTransformer(BaseMultivariate): """iTransformer diff --git a/neuralforecast/models/timexer.py b/neuralforecast/models/timexer.py new file mode 100644 index 000000000..14adf1d15 --- /dev/null +++ b/neuralforecast/models/timexer.py @@ -0,0 +1,352 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.timexer.ipynb. + +# %% auto 0 +__all__ = ['FlattenHead', 'Encoder', 'EncoderLayer', 'EnEmbedding', 'TimeXer'] + +# %% ../../nbs/models.timexer.ipynb 5 +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..losses.pytorch import MAE +from ..common._base_multivariate import BaseMultivariate +from neuralforecast.common._modules import ( + DataEmbedding_inverted, + PositionalEmbedding, + FullAttention, + AttentionLayer, +) + +# %% ../../nbs/models.timexer.ipynb 7 +class FlattenHead(nn.Module): + def __init__(self, n_vars, nf, target_window, head_dropout=0): + super().__init__() + self.n_vars = n_vars + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): # x: [bs x nvars x d_model x patch_num] + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x + +# %% ../../nbs/models.timexer.ipynb 8 +class Encoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Encoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + for layer in self.layers: + x = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x + +# %% ../../nbs/models.timexer.ipynb 9 +class EncoderLayer(nn.Module): + def __init__( + self, + self_attention, + cross_attention, + d_model, + d_ff=None, + dropout=0.1, + activation="relu", + ): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + B, L, D = cross.shape + x = x + self.dropout( + self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] + ) + x = self.norm1(x) + + x_glb_ori = x[:, -1, :].unsqueeze(1) + x_glb = torch.reshape(x_glb_ori, (B, -1, D)) + x_glb_attn = self.dropout( + self.cross_attention( + x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta + )[0] + ) + x_glb_attn = torch.reshape( + x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]) + ).unsqueeze(1) + x_glb = x_glb_ori + x_glb_attn + x_glb = self.norm2(x_glb) + + y = x = torch.cat([x[:, :-1, :], x_glb], dim=1) + + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + +# %% ../../nbs/models.timexer.ipynb 10 +class EnEmbedding(nn.Module): + def __init__(self, n_vars, d_model, patch_len, dropout): + super(EnEmbedding, self).__init__() + # Patching + self.patch_len = patch_len + + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model)) + self.position_embedding = PositionalEmbedding(d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + glb = self.glb_token.repeat((x.shape[0], 1, 1, 1)) + + x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1])) + x = torch.cat([x, glb], dim=2) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + return self.dropout(x), n_vars + +# %% ../../nbs/models.timexer.ipynb 12 +class TimeXer(BaseMultivariate): + """ + TimeXer + + **Parameters:**
+ `h`: int, Forecast horizon.
+ `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
+ `n_series`: int, number of time-series.
+ `futr_exog_list`: str list, future exogenous columns.
+ `hist_exog_list`: str list, historic exogenous columns.
+ `stat_exog_list`: str list, static exogenous columns.
+ `patch_len`: int, length of patches.
+ `hidden_size`: int, dimension of the model.
+ `n_heads`: int, number of heads.
+ `e_layers`: int, number of encoder layers.
+ `d_ff`: int, dimension of fully-connected layer.
+ `factor`: int, attention factor.
+ `dropout`: float, dropout rate.
+ `use_norm`: bool, whether to normalize or not.
+ `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `max_steps`: int=1000, maximum number of training steps.
+ `learning_rate`: float=1e-3, Learning rate between (0, 1).
+ `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
+ `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
+ `val_check_steps`: int=100, Number of training steps between every validation loss check.
+ `batch_size`: int=32, number of different series in each batch.
+ `step_size`: int=1, step size between each window of temporal data.
+ `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
+ `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
+ `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
+ `alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
+ `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
+ `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
+ `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`.
+ `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
+ + **Parameters:**
+ + **References** + - [Yuxuan Wang, Haixu Wu, Jiaxiang Dong, Guo Qin, Haoran Zhang, Yong Liu, Yunzhong Qiu, Jianmin Wang, Mingsheng Long. "TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables"](https://arxiv.org/abs/2402.19072) + """ + + # Class attributes + SAMPLING_TYPE = "multivariate" + EXOGENOUS_FUTR = True + EXOGENOUS_HIST = False + EXOGENOUS_STAT = False + + def __init__( + self, + h, + input_size, + n_series, + futr_exog_list=None, + hist_exog_list=None, + stat_exog_list=None, + patch_len: int = 16, + hidden_size: int = 512, + n_heads: int = 8, + e_layers: int = 2, + d_ff: int = 2048, + factor: int = 1, + dropout: float = 0.1, + use_norm: bool = True, + loss=MAE(), + valid_loss=None, + max_steps: int = 1000, + learning_rate: float = 1e-3, + num_lr_decays: int = -1, + early_stop_patience_steps: int = -1, + val_check_steps: int = 100, + batch_size: int = 32, + step_size: int = 1, + scaler_type: str = "identity", + random_seed: int = 1, + drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, + lr_scheduler=None, + lr_scheduler_kwargs=None, + dataloader_kwargs=None, + **trainer_kwargs + ): + + super(TimeXer, self).__init__( + h=h, + input_size=input_size, + n_series=n_series, + stat_exog_list=stat_exog_list, + futr_exog_list=futr_exog_list, + hist_exog_list=hist_exog_list, + loss=loss, + valid_loss=valid_loss, + max_steps=max_steps, + learning_rate=learning_rate, + num_lr_decays=num_lr_decays, + early_stop_patience_steps=early_stop_patience_steps, + val_check_steps=val_check_steps, + batch_size=batch_size, + step_size=step_size, + scaler_type=scaler_type, + random_seed=random_seed, + drop_last_loader=drop_last_loader, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, + dataloader_kwargs=dataloader_kwargs, + **trainer_kwargs + ) + + self.enc_in = n_series + self.hidden_size = hidden_size + self.n_heads = n_heads + self.e_layers = e_layers + self.d_ff = d_ff + self.dropout = dropout + self.factor = factor + self.patch_len = patch_len + self.use_norm = use_norm + self.patch_num = int(input_size // self.patch_len) + + # Architecture + self.en_embedding = EnEmbedding( + n_series, self.hidden_size, self.patch_len, self.dropout + ) + self.ex_embedding = DataEmbedding_inverted( + input_size, self.hidden_size, self.dropout + ) + + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=False, + ), + self.hidden_size, + self.n_heads, + ), + AttentionLayer( + FullAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=False, + ), + self.hidden_size, + self.n_heads, + ), + self.hidden_size, + self.d_ff, + dropout=self.dropout, + activation="relu", + ) + for l in range(self.e_layers) + ], + norm_layer=torch.nn.LayerNorm(self.hidden_size), + ) + self.head_nf = self.hidden_size * (self.patch_num + 1) + self.head = FlattenHead(self.enc_in, self.head_nf, h, head_dropout=self.dropout) + + def forecast(self, x_enc, x_mark_enc): + if self.use_norm: + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - means + stdev = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5 + ) + x_enc /= stdev + + _, _, N = x_enc.shape + + en_embed, n_vars = self.en_embedding(x_enc.permute(0, 2, 1)) + ex_embed = self.ex_embedding(x_enc, x_mark_enc) + + enc_out = self.encoder(en_embed, ex_embed) + enc_out = torch.reshape( + enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) + ) + # z: [bs x nvars x d_model x patch_num] + enc_out = enc_out.permute(0, 1, 3, 2) + + dec_out = self.head(enc_out) # z: [bs x nvars x target_window] + dec_out = dec_out.permute(0, 2, 1) + + if self.use_norm: + # De-Normalization from Non-stationary Transformer + dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.h, 1)) + dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.h, 1)) + + return dec_out + + def forward(self, windows_batch): + insample_y = windows_batch["insample_y"] + futr_exog = windows_batch["futr_exog"] + + if self.futr_exog_size > 0: + x_mark_enc = futr_exog[:, :, : self.input_size, :] + B, V, T, D = x_mark_enc.shape + x_mark_enc = x_mark_enc.reshape(B, T, V * D) + else: + x_mark_enc = None + + y_pred = self.forecast(insample_y, x_mark_enc) + y_pred = y_pred[:, -self.h :, :] + y_pred = self.loss.domain_map(y_pred) + + if y_pred.ndim == 2: + return y_pred.unsqueeze(-1) + else: + return y_pred diff --git a/neuralforecast/models/vanillatransformer.py b/neuralforecast/models/vanillatransformer.py index e38c03fc9..a8276d577 100644 --- a/neuralforecast/models/vanillatransformer.py +++ b/neuralforecast/models/vanillatransformer.py @@ -1,10 +1,9 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.vanillatransformer.ipynb. # %% auto 0 -__all__ = ['TriangularCausalMask', 'FullAttention', 'VanillaTransformer'] +__all__ = ['VanillaTransformer'] # %% ../../nbs/models.vanillatransformer.ipynb 5 -import math import numpy as np from typing import Optional @@ -18,61 +17,13 @@ TransDecoder, DataEmbedding, AttentionLayer, + FullAttention, ) from ..common._base_windows import BaseWindows from ..losses.pytorch import MAE # %% ../../nbs/models.vanillatransformer.ipynb 8 -class TriangularCausalMask: - def __init__(self, B, L, device="cpu"): - mask_shape = [B, 1, L, L] - with torch.no_grad(): - self._mask = torch.triu( - torch.ones(mask_shape, dtype=torch.bool), diagonal=1 - ).to(device) - - @property - def mask(self): - return self._mask - - -class FullAttention(nn.Module): - """ - FullAttention - """ - - def __init__( - self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False - ): - super(FullAttention, self).__init__() - self.scale = scale - self.mask_flag = mask_flag - self.output_attention = output_attention - self.dropout = nn.Dropout(attention_dropout) - - def forward(self, queries, keys, values, attn_mask): - B, L, H, E = queries.shape - _, S, _, D = values.shape - scale = self.scale or 1.0 / math.sqrt(E) - - scores = torch.einsum("blhe,bshe->bhls", queries, keys) - - if self.mask_flag: - if attn_mask is None: - attn_mask = TriangularCausalMask(B, L, device=queries.device) - - scores.masked_fill_(attn_mask.mask, -np.inf) - - A = self.dropout(torch.softmax(scale * scores, dim=-1)) - V = torch.einsum("bhls,bshd->blhd", A, values) - - if self.output_attention: - return (V.contiguous(), A) - else: - return (V.contiguous(), None) - -# %% ../../nbs/models.vanillatransformer.ipynb 10 class VanillaTransformer(BaseWindows): """VanillaTransformer