Skip to content

Commit e5653e8

Browse files
committed
ready for use in another project
1 parent 1e6b926 commit e5653e8

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

agent_attention_pytorch/agent_transformer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
nn.Sigmoid()
6666
) if gate else None
6767

68-
self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
68+
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
6969
self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
7070

7171
self.qa_dropout = nn.Dropout(dropout)
@@ -146,6 +146,7 @@ def __init__(
146146
heads = 8,
147147
dim_head = 64,
148148
ff_mult = 4,
149+
final_norm = True,
149150
**attn_kwargs: dict
150151
):
151152
super().__init__()
@@ -167,6 +168,8 @@ def __init__(
167168
FeedForward(dim = dim, mult = ff_mult)
168169
]))
169170

171+
self.final_norm = RMSNorm(dim) if final_norm else None
172+
170173
def forward(
171174
self,
172175
x,
@@ -192,6 +195,10 @@ def forward(
192195

193196
a, x = unpack(x, ps, 'b * d')
194197

198+
if exists(self.final_norm):
199+
x = self.final_norm(x)
200+
a = self.final_norm(a)
201+
195202
if not return_agent_tokens:
196203
return x
197204

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'agent-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.5',
6+
version = '0.1.6',
77
license='MIT',
88
description = 'Agent Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)