@@ -22,7 +22,25 @@ def masked_fill(x: jax.Array, mask: jax.Array, value=0) -> jax.Array:
2222 return jnp .where (mask , jnp .broadcast_to (value , x .shape ), x )
2323
2424@bind (jax .jit , static_argnums = [4 , 5 ])
25- def cross_attention (params : tuple , x1 : jax .Array , x2 : jax .Array , mask : jax .Array , n_heads : int = 8 , dropout_rate : float = 0.0 ):
25+ def cross_attention (params : tuple , x1 : jax .Array , x2 : jax .Array , mask : jax .Array , n_heads : int = 8 , dropout_rate : float = 0.0 ) -> jax .Array :
26+ """
27+ Run cross-attention function given a list of parameters and two sequences (x1 and x2).
28+ The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1.
29+ T is the length of the query sequence, and S is the length of the key-value sequence.
30+ Dq is the dimension of the query sequence, and Dkv is the dimension of the key-value sequence.
31+ H is the number of attention heads.
32+
33+ Args:
34+ params (tuple): tuple of parameters
35+ x1 (jax.Array): query sequence. Shape: (B, T, Dq)
36+ x2 (jax.Array): key-value sequence. Shape: (B, S, Dkv)
37+ mask (jax.Array): mask tensor. Shape: (B, T, S)
38+ n_heads (int, optional): number of attention heads. Defaults to 8.
39+ dropout_rate (float, optional): dropout rate. Defaults to 0.0.
40+
41+ Returns:
42+ jax.Array: output of cross-attention
43+ """
2644 B , T , Dq = x1 .shape # The original shape
2745 _ , S , Dkv = x2 .shape
2846 # in here we attend x2 to x1
0 commit comments