@@ -78,7 +78,7 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
7878 """
7979 Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
8080
81- Args:
81+ Args:
8282 params: parameters tuple/list of probe
8383
8484 encodings: input encoding vectors/data
@@ -98,18 +98,35 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
9898 """
9999 # encoded_image_feature: (B, hw, dim)
100100 #learnable_query, *_params) = params
101- learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout , Whid , bhid , Wln_mu , Wln_scale , Wy , by = params
102- attn_params = (Wq , bq , Wk , bk , Wv , bv , Wout , bout )
103- features = cross_attention (attn_params , learnable_query , encodings , mask , n_heads , dropout )
101+ learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout ,\
102+ Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts , Wlnattn_mu ,\
103+ Wlnattn_scale , Whid1 , bhid1 , Wln_mu1 , Wln_scale1 , Whid2 ,\
104+ bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 , Wy , by = params
105+ cross_attn_params = (Wq , bq , Wk , bk , Wv , bv , Wout , bout )
106+ features = cross_attention (cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
107+ # Perform a single self-attention block here
108+ # Self-Attention
109+ self_attn_params = (Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts )
110+ skip = features
111+ if use_LN :
112+ features = layer_normalize (features , Wlnattn_mu , Wlnattn_scale )
113+ features = cross_attention (self_attn_params , features , features , None , n_heads , dropout )
114+ features = features + skip
104115 features = features [:, 0 ] # (B, 1, dim) => (B, dim)
105116 # MLP
106- residual = features
117+ skip = features
107118 if use_LN : ## normalize hidden layer output of probe predictor
108- features = layer_normalize (features , Wln_mu , Wln_scale )
109- features = jnp .matmul ((features ), Whid ) + bhid
119+ features = layer_normalize (features , Wln_mu1 , Wln_scale1 )
120+ features = jnp .matmul ((features ), Whid1 ) + bhid1
110121 features = gelu (features )
111- features = residual + features
112-
122+ if use_LN : ## normalize hidden layer output of probe predictor
123+ features = layer_normalize (features , Wln_mu2 , Wln_scale2 )
124+ features = jnp .matmul ((features ), Whid2 ) + bhid2
125+ features = gelu (features )
126+ if use_LN : ## normalize hidden layer output of probe predictor
127+ features = layer_normalize (features , Wln_mu3 , Wln_scale3 )
128+ features = jnp .matmul ((features ), Whid3 ) + bhid3
129+ features = features + skip
113130 outs = jnp .matmul (features , Wy ) + by
114131 if use_softmax : ## apply softmax output nonlinearity
115132 outs = softmax (outs )
@@ -183,11 +200,12 @@ class AttentiveProbe(Probe):
183200 """
184201 def __init__ (
185202 self , dkey , source_seq_length , input_dim , out_dim , num_heads = 8 , attn_dim = 64 ,
186- target_seq_length = 1 , learnable_query_dim = 31 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_softmax = True , ** kwargs
203+ target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_softmax = True , ** kwargs
187204 ):
188205 super ().__init__ (dkey , batch_size , ** kwargs )
189206 assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
190- self .dkey , * subkeys = random .split (self .dkey , 12 )
207+ assert learnable_query_dim % num_heads == 0 , f"`learnable_query_dim` must be divisible by `num_heads`. Got { learnable_query_dim } and { num_heads } ."
208+ self .dkey , * subkeys = random .split (self .dkey , 25 )
191209 self .num_heads = num_heads
192210 self .source_seq_length = source_seq_length
193211 self .input_dim = input_dim
@@ -205,19 +223,37 @@ def __init__(
205223 bv = random .normal (subkeys [5 ], (1 , attn_dim )) * sigma
206224 Wout = random .normal (subkeys [6 ], (attn_dim , learnable_query_dim )) * sigma
207225 bout = random .normal (subkeys [7 ], (1 , learnable_query_dim )) * sigma
208- #params = (Wq, bq, Wk, bk, Wv, bv, Wout, bout)
226+ cross_attn_params = (Wq , bq , Wk , bk , Wv , bv , Wout , bout )
227+ Wqs = random .normal (subkeys [8 ], (learnable_query_dim , learnable_query_dim )) * sigma
228+ bqs = random .normal (subkeys [9 ], (1 , learnable_query_dim )) * sigma
229+ Wks = random .normal (subkeys [10 ], (learnable_query_dim , learnable_query_dim )) * sigma
230+ bks = random .normal (subkeys [11 ], (1 , learnable_query_dim )) * sigma
231+ Wvs = random .normal (subkeys [12 ], (learnable_query_dim , learnable_query_dim )) * sigma
232+ bvs = random .normal (subkeys [13 ], (1 , learnable_query_dim )) * sigma
233+ Wouts = random .normal (subkeys [14 ], (learnable_query_dim , learnable_query_dim )) * sigma
234+ bouts = random .normal (subkeys [15 ], (1 , learnable_query_dim )) * sigma
235+ Wlnattn_mu = jnp .zeros ((1 , learnable_query_dim ))
236+ Wlnattn_scale = jnp .ones ((1 , learnable_query_dim ))
237+ self_attn_params = (Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts , Wlnattn_mu , Wlnattn_scale )
209238 learnable_query = jnp .zeros ((batch_size , 1 , learnable_query_dim )) # (B, T, D)
210- #self.all_params = (learnable_query, *params)
211239 self .mask = np .zeros ((batch_size , target_seq_length , source_seq_length )).astype (bool ) ## mask tensor
212240 ## MLP parameters
213- Whid = random .normal (subkeys [8 ], (learnable_query_dim , learnable_query_dim )) * sigma
214- bhid = random .normal (subkeys [9 ], (1 , learnable_query_dim )) * sigma
215- Wln_mu = jnp .zeros ((1 , learnable_query_dim ))
216- Wln_scale = jnp .ones ((1 , learnable_query_dim ))
217- Wy = random .normal (subkeys [8 ], (learnable_query_dim , out_dim )) * sigma
218- by = random .normal (subkeys [9 ], (1 , out_dim )) * sigma
219- #mlp_params = (Whid, bhid, Wln_mu, Wln_scale, Wy, by)
220- self .probe_params = (learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout , Whid , bhid , Wln_mu , Wln_scale , Wy , by )
241+ Whid1 = random .normal (subkeys [16 ], (learnable_query_dim , learnable_query_dim )) * sigma
242+ bhid1 = random .normal (subkeys [17 ], (1 , learnable_query_dim )) * sigma
243+ Wln_mu1 = jnp .zeros ((1 , learnable_query_dim ))
244+ Wln_scale1 = jnp .ones ((1 , learnable_query_dim ))
245+ Whid2 = random .normal (subkeys [18 ], (learnable_query_dim , learnable_query_dim * 4 )) * sigma
246+ bhid2 = random .normal (subkeys [19 ], (1 , learnable_query_dim * 4 )) * sigma
247+ Wln_mu2 = jnp .zeros ((1 , learnable_query_dim ))
248+ Wln_scale2 = jnp .ones ((1 , learnable_query_dim ))
249+ Whid3 = random .normal (subkeys [20 ], (learnable_query_dim * 4 , learnable_query_dim )) * sigma
250+ bhid3 = random .normal (subkeys [21 ], (1 , learnable_query_dim )) * sigma
251+ Wln_mu3 = jnp .zeros ((1 , learnable_query_dim * 4 ))
252+ Wln_scale3 = jnp .ones ((1 , learnable_query_dim * 4 ))
253+ Wy = random .normal (subkeys [22 ], (learnable_query_dim , out_dim )) * sigma
254+ by = random .normal (subkeys [23 ], (1 , out_dim )) * sigma
255+ mlp_params = (Whid1 , bhid1 , Wln_mu1 , Wln_scale1 , Whid2 , bhid2 , Wln_mu2 , Wln_scale2 , Whid3 , bhid3 , Wln_mu3 , Wln_scale3 , Wy , by )
256+ self .probe_params = (learnable_query , * cross_attn_params , * self_attn_params , * mlp_params )
221257
222258 ## set up gradient calculator
223259 self .grad_fx = jax .value_and_grad (eval_attention_probe , argnums = 0 , has_aux = True )
0 commit comments