|
89 | 89 | "# 1. Model" |
90 | 90 | ] |
91 | 91 | }, |
| 92 | + { |
| 93 | + "cell_type": "code", |
| 94 | + "execution_count": null, |
| 95 | + "metadata": {}, |
| 96 | + "outputs": [], |
| 97 | + "source": [ |
| 98 | + "#| exporti\n", |
| 99 | + "class FeatureEmbedding(nn.Module):\n", |
| 100 | + " \"\"\"\n", |
| 101 | + " 特征融合模块,通过分通道嵌入实现参数控制:\n", |
| 102 | + " 1. 将原始hidden_size均分给各特征通道\n", |
| 103 | + " 2. 各特征独立进行嵌入编码\n", |
| 104 | + " 3. 沿特征维度拼接最终结果\n", |
| 105 | + " \"\"\"\n", |
| 106 | + " def __init__(self, input_size, h, hidden_size, hist_exog_size, futr_exog_size, stat_exog_size, dropout):\n", |
| 107 | + " super().__init__()\n", |
| 108 | + " self.futr_input_size = input_size + h\n", |
| 109 | + " self.futr_exog_size = futr_exog_size\n", |
| 110 | + " self.hist_exog_size = hist_exog_size\n", |
| 111 | + " self.stat_exog_size = stat_exog_size\n", |
| 112 | + " self.base_embed = DataEmbedding_inverted(input_size, hidden_size, dropout)\n", |
| 113 | + " \n", |
| 114 | + " # 历史特征编码器\n", |
| 115 | + " self.hist_embed = nn.ModuleList([\n", |
| 116 | + " DataEmbedding_inverted(input_size, hidden_size, dropout)\n", |
| 117 | + " for _ in range(hist_exog_size)\n", |
| 118 | + " ])\n", |
| 119 | + " \n", |
| 120 | + " # 未来特征编码器(使用历史部分)\n", |
| 121 | + " self.futr_embed = nn.ModuleList([\n", |
| 122 | + " DataEmbedding_inverted(self.futr_input_size, hidden_size, dropout)\n", |
| 123 | + " for _ in range(futr_exog_size)\n", |
| 124 | + " ])\n", |
| 125 | + " # 静态特征编码(通过线性映射)\n", |
| 126 | + " self.stat_embed = nn.Linear(stat_exog_size, hidden_size) if stat_exog_size > 0 else None\n", |
| 127 | + "\n", |
| 128 | + " def forward(self, y, hist, futr, stat):\n", |
| 129 | + " # 基础序列嵌入 [B, N, E]\n", |
| 130 | + " embeddings = [self.base_embed(y, None)]\n", |
| 131 | + " \n", |
| 132 | + " # 历史特征嵌入 [B, N, E] * H\n", |
| 133 | + " if self.hist_exog_size > 0:\n", |
| 134 | + " for i, embed in enumerate(self.hist_embed):\n", |
| 135 | + " embeddings.append(embed(hist[:, i, :, :], None))\n", |
| 136 | + " \n", |
| 137 | + " # 未来特征嵌入 [B, N, E] * F\n", |
| 138 | + " if self.futr_exog_size > 0:\n", |
| 139 | + " for i, embed in enumerate(self.futr_embed):\n", |
| 140 | + " embeddings.append(embed(futr[:, i, :, :], None))\n", |
| 141 | + " \n", |
| 142 | + " # 静态特征嵌入 [B, N, E]\n", |
| 143 | + " if self.stat_embed is not None:\n", |
| 144 | + " stat_feat = self.stat_embed(stat) # [N, S] -> [N, E]\n", |
| 145 | + " stat_feat = stat_feat.unsqueeze(0).expand(y.size(0), -1, -1) # [N, E] -> [B, N, E]\n", |
| 146 | + " embeddings.append(stat_feat)\n", |
| 147 | + " \n", |
| 148 | + " # 沿特征维度拼接 [B, N, E*(1+H+F+S)]\n", |
| 149 | + " return torch.cat(embeddings, dim=-1)" |
| 150 | + ] |
| 151 | + }, |
92 | 152 | { |
93 | 153 | "cell_type": "code", |
94 | 154 | "execution_count": null, |
|
147 | 207 | " \"\"\"\n", |
148 | 208 | "\n", |
149 | 209 | " # Class attributes\n", |
150 | | - " EXOGENOUS_FUTR = False\n", |
151 | | - " EXOGENOUS_HIST = False\n", |
152 | | - " EXOGENOUS_STAT = False\n", |
| 210 | + " EXOGENOUS_FUTR = True\n", |
| 211 | + " EXOGENOUS_HIST = True\n", |
| 212 | + " EXOGENOUS_STAT = True\n", |
153 | 213 | " MULTIVARIATE = True\n", |
154 | 214 | " RECURRENT = False\n", |
155 | 215 | "\n", |
|
238 | 298 | " self.use_norm = use_norm\n", |
239 | 299 | "\n", |
240 | 300 | " # Architecture\n", |
241 | | - " self.enc_embedding = DataEmbedding_inverted(input_size, self.hidden_size, self.dropout)\n", |
| 301 | + " # Mix all features into one\n", |
| 302 | + " self.num_features = 1 + \\\n", |
| 303 | + " (len(hist_exog_list) if hist_exog_list else 0) + \\\n", |
| 304 | + " (len(futr_exog_list) if futr_exog_list else 0) + \\\n", |
| 305 | + " (len(stat_exog_list) if stat_exog_list else 0)\n", |
| 306 | + " adjusted_hidden = hidden_size // self.num_features\n", |
| 307 | + " self.hidden_size = adjusted_hidden * self.num_features\n", |
| 308 | + " self.feature_embedding = FeatureEmbedding(\n", |
| 309 | + " input_size=input_size,\n", |
| 310 | + " h=h,\n", |
| 311 | + " hidden_size=adjusted_hidden,\n", |
| 312 | + " hist_exog_size=len(hist_exog_list) if hist_exog_list else 0,\n", |
| 313 | + " futr_exog_size=len(futr_exog_list) if futr_exog_list else 0,\n", |
| 314 | + " stat_exog_size=len(stat_exog_list) if stat_exog_list else 0,\n", |
| 315 | + " dropout=dropout\n", |
| 316 | + " )\n", |
242 | 317 | "\n", |
243 | 318 | " self.encoder = TransEncoder(\n", |
244 | 319 | " [\n", |
|
256 | 331 | "\n", |
257 | 332 | " self.projector = nn.Linear(self.hidden_size, h * self.loss.outputsize_multiplier, bias=True)\n", |
258 | 333 | "\n", |
259 | | - " def forecast(self, x_enc):\n", |
| 334 | + " def forecast(self, x_enc, hist_exog, futr_exog, stat_exog):\n", |
260 | 335 | " if self.use_norm:\n", |
261 | 336 | " # Normalization from Non-stationary Transformer\n", |
262 | 337 | " means = x_enc.mean(1, keepdim=True).detach()\n", |
|
271 | 346 | "\n", |
272 | 347 | " # Embedding\n", |
273 | 348 | " # B L N -> B N E (B L N -> B L E in the vanilla Transformer)\n", |
274 | | - " enc_out = self.enc_embedding(x_enc, None) # covariates (e.g timestamp) can be also embedded as tokens\n", |
| 349 | + " # 特征融合\n", |
| 350 | + " enc_embed = self.feature_embedding(\n", |
| 351 | + " x_enc, \n", |
| 352 | + " hist_exog,\n", |
| 353 | + " futr_exog,\n", |
| 354 | + " stat_exog\n", |
| 355 | + " )\n", |
275 | 356 | " \n", |
276 | | - " # B N E -> B N E (B L E -> B L E in the vanilla Transformer)\n", |
277 | | - " # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules\n", |
278 | | - " enc_out, attns = self.encoder(enc_out, attn_mask=None)\n", |
279 | | - "\n", |
280 | | - " # B N E -> B N S -> B S N \n", |
281 | | - " dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates\n", |
| 357 | + " # 后续处理保持原有流程不变\n", |
| 358 | + " enc_out, attns = self.encoder(enc_embed, attn_mask=None)\n", |
| 359 | + " dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :self.n_series]\n", |
282 | 360 | "\n", |
283 | 361 | " if self.use_norm:\n", |
284 | 362 | " # De-Normalization from Non-stationary Transformer\n", |
|
288 | 366 | " return dec_out\n", |
289 | 367 | " \n", |
290 | 368 | " def forward(self, windows_batch):\n", |
291 | | - " insample_y = windows_batch['insample_y']\n", |
| 369 | + " insample_y = windows_batch['insample_y'] # [batch_size (B), input_size (L), n_series (N)]\n", |
| 370 | + " hist_exog = windows_batch['hist_exog'] # [B, hist_exog_size (X), L, N]\n", |
| 371 | + " futr_exog = windows_batch['futr_exog'] # [B, futr_exog_size (F), L + h, N]\n", |
| 372 | + " stat_exog = windows_batch['stat_exog'] # [N, stat_exog_size (S)]\n", |
292 | 373 | "\n", |
293 | | - " y_pred = self.forecast(insample_y)\n", |
| 374 | + " y_pred = self.forecast(insample_y, hist_exog, futr_exog, stat_exog)\n", |
294 | 375 | " y_pred = y_pred.reshape(insample_y.shape[0],\n", |
295 | 376 | " self.h,\n", |
296 | 377 | " -1)\n", |
|
0 commit comments