@@ -236,7 +236,8 @@ def __init__(
236236 Wlnattn_scale = jnp .ones ((1 , learnable_query_dim )) ## LN parameter (applied to output of attention)
237237 self_attn_params = (Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts , Wlnattn_mu , Wlnattn_scale )
238238 learnable_query = jnp .zeros ((batch_size , 1 , learnable_query_dim )) # (B, T, D)
239- self .mask = np .zeros ((batch_size , target_seq_length , source_seq_length )).astype (bool ) ## mask tensor
239+ self .mask = np .zeros ((self .batch_size , target_seq_length , source_seq_length )).astype (bool ) ## mask tensor
240+ self .dev_mask = np .zeros ((self .dev_batch_size , target_seq_length , source_seq_length )).astype (bool )
240241 ## MLP parameters
241242 Whid1 = random .normal (subkeys [16 ], (learnable_query_dim , learnable_query_dim )) * sigma
242243 bhid1 = random .normal (subkeys [17 ], (1 , learnable_query_dim )) * sigma
@@ -259,7 +260,7 @@ def __init__(
259260 self .grad_fx = jax .value_and_grad (eval_attention_probe , argnums = 0 , has_aux = True )
260261 ## set up update rule/optimizer
261262 self .optim_params = adam .adam_init (self .probe_params )
262- self .eta = 0.001
263+ self .eta = 0.0002 #0. 001
263264
264265 def process (self , embedding_sequence ):
265266 """
@@ -271,13 +272,14 @@ def process(self, embedding_sequence):
271272 Returns:
272273 probe output scores/probability values
273274 """
275+ #print(embedding_sequence.shape)
274276 outs , feats = run_attention_probe (
275- self .probe_params , embedding_sequence , self .mask , self .num_heads , 0.0 , use_LN = self .use_LN ,
277+ self .probe_params , embedding_sequence , self .dev_mask , self .num_heads , 0.0 , use_LN = self .use_LN ,
276278 use_softmax = self .use_softmax
277279 )
278280 return outs
279281
280- def update (self , embedding_sequence , labels ):
282+ def update (self , embedding_sequence , labels , dkey = None ):
281283 """
282284 Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally
283285 assigned labels/target vector values.
@@ -290,9 +292,10 @@ def update(self, embedding_sequence, labels):
290292 Returns:
291293 probe output scores/probability values
292294 """
295+ # TODO: put in dkey to facilitate dropout
293296 ## compute partial derivatives / adjustments to probe parameters
294297 outputs , grads = self .grad_fx (
295- self .probe_params , embedding_sequence , labels , self .mask , self .num_heads , dropout = 0. , use_LN = self .use_LN ,
298+ self .probe_params , embedding_sequence , labels , self .mask , self .num_heads , dropout = 0.5 , use_LN = self .use_LN ,
296299 use_softmax = self .use_softmax
297300 )
298301 loss , predictions = outputs
0 commit comments