@@ -107,14 +107,10 @@ def fuse_qkv_projections(self):
107107 are fused. For cross-attention modules, key and value projection matrices are fused.
108108
109109 """
110- self .original_attn_processors = None
111-
112110 for _ , attn_processor in self .attn_processors .items ():
113111 if "Added" in str (attn_processor .__class__ .__name__ ):
114112 raise ValueError ("`fuse_qkv_projections()` is not supported for models having added KV projections." )
115113
116- self .original_attn_processors = self .attn_processors
117-
118114 for module in self .modules ():
119115 if isinstance (module , AttentionModuleMixin ):
120116 module .fuse_projections (fuse = True )
@@ -129,30 +125,58 @@ def unfuse_qkv_projections(self):
129125 </Tip>
130126
131127 """
132- if self . original_attn_processors is not None :
133- self . set_attn_processor ( self . original_attn_processors )
128+ for _ , attn_processor in self . attn_processors . items () :
129+ attn_processor . fused_projections = False
134130
135131
136132class AttentionModuleMixin :
137- """
138- A mixin class that provides common methods for attention modules.
133+ _default_processor_cls = None
134+ _available_processors = []
135+ fused_projections = False
139136
140- This mixin adds functionality to set different attention processors, handle attention masks, compute attention
141- scores, and manage projections.
142- """
137+ def set_processor ( self , processor : "AttnProcessor" ) -> None :
138+ """
139+ Set the attention processor to use.
143140
144- # Default processor classes to be overridden by subclasses
145- default_processor_cls = None
146- _available_processors = []
141+ Args:
142+ processor (`AttnProcessor`):
143+ The attention processor to use.
144+ """
145+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
146+ # pop `processor` from `self._modules`
147+ if (
148+ hasattr (self , "processor" )
149+ and isinstance (self .processor , torch .nn .Module )
150+ and not isinstance (processor , torch .nn .Module )
151+ ):
152+ logger .info (f"You are removing possibly trained weights of { self .processor } with { processor } " )
153+ self ._modules .pop ("processor" )
147154
148- fused_projections = False
149- is_cross_attention = False
155+ self .processor = processor
156+
157+ def get_processor (self , return_deprecated_lora : bool = False ) -> "AttentionProcessor" :
158+ """
159+ Get the attention processor in use.
160+
161+ Args:
162+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
163+ Set to `True` to return the deprecated LoRA attention processor.
150164
151- def _get_compatible_processor (self , backend ):
152- for processor_cls in self ._available_processors :
153- if backend in processor_cls .compatible_backends :
154- processor = processor_cls ()
155- return processor
165+ Returns:
166+ "AttentionProcessor": The attention processor in use.
167+ """
168+ if not return_deprecated_lora :
169+ return self .processor
170+
171+ def set_attention_backend (self , backend : str ):
172+ from .attention_dispatch import AttentionBackendName
173+
174+ available_backends = {x .value for x in AttentionBackendName .__members__ .values ()}
175+ if backend not in available_backends :
176+ raise ValueError (f"`{ backend = } ` must be one of the following: " + ", " .join (available_backends ))
177+
178+ backend = AttentionBackendName (backend .lower ())
179+ self .processor ._attention_backend = backend
156180
157181 def set_use_npu_flash_attention (self , use_npu_flash_attention : bool ) -> None :
158182 """
@@ -161,14 +185,12 @@ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
161185 Args:
162186 use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
163187 """
164- processor = self .default_processor_cls ()
165188
166189 if use_npu_flash_attention :
167190 if not is_torch_npu_available ():
168191 raise ImportError ("torch_npu is not available" )
169- processor = self ._get_compatible_processor ("npu" )
170192
171- self .set_processor ( processor )
193+ self .set_attention_backend ( "_native_npu" )
172194
173195 def set_use_xla_flash_attention (
174196 self ,
@@ -187,52 +209,85 @@ def set_use_xla_flash_attention(
187209 is_flux (`bool`, *optional*, defaults to `False`):
188210 Whether the model is a Flux model.
189211 """
190- processor = self .default_processor_cls ()
191212 if use_xla_flash_attention :
192213 if not is_torch_xla_available ():
193214 raise ImportError ("torch_xla is not available" )
194- processor = self ._get_compatible_processor ("xla" )
195215
196- self .set_processor ( processor )
216+ self .set_attention_backend ( "_native_xla" )
197217
198- @torch .no_grad ()
199- def fuse_projections (self , fuse = True ):
218+ def set_use_memory_efficient_attention_xformers (
219+ self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
220+ ) -> None :
200221 """
201- Fuse the query, key, and value projections into a single projection for efficiency .
222+ Set whether to use memory efficient attention from `xformers` or not .
202223
203224 Args:
204- fuse (`bool`): Whether to fuse the projections or not.
225+ use_memory_efficient_attention_xformers (`bool`):
226+ Whether to use memory efficient attention from `xformers` or not.
227+ attention_op (`Callable`, *optional*):
228+ The attention operation to use. Defaults to `None` which uses the default attention operation from
229+ `xformers`.
205230 """
206- # Skip if already in desired state
207- if getattr (self , "fused_projections" , False ) == fuse :
231+ if use_memory_efficient_attention_xformers :
232+ if not is_xformers_available ():
233+ raise ModuleNotFoundError (
234+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers" ,
235+ name = "xformers" ,
236+ )
237+ elif not torch .cuda .is_available ():
238+ raise ValueError (
239+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
240+ " only available for GPU "
241+ )
242+ else :
243+ try :
244+ # Make sure we can run the memory efficient attention
245+ if xformers is not None :
246+ dtype = None
247+ if attention_op is not None :
248+ op_fw , op_bw = attention_op
249+ dtype , * _ = op_fw .SUPPORTED_DTYPES
250+ q = torch .randn ((1 , 2 , 40 ), device = "cuda" , dtype = dtype )
251+ _ = xformers .ops .memory_efficient_attention (q , q , q )
252+ except Exception as e :
253+ raise e
254+
255+ self .set_attention_backend ("xformers" )
256+
257+ @torch .no_grad ()
258+ def fuse_projections (self ):
259+ """
260+ Fuse the query, key, and value projections into a single projection for efficiency.
261+ """
262+ # Skip if already fused
263+ if getattr (self , "fused_projections" , False ):
208264 return
209265
210266 device = self .to_q .weight .data .device
211267 dtype = self .to_q .weight .data .dtype
212268
213- if not self .is_cross_attention :
214- # Fuse self-attention projections
215- concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
216- in_features = concatenated_weights .shape [1 ]
217- out_features = concatenated_weights .shape [0 ]
218-
219- self .to_qkv = nn .Linear (in_features , out_features , bias = self .use_bias , device = device , dtype = dtype )
220- self .to_qkv .weight .copy_ (concatenated_weights )
221- if self .use_bias :
222- concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
223- self .to_qkv .bias .copy_ (concatenated_bias )
224-
225- else :
269+ if hasattr (self , "is_cross_attention" ) and self .is_cross_attention :
226270 # Fuse cross-attention key-value projections
227271 concatenated_weights = torch .cat ([self .to_k .weight .data , self .to_v .weight .data ])
228272 in_features = concatenated_weights .shape [1 ]
229273 out_features = concatenated_weights .shape [0 ]
230274
231275 self .to_kv = nn .Linear (in_features , out_features , bias = self .use_bias , device = device , dtype = dtype )
232276 self .to_kv .weight .copy_ (concatenated_weights )
233- if self .use_bias :
277+ if hasattr ( self , "use_bias" ) and self .use_bias :
234278 concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
235279 self .to_kv .bias .copy_ (concatenated_bias )
280+ else :
281+ # Fuse self-attention projections
282+ concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
283+ in_features = concatenated_weights .shape [1 ]
284+ out_features = concatenated_weights .shape [0 ]
285+
286+ self .to_qkv = nn .Linear (in_features , out_features , bias = self .use_bias , device = device , dtype = dtype )
287+ self .to_qkv .weight .copy_ (concatenated_weights )
288+ if hasattr (self , "use_bias" ) and self .use_bias :
289+ concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
290+ self .to_qkv .bias .copy_ (concatenated_bias )
236291
237292 # Handle added projections for models like SD3, Flux, etc.
238293 if (
@@ -256,52 +311,28 @@ def fuse_projections(self, fuse=True):
256311 )
257312 self .to_added_qkv .bias .copy_ (concatenated_bias )
258313
259- self .fused_projections = fuse
314+ self .fused_projections = True
260315
261- def set_use_memory_efficient_attention_xformers (
262- self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
263- ) -> None :
316+ @torch .no_grad ()
317+ def unfuse_projections (self ):
264318 """
265- Set whether to use memory efficient attention from `xformers` or not.
266-
267- Args:
268- use_memory_efficient_attention_xformers (`bool`):
269- Whether to use memory efficient attention from `xformers` or not.
270- attention_op (`Callable`, *optional*):
271- The attention operation to use. Defaults to `None` which uses the default attention operation from
272- `xformers`.
319+ Unfuse the query, key, and value projections back to separate projections.
273320 """
274- if use_memory_efficient_attention_xformers :
275- if not is_xformers_available ():
276- raise ModuleNotFoundError (
277- "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers" ,
278- name = "xformers" ,
279- )
280- elif not torch .cuda .is_available ():
281- raise ValueError (
282- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
283- " only available for GPU "
284- )
285- else :
286- try :
287- # Make sure we can run the memory efficient attention
288- if xformers is not None :
289- dtype = None
290- if attention_op is not None :
291- op_fw , op_bw = attention_op
292- dtype , * _ = op_fw .SUPPORTED_DTYPES
293- q = torch .randn ((1 , 2 , 40 ), device = "cuda" , dtype = dtype )
294- _ = xformers .ops .memory_efficient_attention (q , q , q )
295- except Exception as e :
296- raise e
321+ # Skip if not fused
322+ if not getattr (self , "fused_projections" , False ):
323+ return
297324
298- processor = self ._get_compatible_processor ("xformers" )
299- else :
300- # Set default processor
301- processor = self .default_processor_cls ()
325+ # Remove fused projection layers
326+ if hasattr (self , "to_qkv" ):
327+ delattr (self , "to_qkv" )
328+
329+ if hasattr (self , "to_kv" ):
330+ delattr (self , "to_kv" )
302331
303- if processor is not None :
304- self .set_processor (processor )
332+ if hasattr (self , "to_added_qkv" ):
333+ delattr (self , "to_added_qkv" )
334+
335+ self .fused_projections = False
305336
306337 def set_attention_slice (self , slice_size : int ) -> None :
307338 """
@@ -326,40 +357,6 @@ def set_attention_slice(self, slice_size: int) -> None:
326357
327358 self .set_processor (processor )
328359
329- def set_processor (self , processor : "AttnProcessor" ) -> None :
330- """
331- Set the attention processor to use.
332-
333- Args:
334- processor (`AttnProcessor`):
335- The attention processor to use.
336- """
337- # if current processor is in `self._modules` and if passed `processor` is not, we need to
338- # pop `processor` from `self._modules`
339- if (
340- hasattr (self , "processor" )
341- and isinstance (self .processor , torch .nn .Module )
342- and not isinstance (processor , torch .nn .Module )
343- ):
344- logger .info (f"You are removing possibly trained weights of { self .processor } with { processor } " )
345- self ._modules .pop ("processor" )
346-
347- self .processor = processor
348-
349- def get_processor (self , return_deprecated_lora : bool = False ) -> "AttentionProcessor" :
350- """
351- Get the attention processor in use.
352-
353- Args:
354- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
355- Set to `True` to return the deprecated LoRA attention processor.
356-
357- Returns:
358- "AttentionProcessor": The attention processor in use.
359- """
360- if not return_deprecated_lora :
361- return self .processor
362-
363360 def batch_to_head_dim (self , tensor : torch .Tensor ) -> torch .Tensor :
364361 """
365362 Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
0 commit comments