Skip to content

Commit 008bf04

Browse files
committed
fix: fixed circular imports
1 parent 37e47b9 commit 008bf04

File tree

4 files changed

+266
-377
lines changed

4 files changed

+266
-377
lines changed

flaxdiff/models/simple_dit.py

Lines changed: 1 addition & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from .simple_vit import PatchEmbedding, unpatchify
21
import jax
32
import jax.numpy as jnp
43
from flax import linen as nn
@@ -7,6 +6,7 @@
76
from functools import partial
87

98
# Re-use existing components if they are suitable
9+
from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention, AdaLNParams
1010
from .common import kernel_init, FourierEmbedding, TimeProjection
1111
# Using NormalAttention for RoPE integration
1212
from .attention import NormalAttention
@@ -15,207 +15,6 @@
1515
# Use our improved Hilbert implementation
1616
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
1717

18-
# --- Rotary Positional Embedding (RoPE) ---
19-
# Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
20-
21-
22-
def _rotate_half(x: jax.Array) -> jax.Array:
23-
"""Rotates half the hidden dims of the input."""
24-
x1 = x[..., : x.shape[-1] // 2]
25-
x2 = x[..., x.shape[-1] // 2:]
26-
return jnp.concatenate((-x2, x1), axis=-1)
27-
28-
def apply_rotary_embedding(
29-
x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
30-
) -> jax.Array:
31-
"""Applies rotary embedding to the input tensor using rotate_half method."""
32-
# x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
33-
# freqs_cos/sin shape: [Sequence, Dimension / 2]
34-
35-
# Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
36-
if x.ndim == 4: # [B, H, S, D]
37-
cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
38-
sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
39-
elif x.ndim == 3: # [B, S, D]
40-
cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
41-
sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
42-
43-
# Duplicate cos and sin for the full dimension D
44-
# Shape becomes [..., S, D]
45-
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
46-
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
47-
48-
# Apply rotation: x * cos + rotate_half(x) * sin
49-
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
50-
return x_rotated.astype(x.dtype)
51-
52-
class RotaryEmbedding(nn.Module):
53-
dim: int # Dimension of the head
54-
max_seq_len: int = 2048
55-
base: int = 10000
56-
dtype: Dtype = jnp.float32
57-
58-
def setup(self):
59-
inv_freq = 1.0 / (
60-
self.base ** (jnp.arange(0, self.dim, 2,
61-
dtype=jnp.float32) / self.dim)
62-
)
63-
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
64-
freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
65-
66-
# Store cosine and sine separately instead of as complex numbers
67-
self.freqs_cos = jnp.cos(freqs) # Shape: [max_seq_len, dim / 2]
68-
self.freqs_sin = jnp.sin(freqs) # Shape: [max_seq_len, dim / 2]
69-
70-
def __call__(self, seq_len: int):
71-
if seq_len > self.max_seq_len:
72-
raise ValueError(
73-
f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
74-
# Return separate cos and sin components
75-
return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
76-
# --- Attention with RoPE ---
77-
78-
79-
class RoPEAttention(NormalAttention):
80-
rope_emb: RotaryEmbedding = None # Instance of RotaryEmbedding
81-
82-
@nn.compact
83-
def __call__(self, x, context=None, freqs_cis=None):
84-
# x has shape [B, H, W, C] or [B, S, C]
85-
orig_x_shape = x.shape
86-
is_4d = len(orig_x_shape) == 4
87-
if is_4d:
88-
B, H, W, C = x.shape
89-
seq_len = H * W
90-
x = x.reshape((B, seq_len, C))
91-
else:
92-
B, seq_len, C = x.shape
93-
94-
context = x if context is None else context
95-
if len(context.shape) == 4:
96-
_B, _H, _W, _C = context.shape
97-
context_seq_len = _H * _W
98-
context = context.reshape((B, context_seq_len, _C))
99-
# else: context is already [B, S_ctx, C]
100-
101-
query = self.query(x) # [B, S, H, D]
102-
key = self.key(context) # [B, S_ctx, H, D]
103-
value = self.value(context) # [B, S_ctx, H, D]
104-
105-
# Apply RoPE to query and key
106-
if freqs_cis is None:
107-
# Generate frequencies using the rope_emb instance
108-
seq_len_q = query.shape[1] # Use query's sequence length
109-
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
110-
else:
111-
# If freqs_cis is passed in as a tuple
112-
freqs_cos, freqs_sin = freqs_cis
113-
114-
# Apply RoPE to query and key
115-
# Permute to [B, H, S, D] for RoPE application
116-
query = einops.rearrange(query, 'b s h d -> b h s d')
117-
key = einops.rearrange(key, 'b s h d -> b h s d')
118-
119-
# Apply RoPE only up to the context sequence length for keys if different
120-
# Assuming self-attention or context has same seq len for simplicity here
121-
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
122-
key = apply_rotary_embedding(key, freqs_cos, freqs_sin) # Apply same freqs to key
123-
124-
# Permute back to [B, S, H, D] for dot_product_attention
125-
query = einops.rearrange(query, 'b h s d -> b s h d')
126-
key = einops.rearrange(key, 'b h s d -> b s h d')
127-
128-
hidden_states = nn.dot_product_attention(
129-
query, key, value, dtype=self.dtype, broadcast_dropout=False,
130-
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
131-
deterministic=True
132-
) # Output shape [B, S, H, D]
133-
134-
# Use the proj_attn from NormalAttention which expects [B, S, H, D]
135-
proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
136-
137-
if is_4d:
138-
proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
139-
140-
return proj
141-
142-
# --- adaLN-Zero ---
143-
144-
145-
class AdaLNZero(nn.Module):
146-
features: int
147-
dtype: Optional[Dtype] = None
148-
precision: PrecisionLike = None
149-
norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
150-
151-
@nn.compact
152-
def __call__(self, x, conditioning):
153-
# Project conditioning signal to get scale and shift parameters
154-
# Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
155-
# Or [B, 1, 6*features] if x is [B, S, F]
156-
157-
# Ensure conditioning has seq dim if x does
158-
# x=[B,S,F], cond=[B,D_cond]
159-
if x.ndim == 3 and conditioning.ndim == 2:
160-
conditioning = jnp.expand_dims(
161-
conditioning, axis=1) # cond=[B,1,D_cond]
162-
163-
# Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
164-
# Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
165-
ada_params = nn.Dense(
166-
features=6 * self.features,
167-
dtype=self.dtype,
168-
precision=self.precision,
169-
# Initialize projection to zero (Zero init)
170-
kernel_init=nn.initializers.zeros,
171-
name="ada_proj"
172-
)(conditioning)
173-
174-
# Split into scale, shift, gate for MLP and Attention
175-
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
176-
ada_params, 6, axis=-1)
177-
178-
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
179-
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
180-
# Apply Layer Normalization
181-
norm = nn.LayerNorm(epsilon=self.norm_epsilon,
182-
use_scale=False, use_bias=False, dtype=self.dtype)
183-
# norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
184-
185-
norm_x = norm(x)
186-
187-
# Modulate for Attention path
188-
x_attn = norm_x * (1 + scale_attn) + shift_attn
189-
190-
# Modulate for MLP path
191-
x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
192-
193-
# Return modulated outputs and gates
194-
return x_attn, gate_attn, x_mlp, gate_mlp
195-
196-
class AdaLNParams(nn.Module): # Renamed for clarity
197-
features: int
198-
dtype: Optional[Dtype] = None
199-
precision: PrecisionLike = None
200-
201-
@nn.compact
202-
def __call__(self, conditioning):
203-
# Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
204-
if conditioning.ndim == 2:
205-
conditioning = jnp.expand_dims(conditioning, axis=1)
206-
207-
# Project conditioning to get 6 params per feature
208-
ada_params = nn.Dense(
209-
features=6 * self.features,
210-
dtype=self.dtype,
211-
precision=self.precision,
212-
kernel_init=nn.initializers.zeros,
213-
name="ada_proj"
214-
)(conditioning)
215-
# Return all params (or split if preferred, but maybe return tuple/dict)
216-
# Shape: [B, 1, 6*F]
217-
return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
218-
21918
# --- DiT Block ---
22019
class DiTBlock(nn.Module):
22120
features: int

flaxdiff/models/simple_mmdit.py

Lines changed: 1 addition & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -7,143 +7,12 @@
77
from flax.typing import Dtype, PrecisionLike
88

99
# Imports from local modules
10-
from .simple_vit import PatchEmbedding, unpatchify
10+
from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention
1111
from .common import kernel_init, FourierEmbedding, TimeProjection
1212
from .attention import NormalAttention # Base for RoPEAttention
1313
# Replace common.hilbert_indices with improved implementation from hilbert.py
1414
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
1515

16-
# --- Rotary Positional Embedding (RoPE) ---
17-
# Re-used from simple_dit.py
18-
19-
20-
def _rotate_half(x: jax.Array) -> jax.Array:
21-
"""Rotates half the hidden dims of the input."""
22-
x1 = x[..., : x.shape[-1] // 2]
23-
x2 = x[..., x.shape[-1] // 2:]
24-
return jnp.concatenate((-x2, x1), axis=-1)
25-
26-
27-
def apply_rotary_embedding(
28-
x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
29-
) -> jax.Array:
30-
"""Applies rotary embedding to the input tensor using rotate_half method."""
31-
if x.ndim == 4: # [B, H, S, D]
32-
cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
33-
sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
34-
elif x.ndim == 3: # [B, S, D]
35-
cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
36-
sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
37-
else:
38-
raise ValueError(f"Unsupported input dimension: {x.ndim}")
39-
40-
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
41-
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
42-
43-
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
44-
return x_rotated.astype(x.dtype)
45-
46-
47-
class RotaryEmbedding(nn.Module):
48-
dim: int
49-
max_seq_len: int = 4096 # Increased default based on SimpleDiT
50-
base: int = 10000
51-
dtype: Dtype = jnp.float32
52-
53-
def setup(self):
54-
inv_freq = 1.0 / (
55-
self.base ** (jnp.arange(0, self.dim, 2,
56-
dtype=jnp.float32) / self.dim)
57-
)
58-
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
59-
freqs = jnp.outer(t, inv_freq)
60-
self.freqs_cos = jnp.cos(freqs)
61-
self.freqs_sin = jnp.sin(freqs)
62-
63-
def __call__(self, seq_len: int):
64-
if seq_len > self.max_seq_len:
65-
# Dynamically extend frequencies if needed (more robust)
66-
t = jnp.arange(seq_len, dtype=jnp.float32)
67-
inv_freq = 1.0 / (
68-
self.base ** (jnp.arange(0, self.dim, 2,
69-
dtype=jnp.float32) / self.dim)
70-
)
71-
freqs = jnp.outer(t, inv_freq)
72-
freqs_cos = jnp.cos(freqs)
73-
freqs_sin = jnp.sin(freqs)
74-
# Consider caching extended freqs if this happens often
75-
return freqs_cos, freqs_sin
76-
# Or raise error like before:
77-
# raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
78-
return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
79-
80-
# --- Attention with RoPE ---
81-
# Re-used from simple_dit.py
82-
83-
84-
class RoPEAttention(NormalAttention):
85-
rope_emb: RotaryEmbedding = None
86-
87-
@nn.compact
88-
def __call__(self, x, context=None, freqs_cis=None):
89-
orig_x_shape = x.shape
90-
is_4d = len(orig_x_shape) == 4
91-
if is_4d:
92-
B, H, W, C = x.shape
93-
seq_len = H * W
94-
x = x.reshape((B, seq_len, C))
95-
else:
96-
B, seq_len, C = x.shape
97-
98-
context = x if context is None else context
99-
if len(context.shape) == 4:
100-
_B, _H, _W, _C = context.shape
101-
context_seq_len = _H * _W
102-
context = context.reshape((B, context_seq_len, _C))
103-
# else: # context is already [B, S_ctx, C]
104-
105-
query = self.query(x) # [B, S, H, D]
106-
key = self.key(context) # [B, S_ctx, H, D]
107-
value = self.value(context) # [B, S_ctx, H, D]
108-
109-
if freqs_cis is None and self.rope_emb is not None:
110-
seq_len_q = query.shape[1] # Use query's sequence length
111-
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
112-
elif freqs_cis is not None:
113-
freqs_cos, freqs_sin = freqs_cis
114-
else:
115-
# Should not happen if rope_emb is provided or freqs_cis are passed
116-
raise ValueError("RoPE frequencies not provided.")
117-
118-
# Apply RoPE to query and key
119-
# Permute to [B, H, S, D] for RoPE application
120-
query = einops.rearrange(query, 'b s h d -> b h s d')
121-
key = einops.rearrange(key, 'b s h d -> b h s d')
122-
123-
# Apply RoPE only up to the context sequence length for keys if different
124-
# Assuming self-attention or context has same seq len for simplicity here
125-
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
126-
key = apply_rotary_embedding(
127-
key, freqs_cos, freqs_sin) # Apply same freqs to key
128-
129-
# Permute back to [B, S, H, D] for dot_product_attention
130-
query = einops.rearrange(query, 'b h s d -> b s h d')
131-
key = einops.rearrange(key, 'b h s d -> b s h d')
132-
133-
hidden_states = nn.dot_product_attention(
134-
query, key, value, dtype=self.dtype, broadcast_dropout=False,
135-
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
136-
deterministic=True
137-
)
138-
139-
proj = self.proj_attn(hidden_states)
140-
141-
if is_4d:
142-
proj = proj.reshape(orig_x_shape)
143-
144-
return proj
145-
146-
14716
# --- MM-DiT AdaLN-Zero ---
14817
class MMAdaLNZero(nn.Module):
14918
"""

0 commit comments

Comments
 (0)