@@ -5410,95 +5410,45 @@ def __call__(
54105410        hidden_states : torch .Tensor ,
54115411        encoder_hidden_states : Optional [torch .Tensor ] =  None ,
54125412        attention_mask : Optional [torch .Tensor ] =  None ,
5413-         temb : Optional [torch .Tensor ] =  None ,
5414-         * args ,
5415-         ** kwargs ,
54165413    ) ->  torch .Tensor :
5417-         if  len (args ) >  0  or  kwargs .get ("scale" , None ) is  not None :
5418-             deprecation_message  =  "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 
5419-             deprecate ("scale" , "1.0.0" , deprecation_message )
5420- 
5421-         residual  =  hidden_states 
5422-         if  attn .spatial_norm  is  not None :
5423-             hidden_states  =  attn .spatial_norm (hidden_states , temb )
5424- 
5425-         input_ndim  =  hidden_states .ndim 
5426- 
5427-         if  input_ndim  ==  4 :
5428-             batch_size , channel , height , width  =  hidden_states .shape 
5429-             hidden_states  =  hidden_states .view (batch_size , channel , height  *  width ).transpose (1 , 2 )
5414+         original_dtype  =  hidden_states .dtype 
54305415
5431-         # chunk 
54325416        hidden_states_uncond , hidden_states_org , hidden_states_ptb  =  hidden_states .chunk (3 )
54335417        hidden_states_org  =  torch .cat ([hidden_states_uncond , hidden_states_org ])
54345418
5435-         # original path 
5436-         batch_size , sequence_length , _  =  (
5437-             hidden_states_org .shape  if  encoder_hidden_states  is  None  else  encoder_hidden_states .shape 
5438-         )
5439- 
54405419        query  =  attn .to_q (hidden_states_org )
54415420        key  =  attn .to_k (hidden_states_org )
54425421        value  =  attn .to_v (hidden_states_org )
54435422
5444-         inner_dim  =  key .shape [- 1 ]
5445-         head_dim  =  inner_dim  //  attn .heads 
5446- 
5447-         dtype  =  query .dtype 
5448- 
5449-         query  =  query .transpose (- 1 , - 2 ).reshape (batch_size , attn .heads , head_dim , - 1 )
5450-         key  =  key .transpose (- 1 , - 2 ).reshape (batch_size , attn .heads , head_dim , - 1 ).transpose (- 1 , - 2 )
5451-         value  =  value .transpose (- 1 , - 2 ).reshape (batch_size , attn .heads , head_dim , - 1 )
5423+         query  =  query .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 ))
5424+         key  =  key .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 )).transpose (2 , 3 )
5425+         value  =  value .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 ))
54525426
5453-         query  =  self .kernel_func (query )   # B, h, h_d, N 
5427+         query  =  self .kernel_func (query )
54545428        key  =  self .kernel_func (key )
54555429
5456-         # need torch.float 
54575430        query , key , value  =  query .float (), key .float (), value .float ()
54585431
54595432        value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = self .pad_val )
5460-         vk  =  torch .matmul (value , key )
5461-         hidden_states_org  =  torch .matmul (vk , query )
5433+         scores  =  torch .matmul (value , key )
5434+         hidden_states_org  =  torch .matmul (scores , query )
54625435
5463-         if  hidden_states_org .dtype  in  [torch .float16 , torch .bfloat16 ]:
5464-             hidden_states_org  =  hidden_states_org .float ()
54655436        hidden_states_org  =  hidden_states_org [:, :, :- 1 ] /  (hidden_states_org [:, :, - 1 :] +  self .eps )
5437+         hidden_states_org  =  hidden_states_org .flatten (1 , 2 ).transpose (1 , 2 )
5438+         hidden_states_org  =  hidden_states_org .to (original_dtype )
54665439
5467-         hidden_states_org  =  hidden_states_org .view (batch_size , attn .heads  *  head_dim , - 1 ).permute (0 , 2 , 1 )
5468-         hidden_states_org  =  hidden_states_org .to (dtype )
5469- 
5470-         # linear proj 
54715440        hidden_states_org  =  attn .to_out [0 ](hidden_states_org )
5472-         # dropout 
54735441        hidden_states_org  =  attn .to_out [1 ](hidden_states_org )
54745442
5475-         if  input_ndim  ==  4 :
5476-             hidden_states_org  =  hidden_states_org .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
5477- 
54785443        # perturbed path (identity attention) 
5479-         batch_size ,  sequence_length ,  _   =   hidden_states_ptb . shape 
5444+         hidden_states_ptb   =   attn . to_v ( hidden_states_ptb ). to ( original_dtype ) 
54805445
5481-         value  =  attn .to_v (hidden_states_ptb )
5482-         hidden_states_ptb  =  value 
5483-         hidden_states_ptb  =  hidden_states_ptb .to (dtype )
5484- 
5485-         # linear proj 
54865446        hidden_states_ptb  =  attn .to_out [0 ](hidden_states_ptb )
5487-         # dropout 
54885447        hidden_states_ptb  =  attn .to_out [1 ](hidden_states_ptb )
54895448
5490-         if  input_ndim  ==  4 :
5491-             hidden_states_ptb  =  hidden_states_ptb .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
5492- 
5493-         # cat 
54945449        hidden_states  =  torch .cat ([hidden_states_org , hidden_states_ptb ])
54955450
5496-         if  attn .residual_connection :
5497-             hidden_states  =  hidden_states  +  residual 
5498- 
5499-         hidden_states  =  hidden_states  /  attn .rescale_output_factor 
5500- 
5501-         if  hidden_states .dtype  ==  torch .float16 :
5451+         if  original_dtype  ==  torch .float16 :
55025452            hidden_states  =  hidden_states .clip (- 65504 , 65504 )
55035453
55045454        return  hidden_states 
@@ -5520,93 +5470,47 @@ def __call__(
55205470        hidden_states : torch .Tensor ,
55215471        encoder_hidden_states : Optional [torch .Tensor ] =  None ,
55225472        attention_mask : Optional [torch .Tensor ] =  None ,
5523-         temb : Optional [torch .Tensor ] =  None ,
5524-         * args ,
5525-         ** kwargs ,
55265473    ) ->  torch .Tensor :
5527-         if  len (args ) >  0  or  kwargs .get ("scale" , None ) is  not None :
5528-             deprecation_message  =  "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 
5529-             deprecate ("scale" , "1.0.0" , deprecation_message )
5530- 
5531-         residual  =  hidden_states 
5532-         if  attn .spatial_norm  is  not None :
5533-             hidden_states  =  attn .spatial_norm (hidden_states , temb )
5534- 
5535-         input_ndim  =  hidden_states .ndim 
5536- 
5537-         if  input_ndim  ==  4 :
5538-             batch_size , channel , height , width  =  hidden_states .shape 
5539-             hidden_states  =  hidden_states .view (batch_size , channel , height  *  width ).transpose (1 , 2 )
5474+         original_dtype  =  hidden_states .dtype 
55405475
5541-         # chunk 
55425476        hidden_states_org , hidden_states_ptb  =  hidden_states .chunk (2 )
55435477
5544-         # original path 
5545-         batch_size , sequence_length , _  =  (
5546-             hidden_states_org .shape  if  encoder_hidden_states  is  None  else  encoder_hidden_states .shape 
5547-         )
5548- 
55495478        query  =  attn .to_q (hidden_states_org )
55505479        key  =  attn .to_k (hidden_states_org )
55515480        value  =  attn .to_v (hidden_states_org )
55525481
5553-         inner_dim  =  key .shape [- 1 ]
5554-         head_dim  =  inner_dim  //  attn .heads 
5555- 
5556-         dtype  =  query .dtype 
5557- 
5558-         query  =  query .transpose (- 1 , - 2 ).reshape (batch_size , attn .heads , head_dim , - 1 )
5559-         key  =  key .transpose (- 1 , - 2 ).reshape (batch_size , attn .heads , head_dim , - 1 ).transpose (- 1 , - 2 )
5560-         value  =  value .transpose (- 1 , - 2 ).reshape (batch_size , attn .heads , head_dim , - 1 )
5482+         query  =  query .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 ))
5483+         key  =  key .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 )).transpose (2 , 3 )
5484+         value  =  value .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 ))
55615485
5562-         query  =  self .kernel_func (query )   # B, h, h_d, N 
5486+         query  =  self .kernel_func (query )
55635487        key  =  self .kernel_func (key )
55645488
5565-         # need torch.float 
55665489        query , key , value  =  query .float (), key .float (), value .float ()
55675490
55685491        value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = self .pad_val )
5569-         vk  =  torch .matmul (value , key )
5570-         hidden_states_org  =  torch .matmul (vk , query )
5492+         scores  =  torch .matmul (value , key )
5493+         hidden_states_org  =  torch .matmul (scores , query )
55715494
55725495        if  hidden_states_org .dtype  in  [torch .float16 , torch .bfloat16 ]:
55735496            hidden_states_org  =  hidden_states_org .float ()
5574-         hidden_states_org  =  hidden_states_org [:, :, :- 1 ] /  (hidden_states_org [:, :, - 1 :] +  self .eps )
55755497
5576-         hidden_states_org  =  hidden_states_org .view (batch_size , attn .heads  *  head_dim , - 1 ).permute (0 , 2 , 1 )
5577-         hidden_states_org  =  hidden_states_org .to (dtype )
5498+         hidden_states_org  =  hidden_states_org [:, :, :- 1 ] /  (hidden_states_org [:, :, - 1 :] +  self .eps )
5499+         hidden_states_org  =  hidden_states_org .flatten (1 , 2 ).transpose (1 , 2 )
5500+         hidden_states_org  =  hidden_states_org .to (original_dtype )
55785501
5579-         # linear proj 
55805502        hidden_states_org  =  attn .to_out [0 ](hidden_states_org )
5581-         # dropout 
55825503        hidden_states_org  =  attn .to_out [1 ](hidden_states_org )
55835504
5584-         if  input_ndim  ==  4 :
5585-             hidden_states_org  =  hidden_states_org .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
5586- 
55875505        # perturbed path (identity attention) 
5588-         batch_size ,  sequence_length ,  _   =   hidden_states_ptb . shape 
5506+         hidden_states_ptb   =   attn . to_v ( hidden_states_ptb ). to ( original_dtype ) 
55895507
5590-         hidden_states_ptb  =  attn .to_v (hidden_states_ptb )
5591-         hidden_states_ptb  =  hidden_states_ptb .to (dtype )
5592- 
5593-         # linear proj 
55945508        hidden_states_ptb  =  attn .to_out [0 ](hidden_states_ptb )
5595-         # dropout 
55965509        hidden_states_ptb  =  attn .to_out [1 ](hidden_states_ptb )
55975510
5598-         if  input_ndim  ==  4 :
5599-             hidden_states_ptb  =  hidden_states_ptb .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
5600- 
5601-         # cat 
56025511        hidden_states  =  torch .cat ([hidden_states_org , hidden_states_ptb ])
56035512
5604-         if  attn .residual_connection :
5605-             hidden_states  =  hidden_states  +  residual 
5606- 
5607-         hidden_states  =  hidden_states  /  attn .rescale_output_factor 
5608- 
5609-         if  hidden_states .dtype  ==  torch .float16 :
5513+         if  original_dtype  ==  torch .float16 :
56105514            hidden_states  =  hidden_states .clip (- 65504 , 65504 )
56115515
56125516        return  hidden_states 
0 commit comments