@@ -103,6 +103,7 @@ def __init__(
103103 upcast_softmax : bool = False ,
104104 cross_attention_norm : Optional [str ] = None ,
105105 cross_attention_norm_num_groups : int = 32 ,
106+ qk_norm : Optional [str ] = None ,
106107 added_kv_proj_dim : Optional [int ] = None ,
107108 norm_num_groups : Optional [int ] = None ,
108109 spatial_norm_dim : Optional [int ] = None ,
@@ -161,6 +162,15 @@ def __init__(
161162 else :
162163 self .spatial_norm = None
163164
165+ if qk_norm is None :
166+ self .norm_q = None
167+ self .norm_k = None
168+ elif qk_norm == "layer_norm" :
169+ self .norm_q = nn .LayerNorm (dim_head , eps = eps )
170+ self .norm_k = nn .LayerNorm (dim_head , eps = eps )
171+ else :
172+ raise ValueError (f"unknown qk_norm: { qk_norm } . Should be None or 'layer_norm'" )
173+
164174 if cross_attention_norm is None :
165175 self .norm_cross = None
166176 elif cross_attention_norm == "layer_norm" :
@@ -1426,6 +1436,104 @@ def __call__(
14261436 return hidden_states
14271437
14281438
1439+ class HunyuanAttnProcessor2_0 :
1440+ r"""
1441+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1442+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
1443+ """
1444+
1445+ def __init__ (self ):
1446+ if not hasattr (F , "scaled_dot_product_attention" ):
1447+ raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
1448+
1449+ def __call__ (
1450+ self ,
1451+ attn : Attention ,
1452+ hidden_states : torch .Tensor ,
1453+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
1454+ attention_mask : Optional [torch .Tensor ] = None ,
1455+ temb : Optional [torch .Tensor ] = None ,
1456+ image_rotary_emb : Optional [torch .Tensor ] = None ,
1457+ ) -> torch .Tensor :
1458+ from .embeddings import apply_rotary_emb
1459+
1460+ residual = hidden_states
1461+ if attn .spatial_norm is not None :
1462+ hidden_states = attn .spatial_norm (hidden_states , temb )
1463+
1464+ input_ndim = hidden_states .ndim
1465+
1466+ if input_ndim == 4 :
1467+ batch_size , channel , height , width = hidden_states .shape
1468+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1469+
1470+ batch_size , sequence_length , _ = (
1471+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
1472+ )
1473+
1474+ if attention_mask is not None :
1475+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
1476+ # scaled_dot_product_attention expects attention_mask shape to be
1477+ # (batch, heads, source_length, target_length)
1478+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
1479+
1480+ if attn .group_norm is not None :
1481+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
1482+
1483+ query = attn .to_q (hidden_states )
1484+
1485+ if encoder_hidden_states is None :
1486+ encoder_hidden_states = hidden_states
1487+ elif attn .norm_cross :
1488+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
1489+
1490+ key = attn .to_k (encoder_hidden_states )
1491+ value = attn .to_v (encoder_hidden_states )
1492+
1493+ inner_dim = key .shape [- 1 ]
1494+ head_dim = inner_dim // attn .heads
1495+
1496+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1497+
1498+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1499+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1500+
1501+ if attn .norm_q is not None :
1502+ query = attn .norm_q (query )
1503+ if attn .norm_k is not None :
1504+ key = attn .norm_k (key )
1505+
1506+ # Apply RoPE if needed
1507+ if image_rotary_emb is not None :
1508+ query = apply_rotary_emb (query , image_rotary_emb )
1509+ if not attn .is_cross_attention :
1510+ key = apply_rotary_emb (key , image_rotary_emb )
1511+
1512+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1513+ # TODO: add support for attn.scale when we move to Torch 2.1
1514+ hidden_states = F .scaled_dot_product_attention (
1515+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1516+ )
1517+
1518+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1519+ hidden_states = hidden_states .to (query .dtype )
1520+
1521+ # linear proj
1522+ hidden_states = attn .to_out [0 ](hidden_states )
1523+ # dropout
1524+ hidden_states = attn .to_out [1 ](hidden_states )
1525+
1526+ if input_ndim == 4 :
1527+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1528+
1529+ if attn .residual_connection :
1530+ hidden_states = hidden_states + residual
1531+
1532+ hidden_states = hidden_states / attn .rescale_output_factor
1533+
1534+ return hidden_states
1535+
1536+
14291537class FusedAttnProcessor2_0 :
14301538 r"""
14311539 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
0 commit comments