@@ -102,11 +102,15 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
102102 attn_params = (Wq , bq , Wk , bk , Wv , bv , Wout , bout )
103103 features = cross_attention (attn_params , learnable_query , encodings , mask , n_heads , dropout )
104104 features = features [:, 0 ] # (B, 1, dim) => (B, dim)
105- hids = jnp . matmul (( features + learnable_query [:, 0 ]), Whid ) + bhid
106- hids = gelu ( hids )
105+ # MLP
106+ residual = features
107107 if use_LN : ## normalize hidden layer output of probe predictor
108- hids = layer_normalize (hids , Wln_mu , Wln_scale )
109- outs = jnp .matmul (hids , Wy ) + by
108+ features = layer_normalize (features , Wln_mu , Wln_scale )
109+ features = jnp .matmul ((features ), Whid ) + bhid
110+ features = gelu (features )
111+ features = residual + features
112+
113+ outs = jnp .matmul (features , Wy ) + by
110114 if use_softmax : ## apply softmax output nonlinearity
111115 outs = softmax (outs )
112116 return outs , features
@@ -178,10 +182,11 @@ class AttentiveProbe(Probe):
178182
179183 """
180184 def __init__ (
181- self , dkey , source_seq_length , input_dim , out_dim , num_heads = 8 , head_dim = 64 ,
185+ self , dkey , source_seq_length , input_dim , out_dim , num_heads = 8 , attn_dim = 64 ,
182186 target_seq_length = 1 , learnable_query_dim = 31 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_softmax = True , ** kwargs
183187 ):
184188 super ().__init__ (dkey , batch_size , ** kwargs )
189+ assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
185190 self .dkey , * subkeys = random .split (self .dkey , 12 )
186191 self .num_heads = num_heads
187192 self .source_seq_length = source_seq_length
@@ -192,24 +197,24 @@ def __init__(
192197
193198 sigma = 0.05
194199 ## cross-attention parameters
195- Wq = random .normal (subkeys [0 ], (learnable_query_dim , head_dim )) * sigma
196- bq = random .normal (subkeys [1 ], (1 , head_dim )) * sigma
197- Wk = random .normal (subkeys [2 ], (input_dim , head_dim )) * sigma
198- bk = random .normal (subkeys [3 ], (1 , head_dim )) * sigma
199- Wv = random .normal (subkeys [4 ], (input_dim , head_dim )) * sigma
200- bv = random .normal (subkeys [5 ], (1 , head_dim )) * sigma
201- Wout = random .normal (subkeys [6 ], (head_dim , learnable_query_dim )) * sigma
200+ Wq = random .normal (subkeys [0 ], (learnable_query_dim , attn_dim )) * sigma
201+ bq = random .normal (subkeys [1 ], (1 , attn_dim )) * sigma
202+ Wk = random .normal (subkeys [2 ], (input_dim , attn_dim )) * sigma
203+ bk = random .normal (subkeys [3 ], (1 , attn_dim )) * sigma
204+ Wv = random .normal (subkeys [4 ], (input_dim , attn_dim )) * sigma
205+ bv = random .normal (subkeys [5 ], (1 , attn_dim )) * sigma
206+ Wout = random .normal (subkeys [6 ], (attn_dim , learnable_query_dim )) * sigma
202207 bout = random .normal (subkeys [7 ], (1 , learnable_query_dim )) * sigma
203208 #params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
204209 learnable_query = jnp .zeros ((batch_size , 1 , learnable_query_dim )) # (B, T, D)
205210 #self.all_params = (learnable_query, *params)
206211 self .mask = np .zeros ((batch_size , target_seq_length , source_seq_length )).astype (bool ) ## mask tensor
207212 ## MLP parameters
208- Whid = random .normal (subkeys [8 ], (learnable_query_dim , hid_dim )) * sigma
209- bhid = random .normal (subkeys [9 ], (1 , hid_dim )) * sigma
210- Wln_mu = jnp .zeros ((1 , hid_dim ))
211- Wln_scale = jnp .ones ((1 , hid_dim ))
212- Wy = random .normal (subkeys [8 ], (hid_dim , out_dim )) * sigma
213+ Whid = random .normal (subkeys [8 ], (learnable_query_dim , learnable_query_dim )) * sigma
214+ bhid = random .normal (subkeys [9 ], (1 , learnable_query_dim )) * sigma
215+ Wln_mu = jnp .zeros ((1 , learnable_query_dim ))
216+ Wln_scale = jnp .ones ((1 , learnable_query_dim ))
217+ Wy = random .normal (subkeys [8 ], (learnable_query_dim , out_dim )) * sigma
213218 by = random .normal (subkeys [9 ], (1 , out_dim )) * sigma
214219 #mlp_params = (Whid, bhid, Wln_mu, Wln_scale, Wy, by)
215220 self .probe_params = (learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout , Whid , bhid , Wln_mu , Wln_scale , Wy , by )
0 commit comments