Skip to content

Commit f3297c4

Browse files
authored
Merge pull request #111 from jrzaurin/fix_additive_attention
Fix additive attention (#110)
2 parents 1c8709f + 7e5d118 commit f3297c4

File tree

5 files changed

+20
-8
lines changed

5 files changed

+20
-8
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.2.0
1+
1.2.1

examples/notebooks/10_3rd_party_integration-RayTune_WnB.ipynb

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
" quantity.\n",
144144
"\n",
145145
" \"\"\"\n",
146+
"\n",
146147
" def __init__(\n",
147148
" self,\n",
148149
" wb: object,\n",
@@ -1061,7 +1062,7 @@
10611062
" \"wandb\": {\n",
10621063
" \"project\": \"test\",\n",
10631064
" # \"api_key_file\": os.getcwd() + \"/wandb_api.key\",\n",
1064-
" \"api_key\": \"WNB_API_KEY\", \n",
1065+
" \"api_key\": \"WNB_API_KEY\",\n",
10651066
" },\n",
10661067
"}\n",
10671068
"\n",
@@ -1080,7 +1081,12 @@
10801081
" trainer = Trainer(\n",
10811082
" model,\n",
10821083
" objective=\"binary_focal_loss\",\n",
1083-
" callbacks=[RayTuneReporter, WnBReportBest(wb=wandb), early_stopping, model_checkpoint],\n",
1084+
" callbacks=[\n",
1085+
" RayTuneReporter,\n",
1086+
" WnBReportBest(wb=wandb),\n",
1087+
" early_stopping,\n",
1088+
" model_checkpoint,\n",
1089+
" ],\n",
10841090
" lr_schedulers={\"deeptabular\": deep_sch},\n",
10851091
" initializers={\"deeptabular\": XavierNormal},\n",
10861092
" optimizers={\"deeptabular\": deep_opt},\n",

mkdocs/sources/examples/10_3rd_party_integration-RayTune_WnB.ipynb

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
" quantity.\n",
144144
"\n",
145145
" \"\"\"\n",
146+
"\n",
146147
" def __init__(\n",
147148
" self,\n",
148149
" wb: object,\n",
@@ -1061,7 +1062,7 @@
10611062
" \"wandb\": {\n",
10621063
" \"project\": \"test\",\n",
10631064
" # \"api_key_file\": os.getcwd() + \"/wandb_api.key\",\n",
1064-
" \"api_key\": \"WNB_API_KEY\", \n",
1065+
" \"api_key\": \"WNB_API_KEY\",\n",
10651066
" },\n",
10661067
"}\n",
10671068
"\n",
@@ -1080,7 +1081,12 @@
10801081
" trainer = Trainer(\n",
10811082
" model,\n",
10821083
" objective=\"binary_focal_loss\",\n",
1083-
" callbacks=[RayTuneReporter, WnBReportBest(wb=wandb), early_stopping, model_checkpoint],\n",
1084+
" callbacks=[\n",
1085+
" RayTuneReporter,\n",
1086+
" WnBReportBest(wb=wandb),\n",
1087+
" early_stopping,\n",
1088+
" model_checkpoint,\n",
1089+
" ],\n",
10841090
" lr_schedulers={\"deeptabular\": deep_sch},\n",
10851091
" initializers={\"deeptabular\": XavierNormal},\n",
10861092
" optimizers={\"deeptabular\": deep_opt},\n",

pytorch_widedeep/models/tabular/transformers/_attention_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,14 @@ def forward(self, X: Tensor) -> Tensor:
217217
v = self.qv_proj(X) if self.share_qv_weights else self.v_proj(X)
218218
k = self.k_proj(X)
219219

220-
alphas = (self.W_q(q) / math.sqrt(self.head_dim)).softmax(dim=-1)
220+
alphas = (self.W_q(q) / math.sqrt(self.head_dim)).softmax(dim=1)
221221
q_r = einops.rearrange(q, "b s (h d) -> b s h d", h=self.n_heads)
222222
global_query = einsum(" b s h, b s h d -> b h d", alphas, q_r)
223223
global_query = einops.rearrange(global_query, "b h d -> b () (h d)")
224224

225225
p = k * global_query
226226

227-
betas = (self.W_k(p) / math.sqrt(self.head_dim)).softmax(dim=-1)
227+
betas = (self.W_k(p) / math.sqrt(self.head_dim)).softmax(dim=1)
228228
p_r = einops.rearrange(p, "b s (h d) -> b s h d", h=self.n_heads)
229229
global_key = einsum(" b s h, b s h d -> b h d", betas, p_r)
230230
global_key = einops.rearrange(global_key, "b h d -> b () (h d)")

pytorch_widedeep/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.2.0"
1+
__version__ = "1.2.1"

0 commit comments

Comments
 (0)