@@ -1430,6 +1430,90 @@ def shape(x):
14301430        return  a [:, 0 , :]  # cls_token 
14311431
14321432
1433+ class  MochiAttentionPool (nn .Module ):
1434+     def  __init__ (
1435+         self ,
1436+         num_attention_heads : int ,
1437+         embed_dim : int ,
1438+         output_dim : Optional [int ] =  None ,
1439+     ) ->  None :
1440+         super ().__init__ ()
1441+ 
1442+         self .output_dim  =  output_dim  or  embed_dim 
1443+         self .num_attention_heads  =  num_attention_heads 
1444+ 
1445+         self .to_kv  =  nn .Linear (embed_dim , 2  *  embed_dim )
1446+         self .to_q  =  nn .Linear (embed_dim , embed_dim )
1447+         self .to_out  =  nn .Linear (embed_dim , self .output_dim )
1448+     
1449+     @staticmethod  
1450+     def  pool_tokens (x : torch .Tensor , mask : torch .Tensor , * , keepdim = False ) ->  torch .Tensor :
1451+         """ 
1452+         Pool tokens in x using mask. 
1453+ 
1454+         NOTE: We assume x does not require gradients. 
1455+ 
1456+         Args: 
1457+             x: (B, L, D) tensor of tokens. 
1458+             mask: (B, L) boolean tensor indicating which tokens are not padding. 
1459+ 
1460+         Returns: 
1461+             pooled: (B, D) tensor of pooled tokens. 
1462+         """ 
1463+         assert  x .size (1 ) ==  mask .size (1 )  # Expected mask to have same length as tokens. 
1464+         assert  x .size (0 ) ==  mask .size (0 )  # Expected mask to have same batch size as tokens. 
1465+         mask  =  mask [:, :, None ].to (dtype = x .dtype )
1466+         mask  =  mask  /  mask .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
1467+         pooled  =  (x  *  mask ).sum (dim = 1 , keepdim = keepdim )
1468+         return  pooled 
1469+ 
1470+     def  forward (self , x : torch .Tensor , mask : torch .BoolTensor ) ->  torch .Tensor :
1471+         r""" 
1472+         Args: 
1473+             x (`torch.Tensor`): 
1474+                 Tensor of shape `(B, S, D)` of input tokens. 
1475+             mask (`torch.Tensor`): 
1476+                 Boolean ensor of shape `(B, S)` indicating which tokens are not padding. 
1477+ 
1478+         Returns: 
1479+             `torch.Tensor`: 
1480+                 `(B, D)` tensor of pooled tokens. 
1481+         """ 
1482+         D  =  x .size (2 )
1483+ 
1484+         # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). 
1485+         attn_mask  =  mask [:, None , None , :].bool ()  # (B, 1, 1, L). 
1486+         attn_mask  =  F .pad (attn_mask , (1 , 0 ), value = True )  # (B, 1, 1, 1+L). 
1487+ 
1488+         # Average non-padding token features. These will be used as the query. 
1489+         x_pool  =  self .pool_tokens (x , mask , keepdim = True )  # (B, 1, D) 
1490+ 
1491+         # Concat pooled features to input sequence. 
1492+         x  =  torch .cat ([x_pool , x ], dim = 1 )  # (B, L+1, D) 
1493+ 
1494+         # Compute queries, keys, values. Only the mean token is used to create a query. 
1495+         kv  =  self .to_kv (x )  # (B, L+1, 2 * D) 
1496+         q  =  self .to_q (x [:, 0 ])  # (B, D) 
1497+ 
1498+         # Extract heads. 
1499+         head_dim  =  D  //  self .num_attention_heads 
1500+         kv  =  kv .unflatten (2 , (2 , self .num_attention_heads , head_dim ))  # (B, 1+L, 2, H, head_dim) 
1501+         kv  =  kv .transpose (1 , 3 )  # (B, H, 2, 1+L, head_dim) 
1502+         k , v  =  kv .unbind (2 )  # (B, H, 1+L, head_dim) 
1503+         q  =  q .unflatten (1 , (self .num_attention_heads , head_dim ))  # (B, H, head_dim) 
1504+         q  =  q .unsqueeze (2 )  # (B, H, 1, head_dim) 
1505+ 
1506+         # Compute attention. 
1507+         x  =  F .scaled_dot_product_attention (
1508+             q , k , v , attn_mask = attn_mask , dropout_p = 0.0 
1509+         )  # (B, H, 1, head_dim) 
1510+ 
1511+         # Concatenate heads and run output. 
1512+         x  =  x .squeeze (2 ).flatten (1 , 2 )  # (B, D = H * head_dim) 
1513+         x  =  self .to_out (x )
1514+         return  x 
1515+ 
1516+ 
14331517def  get_fourier_embeds_from_boundingbox (embed_dim , box ):
14341518    """ 
14351519    Args: 
0 commit comments