Skip to content

Commit f06f329

Browse files
committed
separate combine head projection for agent tokens
1 parent e5653e8 commit f06f329

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

agent_attention_pytorch/agent_transformer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,18 @@ def __init__(
6565
nn.Sigmoid()
6666
) if gate else None
6767

68-
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
68+
self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
6969
self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
7070

7171
self.qa_dropout = nn.Dropout(dropout)
7272
self.ak_dropout = nn.Dropout(dropout)
7373

74+
self.to_agent_out = nn.Sequential(
75+
nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
76+
Rearrange('b h n d -> b n (h d)'),
77+
nn.Linear(dim_inner, dim, bias = False)
78+
)
79+
7480
self.to_out = nn.Sequential(
7581
nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
7682
Rearrange('b h n d -> b n (h d)'),
@@ -127,7 +133,7 @@ def forward(
127133
agent_out = agent_out * self.to_gates(a)
128134

129135
out = self.to_out(out)
130-
agent_out = self.to_out(agent_out)
136+
agent_out = self.to_agent_out(agent_out)
131137

132138
if not return_agent_tokens:
133139
return out
@@ -183,7 +189,8 @@ def forward(
183189
attn_out, agent_out = attn(
184190
x,
185191
agent_tokens = a,
186-
mask = mask
192+
mask = mask,
193+
return_agent_tokens = True
187194
)
188195

189196
a = a + agent_out

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'agent-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.6',
6+
version = '0.1.7',
77
license='MIT',
88
description = 'Agent Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)