4242logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4343
4444
45- class FluxAttnProcessor :
46- _attention_backend = None
45+ def _get_projections (attn : "FluxAttention" , hidden_states , encoder_hidden_states = None ):
46+ query = attn .to_q (hidden_states )
47+ key = attn .to_k (hidden_states )
48+ value = attn .to_v (hidden_states )
4749
48- def __init__ (self ):
49- if not hasattr (F , "scaled_dot_product_attention" ):
50- raise ImportError (f"{ self .__class__ .__name__ } requires PyTorch 2.0. Please upgrade your pytorch version." )
50+ encoder_query = encoder_key = encoder_value = None
51+ if encoder_hidden_states is not None and attn .added_kv_proj_dim is not None :
52+ encoder_query = attn .add_q_proj (encoder_hidden_states )
53+ encoder_key = attn .add_k_proj (encoder_hidden_states )
54+ encoder_value = attn .add_v_proj (encoder_hidden_states )
5155
52- def _get_projections (self , attn , hidden_states , encoder_hidden_states = None ):
53- query = attn .to_q (hidden_states )
54- key = attn .to_k (hidden_states )
55- value = attn .to_v (hidden_states )
56+ return query , key , value , encoder_query , encoder_key , encoder_value
5657
57- encoder_query = encoder_key = encoder_value = None
58- if encoder_hidden_states is not None and attn .added_kv_proj_dim is not None :
59- encoder_query = attn .add_q_proj (encoder_hidden_states )
60- encoder_key = attn .add_k_proj (encoder_hidden_states )
61- encoder_value = attn .add_v_proj (encoder_hidden_states )
6258
63- return query , key , value , encoder_query , encoder_key , encoder_value
59+ def _get_fused_projections (attn : "FluxAttention" , hidden_states , encoder_hidden_states = None ):
60+ query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
6461
65- def _get_fused_projections (self , attn , hidden_states , encoder_hidden_states = None ):
66- query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
62+ encoder_query = encoder_key = encoder_value = (None ,)
63+ if encoder_hidden_states is not None and hasattr (attn , "to_added_qkv" ):
64+ encoder_query , encoder_key , encoder_value = attn .to_added_qkv (encoder_hidden_states ).chunk (3 , dim = - 1 )
6765
68- encoder_query = encoder_key = encoder_value = (None ,)
69- if encoder_hidden_states is not None and hasattr (attn , "to_added_qkv" ):
70- encoder_query , encoder_key , encoder_value = attn .to_added_qkv (encoder_hidden_states ).chunk (3 , dim = - 1 )
66+ return query , key , value , encoder_query , encoder_key , encoder_value
7167
72- return query , key , value , encoder_query , encoder_key , encoder_value
7368
74- def get_qkv_projections (self , attn : AttentionModuleMixin , hidden_states , encoder_hidden_states = None ):
75- if attn .fused_projections :
76- return self ._get_fused_projections (attn , hidden_states , encoder_hidden_states )
77- return self ._get_projections (attn , hidden_states , encoder_hidden_states )
69+ def _get_qkv_projections (attn : "FluxAttention" , hidden_states , encoder_hidden_states = None ):
70+ if attn .fused_projections :
71+ return _get_fused_projections (attn , hidden_states , encoder_hidden_states )
72+ return _get_projections (attn , hidden_states , encoder_hidden_states )
73+
74+
75+ class FluxAttnProcessor :
76+ _attention_backend = None
77+
78+ def __init__ (self ):
79+ if not hasattr (F , "scaled_dot_product_attention" ):
80+ raise ImportError (f"{ self .__class__ .__name__ } requires PyTorch 2.0. Please upgrade your pytorch version." )
7881
7982 def __call__ (
8083 self ,
@@ -84,7 +87,7 @@ def __call__(
8487 attention_mask : Optional [torch .Tensor ] = None ,
8588 image_rotary_emb : Optional [torch .Tensor ] = None ,
8689 ) -> torch .Tensor :
87- query , key , value , encoder_query , encoder_key , encoder_value = self . get_qkv_projections (
90+ query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
8891 attn , hidden_states , encoder_hidden_states
8992 )
9093
@@ -180,55 +183,35 @@ def __call__(
180183 ip_hidden_states : Optional [List [torch .Tensor ]] = None ,
181184 ip_adapter_masks : Optional [torch .Tensor ] = None ,
182185 ) -> torch .Tensor :
183- batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states . shape
186+ batch_size = hidden_states .shape [ 0 ]
184187
185- # `sample` projections.
186- hidden_states_query_proj = attn .to_q (hidden_states )
187- key = attn .to_k (hidden_states )
188- value = attn .to_v (hidden_states )
189-
190- inner_dim = key .shape [- 1 ]
191- head_dim = inner_dim // attn .heads
188+ query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
189+ attn , hidden_states , encoder_hidden_states
190+ )
192191
193- hidden_states_query_proj = hidden_states_query_proj . view ( batch_size , - 1 , attn .heads , head_dim ). transpose ( 1 , 2 )
194- key = key .view ( batch_size , - 1 , attn .heads , head_dim ). transpose ( 1 , 2 )
195- value = value .view ( batch_size , - 1 , attn .heads , head_dim ). transpose ( 1 , 2 )
192+ query = query . unflatten ( - 1 , ( attn .heads , - 1 ) )
193+ key = key .unflatten ( - 1 , ( attn .heads , - 1 ) )
194+ value = value .unflatten ( - 1 , ( attn .heads , - 1 ) )
196195
197- if attn .norm_q is not None :
198- hidden_states_query_proj = attn .norm_q (hidden_states_query_proj )
199- if attn .norm_k is not None :
200- key = attn .norm_k (key )
196+ query = attn .norm_q (query )
197+ key = attn .norm_k (key )
198+ ip_query = query
201199
202- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
203200 if encoder_hidden_states is not None :
204- # `context` projections.
205- encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
206- encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
207- encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
208-
209- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
210- batch_size , - 1 , attn .heads , head_dim
211- ).transpose (1 , 2 )
212- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
213- batch_size , - 1 , attn .heads , head_dim
214- ).transpose (1 , 2 )
215- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
216- batch_size , - 1 , attn .heads , head_dim
217- ).transpose (1 , 2 )
218-
219- if attn .norm_added_q is not None :
220- encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
221- if attn .norm_added_k is not None :
222- encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
223-
224- # attention
225- query = torch .cat ([encoder_hidden_states_query_proj , hidden_states_query_proj ], dim = 2 )
226- key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
227- value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
201+ encoder_query = encoder_query .unflatten (- 1 , (attn .heads , - 1 ))
202+ encoder_key = encoder_key .unflatten (- 1 , (attn .heads , - 1 ))
203+ encoder_value = encoder_value .unflatten (- 1 , (attn .heads , - 1 ))
204+
205+ encoder_query = attn .norm_added_q (encoder_query )
206+ encoder_key = attn .norm_added_k (encoder_key )
207+
208+ query = torch .cat ([encoder_query , query ], dim = 1 )
209+ key = torch .cat ([encoder_key , key ], dim = 1 )
210+ value = torch .cat ([encoder_value , value ], dim = 1 )
228211
229212 if image_rotary_emb is not None :
230- query = apply_rotary_emb (query , image_rotary_emb )
231- key = apply_rotary_emb (key , image_rotary_emb )
213+ query = apply_rotary_emb (query , image_rotary_emb , sequence_dim = 1 )
214+ key = apply_rotary_emb (key , image_rotary_emb , sequence_dim = 1 )
232215
233216 hidden_states = dispatch_attention_fn (
234217 query ,
@@ -239,23 +222,18 @@ def __call__(
239222 is_causal = False ,
240223 backend = self ._attention_backend ,
241224 )
242- hidden_states = hidden_states .transpose ( 1 , 2 ). reshape ( batch_size , - 1 , attn . heads * head_dim )
225+ hidden_states = hidden_states .flatten ( 2 , 3 )
243226 hidden_states = hidden_states .to (query .dtype )
244227
245228 if encoder_hidden_states is not None :
246- encoder_hidden_states , hidden_states = (
247- hidden_states [:, : encoder_hidden_states .shape [1 ]],
248- hidden_states [:, encoder_hidden_states .shape [1 ] :],
229+ encoder_hidden_states , hidden_states = hidden_states .split_with_sizes (
230+ [encoder_hidden_states .shape [1 ], hidden_states .shape [1 ] - encoder_hidden_states .shape [1 ]], dim = 1
249231 )
250-
251- # linear proj
252232 hidden_states = attn .to_out [0 ](hidden_states )
253- # dropout
254233 hidden_states = attn .to_out [1 ](hidden_states )
255234 encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
256235
257236 # IP-adapter
258- ip_query = hidden_states_query_proj
259237 ip_attn_output = torch .zeros_like (hidden_states )
260238
261239 for current_ip_hidden_states , scale , to_k_ip , to_v_ip in zip (
@@ -264,10 +242,9 @@ def __call__(
264242 ip_key = to_k_ip (current_ip_hidden_states )
265243 ip_value = to_v_ip (current_ip_hidden_states )
266244
267- ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
268- ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
269- # the output of sdp = (batch, num_heads, seq_len, head_dim)
270- # TODO: add support for attn.scale when we move to Torch 2.1
245+ ip_key = ip_key .view (batch_size , - 1 , attn .heads , attn .head_dim )
246+ ip_value = ip_value .view (batch_size , - 1 , attn .heads , attn .head_dim )
247+
271248 current_ip_hidden_states = dispatch_attention_fn (
272249 ip_query ,
273250 ip_key ,
@@ -277,9 +254,7 @@ def __call__(
277254 is_causal = False ,
278255 backend = self ._attention_backend ,
279256 )
280- current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
281- batch_size , - 1 , attn .heads * head_dim
282- )
257+ current_ip_hidden_states = current_ip_hidden_states .reshape (batch_size , - 1 , attn .heads * attn .head_dim )
283258 current_ip_hidden_states = current_ip_hidden_states .to (ip_query .dtype )
284259 ip_attn_output += scale * current_ip_hidden_states
285260
@@ -316,6 +291,7 @@ def __init__(
316291 super ().__init__ ()
317292 assert qk_norm == "rms_norm" , "Flux uses RMSNorm"
318293
294+ self .head_dim = dim_head
319295 self .inner_dim = out_dim if out_dim is not None else dim_head * heads
320296 self .query_dim = query_dim
321297 self .use_bias = bias
0 commit comments