@@ -65,12 +65,18 @@ def __init__(
6565 nn .Sigmoid ()
6666 ) if gate else None
6767
68- self .talking_heads = nn .Conv2d (heads , heads , 1 , bias = False ) if talking_heads else nn .Identity ()
68+ self .qa_talking_heads = nn .Conv2d (heads , heads , 1 , bias = False ) if talking_heads else nn .Identity ()
6969 self .ak_talking_heads = nn .Conv2d (heads , heads , 1 , bias = False ) if talking_heads else nn .Identity ()
7070
7171 self .qa_dropout = nn .Dropout (dropout )
7272 self .ak_dropout = nn .Dropout (dropout )
7373
74+ self .to_agent_out = nn .Sequential (
75+ nn .LayerNorm (dim_head ) if sub_layernorm else nn .Identity (),
76+ Rearrange ('b h n d -> b n (h d)' ),
77+ nn .Linear (dim_inner , dim , bias = False )
78+ )
79+
7480 self .to_out = nn .Sequential (
7581 nn .LayerNorm (dim_head ) if sub_layernorm else nn .Identity (),
7682 Rearrange ('b h n d -> b n (h d)' ),
@@ -127,7 +133,7 @@ def forward(
127133 agent_out = agent_out * self .to_gates (a )
128134
129135 out = self .to_out (out )
130- agent_out = self .to_out (agent_out )
136+ agent_out = self .to_agent_out (agent_out )
131137
132138 if not return_agent_tokens :
133139 return out
@@ -183,7 +189,8 @@ def forward(
183189 attn_out , agent_out = attn (
184190 x ,
185191 agent_tokens = a ,
186- mask = mask
192+ mask = mask ,
193+ return_agent_tokens = True
187194 )
188195
189196 a = a + agent_out
0 commit comments