@@ -70,7 +70,7 @@ def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax
7070 score = jax .nn .softmax (score , axis = - 1 ) # (B, H, T, S)
7171 score = score .astype (q .dtype ) # (B, H, T, S)
7272 if dropout_rate > 0. :
73- score , _ = drop_out (dkey , input = score , rate = dropout_rate ) ## NOTE: normally you apply dropout here
73+ score , _ = drop_out (dkey , score , rate = dropout_rate ) ## NOTE: normally you apply dropout here
7474 attention = jnp .einsum ("BHTS,BHSE->BHTE" , score , v ) # (B, T, H, E)
7575 attention = attention .transpose ([0 , 2 , 1 , 3 ]).reshape ((B , T , - 1 )) # (B, T, H, E) => (B, T, D)
7676 return attention @ Wout + bout # (B, T, Dq)
@@ -105,6 +105,8 @@ def run_attention_probe(
105105 Returns:
106106 output scores/probabilities, cross-attention (hidden) features
107107 """
108+ # Two separate dkeys for each dropout in two cross attention
109+ dkey1 , dkey2 = random .split (dkey , 2 )
108110 # encoded_image_feature: (B, hw, dim)
109111 #learnable_query, *_params) = params
110112 learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout ,\
@@ -116,14 +118,14 @@ def run_attention_probe(
116118 if use_LN_input :
117119 learnable_query = layer_normalize (learnable_query , ln_in_mu , ln_in_scale )
118120 encodings = layer_normalize (encodings , ln_in_mu2 , ln_in_scale2 )
119- features = cross_attention (dkey , cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
121+ features = cross_attention (dkey1 , cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
120122 # Perform a single self-attention block here
121123 # Self-Attention
122124 self_attn_params = (Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts )
123125 skip = features
124126 if use_LN :
125127 features = layer_normalize (features , Wlnattn_mu , Wlnattn_scale )
126- features = cross_attention (dkey , self_attn_params , features , features , None , n_heads , dropout )
128+ features = cross_attention (dkey2 , self_attn_params , features , features , None , n_heads , dropout )
127129 features = features + skip
128130 features = features [:, 0 ] # (B, 1, dim) => (B, dim)
129131 # MLP
@@ -222,7 +224,7 @@ def __init__(
222224 super ().__init__ (dkey , batch_size , ** kwargs )
223225 assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
224226 assert learnable_query_dim % num_heads == 0 , f"`learnable_query_dim` must be divisible by `num_heads`. Got { learnable_query_dim } and { num_heads } ."
225- self .dkey , * subkeys = random .split (self .dkey , 25 )
227+ self .dkey , * subkeys = random .split (self .dkey , 26 )
226228 self .num_heads = num_heads
227229 self .source_seq_length = source_seq_length
228230 self .input_dim = input_dim
@@ -287,8 +289,12 @@ def __init__(
287289 self .optim_params = adam .adam_init (self .probe_params )
288290 self .eta = 0.0002 #0.001
289291
292+ # Finally, the dkey for the noise_key
293+ self .noise_key = subkeys [24 ]
294+
290295 def process (self , embeddings , dkey = None ):
291- noise_key = None
296+ # noise_key = None
297+ noise_key = self .noise_key
292298 if dkey is not None :
293299 dkey , * subkeys = random .split (dkey , 2 )
294300 noise_key = subkeys [0 ]
@@ -299,7 +305,8 @@ def process(self, embeddings, dkey=None):
299305 return outs
300306
301307 def update (self , embeddings , labels , dkey = None ):
302- noise_key = None
308+ # noise_key = None
309+ noise_key = self .noise_key
303310 if dkey is not None :
304311 dkey , * subkeys = random .split (dkey , 2 )
305312 noise_key = subkeys [0 ]
0 commit comments