Skip to content

Commit 4ccf66a

Browse files
committed
add f/h/s exog support for models
1 parent 8e93aae commit 4ccf66a

File tree

8 files changed

+539
-101
lines changed

8 files changed

+539
-101
lines changed

nbs/models.itransformer.ipynb

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,66 @@
8989
"# 1. Model"
9090
]
9191
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"#| exporti\n",
99+
"class FeatureEmbedding(nn.Module):\n",
100+
" \"\"\"\n",
101+
" 特征融合模块,通过分通道嵌入实现参数控制:\n",
102+
" 1. 将原始hidden_size均分给各特征通道\n",
103+
" 2. 各特征独立进行嵌入编码\n",
104+
" 3. 沿特征维度拼接最终结果\n",
105+
" \"\"\"\n",
106+
" def __init__(self, input_size, h, hidden_size, hist_exog_size, futr_exog_size, stat_exog_size, dropout):\n",
107+
" super().__init__()\n",
108+
" self.futr_input_size = input_size + h\n",
109+
" self.futr_exog_size = futr_exog_size\n",
110+
" self.hist_exog_size = hist_exog_size\n",
111+
" self.stat_exog_size = stat_exog_size\n",
112+
" self.base_embed = DataEmbedding_inverted(input_size, hidden_size, dropout)\n",
113+
" \n",
114+
" # 历史特征编码器\n",
115+
" self.hist_embed = nn.ModuleList([\n",
116+
" DataEmbedding_inverted(input_size, hidden_size, dropout)\n",
117+
" for _ in range(hist_exog_size)\n",
118+
" ])\n",
119+
" \n",
120+
" # 未来特征编码器(使用历史部分)\n",
121+
" self.futr_embed = nn.ModuleList([\n",
122+
" DataEmbedding_inverted(self.futr_input_size, hidden_size, dropout)\n",
123+
" for _ in range(futr_exog_size)\n",
124+
" ])\n",
125+
" # 静态特征编码(通过线性映射)\n",
126+
" self.stat_embed = nn.Linear(stat_exog_size, hidden_size) if stat_exog_size > 0 else None\n",
127+
"\n",
128+
" def forward(self, y, hist, futr, stat):\n",
129+
" # 基础序列嵌入 [B, N, E]\n",
130+
" embeddings = [self.base_embed(y, None)]\n",
131+
" \n",
132+
" # 历史特征嵌入 [B, N, E] * H\n",
133+
" if self.hist_exog_size > 0:\n",
134+
" for i, embed in enumerate(self.hist_embed):\n",
135+
" embeddings.append(embed(hist[:, i, :, :], None))\n",
136+
" \n",
137+
" # 未来特征嵌入 [B, N, E] * F\n",
138+
" if self.futr_exog_size > 0:\n",
139+
" for i, embed in enumerate(self.futr_embed):\n",
140+
" embeddings.append(embed(futr[:, i, :, :], None))\n",
141+
" \n",
142+
" # 静态特征嵌入 [B, N, E]\n",
143+
" if self.stat_embed is not None:\n",
144+
" stat_feat = self.stat_embed(stat) # [N, S] -> [N, E]\n",
145+
" stat_feat = stat_feat.unsqueeze(0).expand(y.size(0), -1, -1) # [N, E] -> [B, N, E]\n",
146+
" embeddings.append(stat_feat)\n",
147+
" \n",
148+
" # 沿特征维度拼接 [B, N, E*(1+H+F+S)]\n",
149+
" return torch.cat(embeddings, dim=-1)"
150+
]
151+
},
92152
{
93153
"cell_type": "code",
94154
"execution_count": null,
@@ -147,9 +207,9 @@
147207
" \"\"\"\n",
148208
"\n",
149209
" # Class attributes\n",
150-
" EXOGENOUS_FUTR = False\n",
151-
" EXOGENOUS_HIST = False\n",
152-
" EXOGENOUS_STAT = False\n",
210+
" EXOGENOUS_FUTR = True\n",
211+
" EXOGENOUS_HIST = True\n",
212+
" EXOGENOUS_STAT = True\n",
153213
" MULTIVARIATE = True\n",
154214
" RECURRENT = False\n",
155215
"\n",
@@ -238,7 +298,22 @@
238298
" self.use_norm = use_norm\n",
239299
"\n",
240300
" # Architecture\n",
241-
" self.enc_embedding = DataEmbedding_inverted(input_size, self.hidden_size, self.dropout)\n",
301+
" # Mix all features into one\n",
302+
" self.num_features = 1 + \\\n",
303+
" (len(hist_exog_list) if hist_exog_list else 0) + \\\n",
304+
" (len(futr_exog_list) if futr_exog_list else 0) + \\\n",
305+
" (len(stat_exog_list) if stat_exog_list else 0)\n",
306+
" adjusted_hidden = hidden_size // self.num_features\n",
307+
" self.hidden_size = adjusted_hidden * self.num_features\n",
308+
" self.feature_embedding = FeatureEmbedding(\n",
309+
" input_size=input_size,\n",
310+
" h=h,\n",
311+
" hidden_size=adjusted_hidden,\n",
312+
" hist_exog_size=len(hist_exog_list) if hist_exog_list else 0,\n",
313+
" futr_exog_size=len(futr_exog_list) if futr_exog_list else 0,\n",
314+
" stat_exog_size=len(stat_exog_list) if stat_exog_list else 0,\n",
315+
" dropout=dropout\n",
316+
" )\n",
242317
"\n",
243318
" self.encoder = TransEncoder(\n",
244319
" [\n",
@@ -256,7 +331,7 @@
256331
"\n",
257332
" self.projector = nn.Linear(self.hidden_size, h * self.loss.outputsize_multiplier, bias=True)\n",
258333
"\n",
259-
" def forecast(self, x_enc):\n",
334+
" def forecast(self, x_enc, hist_exog, futr_exog, stat_exog):\n",
260335
" if self.use_norm:\n",
261336
" # Normalization from Non-stationary Transformer\n",
262337
" means = x_enc.mean(1, keepdim=True).detach()\n",
@@ -271,14 +346,17 @@
271346
"\n",
272347
" # Embedding\n",
273348
" # B L N -> B N E (B L N -> B L E in the vanilla Transformer)\n",
274-
" enc_out = self.enc_embedding(x_enc, None) # covariates (e.g timestamp) can be also embedded as tokens\n",
349+
" # 特征融合\n",
350+
" enc_embed = self.feature_embedding(\n",
351+
" x_enc, \n",
352+
" hist_exog,\n",
353+
" futr_exog,\n",
354+
" stat_exog\n",
355+
" )\n",
275356
" \n",
276-
" # B N E -> B N E (B L E -> B L E in the vanilla Transformer)\n",
277-
" # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules\n",
278-
" enc_out, attns = self.encoder(enc_out, attn_mask=None)\n",
279-
"\n",
280-
" # B N E -> B N S -> B S N \n",
281-
" dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates\n",
357+
" # 后续处理保持原有流程不变\n",
358+
" enc_out, attns = self.encoder(enc_embed, attn_mask=None)\n",
359+
" dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :self.n_series]\n",
282360
"\n",
283361
" if self.use_norm:\n",
284362
" # De-Normalization from Non-stationary Transformer\n",
@@ -288,9 +366,12 @@
288366
" return dec_out\n",
289367
" \n",
290368
" def forward(self, windows_batch):\n",
291-
" insample_y = windows_batch['insample_y']\n",
369+
" insample_y = windows_batch['insample_y'] # [batch_size (B), input_size (L), n_series (N)]\n",
370+
" hist_exog = windows_batch['hist_exog'] # [B, hist_exog_size (X), L, N]\n",
371+
" futr_exog = windows_batch['futr_exog'] # [B, futr_exog_size (F), L + h, N]\n",
372+
" stat_exog = windows_batch['stat_exog'] # [N, stat_exog_size (S)]\n",
292373
"\n",
293-
" y_pred = self.forecast(insample_y)\n",
374+
" y_pred = self.forecast(insample_y, hist_exog, futr_exog, stat_exog)\n",
294375
" y_pred = y_pred.reshape(insample_y.shape[0],\n",
295376
" self.h,\n",
296377
" -1)\n",

nbs/models.softs.ipynb

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,73 @@
158158
" return output, None"
159159
]
160160
},
161+
{
162+
"cell_type": "markdown",
163+
"metadata": {},
164+
"source": [
165+
"### 1.3 FeatureEmbedding (mix of [y, futr, hist, stat])"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"metadata": {},
172+
"outputs": [],
173+
"source": [
174+
"#| exporti\n",
175+
"class FeatureEmbedding(nn.Module):\n",
176+
" \"\"\"\n",
177+
" 特征融合模块,通过分通道嵌入实现参数控制:\n",
178+
" 1. 将原始hidden_size均分给各特征通道\n",
179+
" 2. 各特征独立进行嵌入编码\n",
180+
" 3. 沿特征维度拼接最终结果\n",
181+
" \"\"\"\n",
182+
" def __init__(self, input_size, h, hidden_size, hist_exog_size, futr_exog_size, stat_exog_size, dropout):\n",
183+
" super().__init__()\n",
184+
" self.futr_input_size = input_size + h\n",
185+
" self.futr_exog_size = futr_exog_size\n",
186+
" self.hist_exog_size = hist_exog_size\n",
187+
" self.stat_exog_size = stat_exog_size\n",
188+
" self.base_embed = DataEmbedding_inverted(input_size, hidden_size, dropout)\n",
189+
" \n",
190+
" # 历史特征编码器\n",
191+
" self.hist_embed = nn.ModuleList([\n",
192+
" DataEmbedding_inverted(input_size, hidden_size, dropout)\n",
193+
" for _ in range(hist_exog_size)\n",
194+
" ])\n",
195+
" \n",
196+
" # 未来特征编码器(使用历史部分)\n",
197+
" self.futr_embed = nn.ModuleList([\n",
198+
" DataEmbedding_inverted(self.futr_input_size, hidden_size, dropout)\n",
199+
" for _ in range(futr_exog_size)\n",
200+
" ])\n",
201+
" # 静态特征编码(通过线性映射)\n",
202+
" self.stat_embed = nn.Linear(stat_exog_size, hidden_size) if stat_exog_size > 0 else None\n",
203+
"\n",
204+
" def forward(self, y, hist, futr, stat):\n",
205+
" # 基础序列嵌入 [B, N, E]\n",
206+
" embeddings = [self.base_embed(y, None)]\n",
207+
" \n",
208+
" # 历史特征嵌入 [B, N, E] * H\n",
209+
" if self.hist_exog_size > 0:\n",
210+
" for i, embed in enumerate(self.hist_embed):\n",
211+
" embeddings.append(embed(hist[:, i, :, :], None))\n",
212+
" \n",
213+
" # 未来特征嵌入 [B, N, E] * F\n",
214+
" if self.futr_exog_size > 0:\n",
215+
" for i, embed in enumerate(self.futr_embed):\n",
216+
" embeddings.append(embed(futr[:, i, :, :], None))\n",
217+
" \n",
218+
" # 静态特征嵌入 [B, N, E]\n",
219+
" if self.stat_embed is not None:\n",
220+
" stat_feat = self.stat_embed(stat) # [N, S] -> [N, E]\n",
221+
" stat_feat = stat_feat.unsqueeze(0).expand(y.size(0), -1, -1) # [N, E] -> [B, N, E]\n",
222+
" embeddings.append(stat_feat)\n",
223+
" \n",
224+
" # 沿特征维度拼接 [B, N, E*(1+H+F+S)]\n",
225+
" return torch.cat(embeddings, dim=-1)"
226+
]
227+
},
161228
{
162229
"cell_type": "markdown",
163230
"metadata": {},
@@ -220,9 +287,9 @@
220287
" \"\"\"\n",
221288
"\n",
222289
" # Class attributes\n",
223-
" EXOGENOUS_FUTR = False\n",
224-
" EXOGENOUS_HIST = False\n",
225-
" EXOGENOUS_STAT = False\n",
290+
" EXOGENOUS_FUTR = True\n",
291+
" EXOGENOUS_HIST = True\n",
292+
" EXOGENOUS_STAT = True\n",
226293
" MULTIVARIATE = True\n",
227294
" RECURRENT = False\n",
228295
"\n",
@@ -302,9 +369,22 @@
302369
" self.use_norm = use_norm\n",
303370
"\n",
304371
" # Architecture\n",
305-
" self.enc_embedding = DataEmbedding_inverted(input_size, \n",
306-
" hidden_size, \n",
307-
" dropout)\n",
372+
" # Mix all features into one\n",
373+
" self.num_features = 1 + \\\n",
374+
" (len(hist_exog_list) if hist_exog_list else 0) + \\\n",
375+
" (len(futr_exog_list) if futr_exog_list else 0) + \\\n",
376+
" (len(stat_exog_list) if stat_exog_list else 0)\n",
377+
" adjusted_hidden = hidden_size // self.num_features\n",
378+
" self.hidden_size = adjusted_hidden * self.num_features\n",
379+
" self.feature_embedding = FeatureEmbedding(\n",
380+
" input_size=input_size,\n",
381+
" h=h,\n",
382+
" hidden_size=adjusted_hidden,\n",
383+
" hist_exog_size=len(hist_exog_list) if hist_exog_list else 0,\n",
384+
" futr_exog_size=len(futr_exog_list) if futr_exog_list else 0,\n",
385+
" stat_exog_size=len(stat_exog_list) if stat_exog_list else 0,\n",
386+
" dropout=dropout\n",
387+
" )\n",
308388
" \n",
309389
" self.encoder = TransEncoder(\n",
310390
" [\n",
@@ -320,7 +400,7 @@
320400
"\n",
321401
" self.projection = nn.Linear(hidden_size, self.h * self.loss.outputsize_multiplier, bias=True)\n",
322402
"\n",
323-
" def forecast(self, x_enc):\n",
403+
" def forecast(self, x_enc, hist_exog, futr_exog, stat_exog):\n",
324404
" # Normalization from Non-stationary Transformer\n",
325405
" if self.use_norm:\n",
326406
" means = x_enc.mean(1, keepdim=True).detach()\n",
@@ -329,7 +409,12 @@
329409
" x_enc /= stdev\n",
330410
"\n",
331411
" _, _, N = x_enc.shape\n",
332-
" enc_out = self.enc_embedding(x_enc, None)\n",
412+
" enc_out = self.feature_embedding(\n",
413+
" x_enc, \n",
414+
" hist_exog,\n",
415+
" futr_exog,\n",
416+
" stat_exog\n",
417+
" )\n",
333418
" enc_out, attns = self.encoder(enc_out, attn_mask=None)\n",
334419
" dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]\n",
335420
"\n",
@@ -340,9 +425,12 @@
340425
" return dec_out\n",
341426
" \n",
342427
" def forward(self, windows_batch):\n",
343-
" insample_y = windows_batch['insample_y']\n",
428+
" insample_y = windows_batch['insample_y'] # [batch_size (B), input_size (L), n_series (N)]\n",
429+
" hist_exog = windows_batch['hist_exog'] # [B, hist_exog_size (X), L, N]\n",
430+
" futr_exog = windows_batch['futr_exog'] # [B, futr_exog_size (F), L + h, N]\n",
431+
" stat_exog = windows_batch['stat_exog'] # [N, stat_exog_size (S)]\n",
344432
"\n",
345-
" y_pred = self.forecast(insample_y)\n",
433+
" y_pred = self.forecast(insample_y, hist_exog, futr_exog, stat_exog)\n",
346434
" y_pred = y_pred.reshape(insample_y.shape[0],\n",
347435
" self.h,\n",
348436
" -1)\n",

nbs/models.timexer.ipynb

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@
281281
"\n",
282282
" # Class attributes\n",
283283
" EXOGENOUS_FUTR = True\n",
284-
" EXOGENOUS_HIST = False\n",
285-
" EXOGENOUS_STAT = False\n",
284+
" EXOGENOUS_HIST = True\n",
285+
" EXOGENOUS_STAT = True\n",
286286
" MULTIVARIATE = True # If the model produces multivariate forecasts (True) or univariate (False)\n",
287287
" RECURRENT = False # If the model produces forecasts recursively (True) or direct (False)\n",
288288
"\n",
@@ -367,10 +367,18 @@
367367
" self.patch_len = patch_len\n",
368368
" self.use_norm = use_norm\n",
369369
" self.patch_num = int(input_size // self.patch_len)\n",
370+
" \n",
371+
" self.futr_exog_size = len(futr_exog_list) if futr_exog_list is not None else 0\n",
372+
" self.hist_exog_size = len(hist_exog_list) if hist_exog_list is not None else 0\n",
373+
" self.stat_exog_size = len(stat_exog_list) if stat_exog_list is not None else 0\n",
370374
"\n",
371375
" # Architecture\n",
372376
" self.en_embedding = EnEmbedding(n_series, self.hidden_size, self.patch_len, self.dropout)\n",
373-
" self.ex_embedding = DataEmbedding_inverted(input_size, self.hidden_size, self.dropout)\n",
377+
" self.hist_ex_embedding = DataEmbedding_inverted(input_size, self.hidden_size, self.dropout)\n",
378+
" if futr_exog_list is not None:\n",
379+
" self.futr_ex_embedding = DataEmbedding_inverted(input_size+h, self.hidden_size, self.dropout)\n",
380+
" if stat_exog_list is not None:\n",
381+
" self.stat_ex_embedding = nn.Linear(len(stat_exog_list), hidden_size)\n",
374382
"\n",
375383
" self.encoder = Encoder(\n",
376384
" [\n",
@@ -396,18 +404,33 @@
396404
" self.head = FlattenHead(self.enc_in, self.head_nf, h * self.loss.outputsize_multiplier,\n",
397405
" head_dropout=self.dropout)\n",
398406
" \n",
399-
" def forecast(self, x_enc, x_mark_enc):\n",
407+
" def forecast(self, x_enc, futr_exog, hist_exog, stat_exog):\n",
400408
" if self.use_norm:\n",
401409
" # Normalization from Non-stationary Transformer\n",
402410
" means = x_enc.mean(1, keepdim=True).detach()\n",
403411
" x_enc = x_enc - means\n",
404412
" stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)\n",
405413
" x_enc /= stdev\n",
406414
"\n",
407-
" _, _, N = x_enc.shape\n",
415+
" B, _, N = x_enc.shape\n",
408416
"\n",
417+
" \n",
409418
" en_embed, n_vars = self.en_embedding(x_enc.permute(0, 2, 1))\n",
410-
" ex_embed = self.ex_embedding(x_enc, x_mark_enc)\n",
419+
" # concat exogenous embedding if exist\n",
420+
" if self.hist_exog_size > 0:\n",
421+
" B, V, T, D = hist_exog.shape\n",
422+
" ex_embed = self.hist_ex_embedding(x_enc, hist_exog.reshape(B, T, V*D))\n",
423+
" else:\n",
424+
" ex_embed = self.hist_ex_embedding(x_enc, None)\n",
425+
" if self.futr_exog_size > 0:\n",
426+
" B, V, T, D = futr_exog.shape\n",
427+
" futr_ex_embed = self.futr_ex_embedding(futr_exog.reshape(B, T, V*D), None)\n",
428+
" ex_embed = torch.cat([ex_embed, futr_ex_embed], dim=1)\n",
429+
" if self.stat_exog_size > 0:\n",
430+
" # stat_exog: [N, S] -> [N, E] -> [B, N, E]\n",
431+
" stat_embed = self.stat_ex_embedding(stat_exog) # [N, E]\n",
432+
" stat_embed = stat_embed.unsqueeze(0).expand(B, -1, -1) # [B, N, E]\n",
433+
" ex_embed = torch.cat([ex_embed, stat_embed], dim=1)\n",
411434
"\n",
412435
" enc_out = self.encoder(en_embed, ex_embed)\n",
413436
" enc_out = torch.reshape(\n",
@@ -426,17 +449,12 @@
426449
" return dec_out\n",
427450
" \n",
428451
" def forward(self, windows_batch):\n",
429-
" insample_y = windows_batch['insample_y']\n",
430-
" futr_exog = windows_batch['futr_exog']\n",
431-
" \n",
432-
" if self.futr_exog_size > 0:\n",
433-
" x_mark_enc = futr_exog[:, :, :self.input_size, :]\n",
434-
" B, V, T, D = x_mark_enc.shape\n",
435-
" x_mark_enc = x_mark_enc.reshape(B, T, V*D)\n",
436-
" else:\n",
437-
" x_mark_enc = None\n",
452+
" insample_y = windows_batch['insample_y'] # [batch_size (B), input_size (L), n_series (N)]\n",
453+
" hist_exog = windows_batch['hist_exog'] # [B, hist_exog_size (X), L, N]\n",
454+
" futr_exog = windows_batch['futr_exog'] # [B, futr_exog_size (F), L + h, N]\n",
455+
" stat_exog = windows_batch['stat_exog'] # [N, stat_exog_size (S)]\n",
438456
"\n",
439-
" y_pred = self.forecast(insample_y, x_mark_enc)\n",
457+
" y_pred = self.forecast(insample_y, futr_exog, hist_exog, stat_exog)\n",
440458
" y_pred = y_pred.reshape(insample_y.shape[0],\n",
441459
" self.h,\n",
442460
" -1)\n",

0 commit comments

Comments
 (0)