Skip to content

Commit 83956fe

Browse files
committed
feat: dit changes
1 parent 2458106 commit 83956fe

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

flaxdiff/models/simple_dit.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def __call__(self, x, context=None, freqs_cis=None):
9696
_B, _H, _W, _C = context.shape
9797
context_seq_len = _H * _W
9898
context = context.reshape((B, context_seq_len, _C))
99-
else:
100-
_B, context_seq_len, _C = context.shape
99+
# else: context is already [B, S_ctx, C]
101100

102101
query = self.query(x) # [B, S, H, D]
103102
key = self.key(context) # [B, S_ctx, H, D]
@@ -106,19 +105,21 @@ def __call__(self, x, context=None, freqs_cis=None):
106105
# Apply RoPE to query and key
107106
if freqs_cis is None:
108107
# Generate frequencies using the rope_emb instance
109-
seq_len = query.shape[1]
110-
freqs_cos, freqs_sin = self.rope_emb(seq_len)
108+
seq_len_q = query.shape[1] # Use query's sequence length
109+
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
111110
else:
112-
# If freqs_cis is passed in as a tuple (for backward compatibility)
111+
# If freqs_cis is passed in as a tuple
113112
freqs_cos, freqs_sin = freqs_cis
114113

115114
# Apply RoPE to query and key
116-
# Permute to [B, H, S, D] for RoPE application if needed by apply_rotary_embedding
115+
# Permute to [B, H, S, D] for RoPE application
117116
query = einops.rearrange(query, 'b s h d -> b h s d')
118117
key = einops.rearrange(key, 'b s h d -> b h s d')
119118

119+
# Apply RoPE only up to the context sequence length for keys if different
120+
# Assuming self-attention or context has same seq len for simplicity here
120121
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
121-
key = apply_rotary_embedding(key, freqs_cos, freqs_sin)
122+
key = apply_rotary_embedding(key, freqs_cos, freqs_sin) # Apply same freqs to key
122123

123124
# Permute back to [B, S, H, D] for dot_product_attention
124125
query = einops.rearrange(query, 'b h s d -> b s h d')
@@ -130,10 +131,8 @@ def __call__(self, x, context=None, freqs_cis=None):
130131
deterministic=True
131132
) # Output shape [B, S, H, D]
132133

133-
# Flatten H and D dimensions before projection
134-
hidden_states_flat = einops.rearrange(hidden_states, 'b s h d -> b s (h d)') # Shape [B, S, F]
135-
136-
proj = self.proj_attn(hidden_states_flat) # Input [B, S, F], Output shape [B, S, C]
134+
# Use the proj_attn from NormalAttention which expects [B, S, H, D]
135+
proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
137136

138137
if is_4d:
139138
proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D

0 commit comments

Comments
 (0)