Skip to content

Commit 8e93aae

Browse files
committed
backward compatibility
1 parent cc8b359 commit 8e93aae

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

nbs/common.modules.ipynb

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,13 @@
509509
" A = self.dropout(torch.softmax(scale * scores, dim=-1))\n",
510510
" V = torch.einsum(\"bhls,bshd->blhd\", A, values)\n",
511511
"\n",
512-
" return (V.contiguous(), A) if self.output_attention else (V.contiguous(), None) "
512+
" return (V.contiguous(), A) if self.output_attention else (V.contiguous(), None) \n",
513+
"\n",
514+
" def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):\n",
515+
" # Manually initialize `atten` if not in state_dict\n",
516+
" if prefix + 'atten' not in state_dict:\n",
517+
" self.atten = \"full\"\n",
518+
" super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)"
513519
]
514520
},
515521
{

neuralforecast/common/_modules.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,12 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
429429
(V.contiguous(), A) if self.output_attention else (V.contiguous(), None)
430430
)
431431

432+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
433+
# Manually initialize `atten` if not in state_dict
434+
if prefix + "atten" not in state_dict:
435+
self.atten = "full"
436+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
437+
432438
# %% ../../nbs/common.modules.ipynb 19
433439
class PositionalEmbedding(nn.Module):
434440
def __init__(self, hidden_size, max_len=5000):

0 commit comments

Comments
 (0)