@@ -219,7 +219,9 @@ class AttentiveProbe(Probe):
219219 """
220220 def __init__ (
221221 self , dkey , source_seq_length , input_dim , out_dim , num_heads = 8 , attn_dim = 64 ,
222- target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 , use_LN = True , use_LN_input = False , use_softmax = True , ** kwargs
222+ target_seq_length = 1 , learnable_query_dim = 32 , batch_size = 1 , hid_dim = 32 ,
223+ use_LN = True , use_LN_input = False , use_softmax = True , dropout = 0.5 , eta = 0.0002 ,
224+ eta_decay = 0.0 , min_eta = 1e-5 , ** kwargs
223225 ):
224226 super ().__init__ (dkey , batch_size , ** kwargs )
225227 assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
@@ -232,9 +234,9 @@ def __init__(
232234 self .use_softmax = use_softmax
233235 self .use_LN = use_LN
234236 self .use_LN_input = use_LN_input
235- self .dropout = 0.5
237+ self .dropout = dropout
236238
237- sigma = 0.05
239+ sigma = 0.02
238240 ## cross-attention parameters
239241 Wq = random .normal (subkeys [0 ], (learnable_query_dim , attn_dim )) * sigma
240242 bq = random .normal (subkeys [1 ], (1 , attn_dim )) * sigma
@@ -287,7 +289,10 @@ def __init__(
287289 self .grad_fx = jax .value_and_grad (eval_attention_probe , argnums = 1 , has_aux = True ) #, allow_int=True)
288290 ## set up update rule/optimizer
289291 self .optim_params = adam .adam_init (self .probe_params )
290- self .eta = 0.0002 #0.001
292+ # Learning rate scheduling
293+ self .eta = eta #0.001
294+ self .eta_decay = eta_decay
295+ self .min_eta = min_eta
291296
292297 # Finally, the dkey for the noise_key
293298 self .noise_key = subkeys [24 ]
@@ -319,5 +324,7 @@ def update(self, embeddings, labels, dkey=None):
319324 self .optim_params , self .probe_params = adam .adam_step (
320325 self .optim_params , self .probe_params , grads , eta = self .eta
321326 )
327+
328+ self .eta = max (self .min_eta , self .eta - self .eta_decay * self .eta )
322329 return loss , predictions
323330
0 commit comments