@@ -96,8 +96,7 @@ def __call__(self, x, context=None, freqs_cis=None):
9696 _B , _H , _W , _C = context .shape
9797 context_seq_len = _H * _W
9898 context = context .reshape ((B , context_seq_len , _C ))
99- else :
100- _B , context_seq_len , _C = context .shape
99+ # else: context is already [B, S_ctx, C]
101100
102101 query = self .query (x ) # [B, S, H, D]
103102 key = self .key (context ) # [B, S_ctx, H, D]
@@ -106,19 +105,21 @@ def __call__(self, x, context=None, freqs_cis=None):
106105 # Apply RoPE to query and key
107106 if freqs_cis is None :
108107 # Generate frequencies using the rope_emb instance
109- seq_len = query .shape [1 ]
110- freqs_cos , freqs_sin = self .rope_emb (seq_len )
108+ seq_len_q = query .shape [1 ] # Use query's sequence length
109+ freqs_cos , freqs_sin = self .rope_emb (seq_len_q )
111110 else :
112- # If freqs_cis is passed in as a tuple (for backward compatibility)
111+ # If freqs_cis is passed in as a tuple
113112 freqs_cos , freqs_sin = freqs_cis
114113
115114 # Apply RoPE to query and key
116- # Permute to [B, H, S, D] for RoPE application if needed by apply_rotary_embedding
115+ # Permute to [B, H, S, D] for RoPE application
117116 query = einops .rearrange (query , 'b s h d -> b h s d' )
118117 key = einops .rearrange (key , 'b s h d -> b h s d' )
119118
119+ # Apply RoPE only up to the context sequence length for keys if different
120+ # Assuming self-attention or context has same seq len for simplicity here
120121 query = apply_rotary_embedding (query , freqs_cos , freqs_sin )
121- key = apply_rotary_embedding (key , freqs_cos , freqs_sin )
122+ key = apply_rotary_embedding (key , freqs_cos , freqs_sin ) # Apply same freqs to key
122123
123124 # Permute back to [B, S, H, D] for dot_product_attention
124125 query = einops .rearrange (query , 'b h s d -> b s h d' )
@@ -130,10 +131,8 @@ def __call__(self, x, context=None, freqs_cis=None):
130131 deterministic = True
131132 ) # Output shape [B, S, H, D]
132133
133- # Flatten H and D dimensions before projection
134- hidden_states_flat = einops .rearrange (hidden_states , 'b s h d -> b s (h d)' ) # Shape [B, S, F]
135-
136- proj = self .proj_attn (hidden_states_flat ) # Input [B, S, F], Output shape [B, S, C]
134+ # Use the proj_attn from NormalAttention which expects [B, S, H, D]
135+ proj = self .proj_attn (hidden_states ) # Output shape [B, S, C]
137136
138137 if is_4d :
139138 proj = proj .reshape (orig_x_shape ) # Reshape back if input was 4D
0 commit comments