66from diffsynth_engine .models .base import StateDictConverter , PreTrainedModel
77from diffsynth_engine .models .basic import attention as attention_ops
88from diffsynth_engine .models .basic .timestep import TimestepEmbeddings
9- from diffsynth_engine .models .basic .transformer_helper import AdaLayerNorm , ApproximateGELU , RMSNorm
9+ from diffsynth_engine .models .basic .transformer_helper import AdaLayerNorm , GELU , RMSNorm
1010from diffsynth_engine .utils .gguf import gguf_inference
1111from diffsynth_engine .utils .fp8_linear import fp8_inference
1212from diffsynth_engine .utils .parallel import (
@@ -144,7 +144,7 @@ def __init__(
144144 super ().__init__ ()
145145 inner_dim = int (dim * 4 )
146146 self .net = nn .ModuleList ([])
147- self .net .append (ApproximateGELU (dim , inner_dim , device = device , dtype = dtype ))
147+ self .net .append (GELU (dim , inner_dim , approximate = "tanh" , device = device , dtype = dtype ))
148148 self .net .append (nn .Dropout (dropout ))
149149 self .net .append (nn .Linear (inner_dim , dim_out , device = device , dtype = dtype ))
150150
@@ -155,8 +155,8 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
155155
156156
157157def apply_rotary_emb_qwen (x : torch .Tensor , freqs_cis : Union [torch .Tensor , Tuple [torch .Tensor ]]):
158- x_rotated = torch .view_as_complex (x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 ))
159- x_out = torch .view_as_real (x_rotated * freqs_cis ) .flatten (3 )
158+ x_rotated = torch .view_as_complex (x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 )) # (b, s, h, d) -> (b, s, h, d/2, 2)
159+ x_out = torch .view_as_real (x_rotated * freqs_cis . unsqueeze ( 1 )) .flatten (3 ) # (b, s, h, d/2, 2) -> (b, s, h, d )
160160 return x_out .type_as (x )
161161
162162
@@ -200,13 +200,13 @@ def forward(
200200 img_q , img_k , img_v = self .to_q (image ), self .to_k (image ), self .to_v (image )
201201 txt_q , txt_k , txt_v = self .add_q_proj (text ), self .add_k_proj (text ), self .add_v_proj (text )
202202
203- img_q = rearrange (img_q , "b s (h d) -> b h s d" , h = self .num_heads )
204- img_k = rearrange (img_k , "b s (h d) -> b h s d" , h = self .num_heads )
205- img_v = rearrange (img_v , "b s (h d) -> b h s d" , h = self .num_heads )
203+ img_q = rearrange (img_q , "b s (h d) -> b s h d" , h = self .num_heads )
204+ img_k = rearrange (img_k , "b s (h d) -> b s h d" , h = self .num_heads )
205+ img_v = rearrange (img_v , "b s (h d) -> b s h d" , h = self .num_heads )
206206
207- txt_q = rearrange (txt_q , "b s (h d) -> b h s d" , h = self .num_heads )
208- txt_k = rearrange (txt_k , "b s (h d) -> b h s d" , h = self .num_heads )
209- txt_v = rearrange (txt_v , "b s (h d) -> b h s d" , h = self .num_heads )
207+ txt_q = rearrange (txt_q , "b s (h d) -> b s h d" , h = self .num_heads )
208+ txt_k = rearrange (txt_k , "b s (h d) -> b s h d" , h = self .num_heads )
209+ txt_v = rearrange (txt_v , "b s (h d) -> b s h d" , h = self .num_heads )
210210
211211 img_q , img_k = self .norm_q (img_q ), self .norm_k (img_k )
212212 txt_q , txt_k = self .norm_added_q (txt_q ), self .norm_added_k (txt_k )
@@ -218,13 +218,9 @@ def forward(
218218 txt_q = apply_rotary_emb_qwen (txt_q , txt_freqs )
219219 txt_k = apply_rotary_emb_qwen (txt_k , txt_freqs )
220220
221- joint_q = torch .cat ([txt_q , img_q ], dim = 2 )
222- joint_k = torch .cat ([txt_k , img_k ], dim = 2 )
223- joint_v = torch .cat ([txt_v , img_v ], dim = 2 )
224-
225- joint_q = joint_q .transpose (1 , 2 )
226- joint_k = joint_k .transpose (1 , 2 )
227- joint_v = joint_v .transpose (1 , 2 )
221+ joint_q = torch .cat ([txt_q , img_q ], dim = 1 )
222+ joint_k = torch .cat ([txt_k , img_k ], dim = 1 )
223+ joint_v = torch .cat ([txt_v , img_v ], dim = 1 )
228224
229225 attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
230226 joint_attn_out = attention_ops .attention (joint_q , joint_k , joint_v , attn_mask = attn_mask , ** attn_kwargs )
0 commit comments