@@ -5446,11 +5446,6 @@ class SanaLinearAttnProcessor2_0:
54465446    Processor for implementing scaled dot-product linear attention. 
54475447    """ 
54485448
5449-     def  __init__ (self , pad_val = 1.0 , eps = 1e-15 ):
5450-         self .pad_val  =  pad_val 
5451-         self .eps  =  eps 
5452-         self .kernel_func  =  nn .ReLU (inplace = False )
5453- 
54545449    def  __call__ (
54555450        self ,
54565451        attn : Attention ,
@@ -5471,16 +5466,16 @@ def __call__(
54715466        key  =  key .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 )).transpose (2 , 3 )
54725467        value  =  value .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 ))
54735468
5474-         query  =  self . kernel_func (query )
5475-         key  =  self . kernel_func (key )
5469+         query  =  F . relu (query )
5470+         key  =  F . relu (key )
54765471
54775472        query , key , value  =  query .float (), key .float (), value .float ()
54785473
5479-         value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = self . pad_val )
5474+         value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1.0 )
54805475        scores  =  torch .matmul (value , key )
54815476        hidden_states  =  torch .matmul (scores , query )
54825477
5483-         hidden_states  =  hidden_states [:, :, :- 1 ] /  (hidden_states [:, :, - 1 :] +  self . eps )
5478+         hidden_states  =  hidden_states [:, :, :- 1 ] /  (hidden_states [:, :, - 1 :] +  1e-15 )
54845479        hidden_states  =  hidden_states .flatten (1 , 2 ).transpose (1 , 2 )
54855480        hidden_states  =  hidden_states .to (original_dtype )
54865481
@@ -5498,11 +5493,6 @@ class PAGCFGSanaLinearAttnProcessor2_0:
54985493    Processor for implementing scaled dot-product linear attention. 
54995494    """ 
55005495
5501-     def  __init__ (self , pad_val = 1.0 , eps = 1e-15 ):
5502-         self .pad_val  =  pad_val 
5503-         self .eps  =  eps 
5504-         self .kernel_func  =  nn .ReLU (inplace = False )
5505- 
55065496    def  __call__ (
55075497        self ,
55085498        attn : Attention ,
@@ -5523,16 +5513,16 @@ def __call__(
55235513        key  =  key .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 )).transpose (2 , 3 )
55245514        value  =  value .transpose (1 , 2 ).unflatten (1 , (attn .heads , - 1 ))
55255515
5526-         query  =  self . kernel_func (query )
5527-         key  =  self . kernel_func (key )
5516+         query  =  F . relu (query )
5517+         key  =  F . relu (key )
55285518
55295519        query , key , value  =  query .float (), key .float (), value .float ()
55305520
5531-         value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = self . pad_val )
5521+         value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1.0 )
55325522        scores  =  torch .matmul (value , key )
55335523        hidden_states_org  =  torch .matmul (scores , query )
55345524
5535-         hidden_states_org  =  hidden_states_org [:, :, :- 1 ] /  (hidden_states_org [:, :, - 1 :] +  self . eps )
5525+         hidden_states_org  =  hidden_states_org [:, :, :- 1 ] /  (hidden_states_org [:, :, - 1 :] +  1e-15 )
55365526        hidden_states_org  =  hidden_states_org .flatten (1 , 2 ).transpose (1 , 2 )
55375527        hidden_states_org  =  hidden_states_org .to (original_dtype )
55385528
@@ -5558,11 +5548,6 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
55585548    Processor for implementing scaled dot-product linear attention. 
55595549    """ 
55605550
5561-     def  __init__ (self , pad_val = 1.0 , eps = 1e-15 ):
5562-         self .pad_val  =  pad_val 
5563-         self .eps  =  eps 
5564-         self .kernel_func  =  nn .ReLU (inplace = False )
5565- 
55665551    def  __call__ (
55675552        self ,
55685553        attn : Attention ,
@@ -5587,14 +5572,14 @@ def __call__(
55875572
55885573        query , key , value  =  query .float (), key .float (), value .float ()
55895574
5590-         value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = self . pad_val )
5575+         value  =  F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1.0 )
55915576        scores  =  torch .matmul (value , key )
55925577        hidden_states_org  =  torch .matmul (scores , query )
55935578
55945579        if  hidden_states_org .dtype  in  [torch .float16 , torch .bfloat16 ]:
55955580            hidden_states_org  =  hidden_states_org .float ()
55965581
5597-         hidden_states_org  =  hidden_states_org [:, :, :- 1 ] /  (hidden_states_org [:, :, - 1 :] +  self . eps )
5582+         hidden_states_org  =  hidden_states_org [:, :, :- 1 ] /  (hidden_states_org [:, :, - 1 :] +  1e-15 )
55985583        hidden_states_org  =  hidden_states_org .flatten (1 , 2 ).transpose (1 , 2 )
55995584        hidden_states_org  =  hidden_states_org .to (original_dtype )
56005585
0 commit comments