|
1 | | -from .simple_vit import PatchEmbedding, unpatchify |
2 | 1 | import jax |
3 | 2 | import jax.numpy as jnp |
4 | 3 | from flax import linen as nn |
|
7 | 6 | from functools import partial |
8 | 7 |
|
9 | 8 | # Re-use existing components if they are suitable |
| 9 | +from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention, AdaLNParams |
10 | 10 | from .common import kernel_init, FourierEmbedding, TimeProjection |
11 | 11 | # Using NormalAttention for RoPE integration |
12 | 12 | from .attention import NormalAttention |
|
15 | 15 | # Use our improved Hilbert implementation |
16 | 16 | from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify |
17 | 17 |
|
18 | | -# --- Rotary Positional Embedding (RoPE) --- |
19 | | -# Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py |
20 | | - |
21 | | - |
22 | | -def _rotate_half(x: jax.Array) -> jax.Array: |
23 | | - """Rotates half the hidden dims of the input.""" |
24 | | - x1 = x[..., : x.shape[-1] // 2] |
25 | | - x2 = x[..., x.shape[-1] // 2:] |
26 | | - return jnp.concatenate((-x2, x1), axis=-1) |
27 | | - |
28 | | -def apply_rotary_embedding( |
29 | | - x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array |
30 | | -) -> jax.Array: |
31 | | - """Applies rotary embedding to the input tensor using rotate_half method.""" |
32 | | - # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D] |
33 | | - # freqs_cos/sin shape: [Sequence, Dimension / 2] |
34 | | - |
35 | | - # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2] |
36 | | - if x.ndim == 4: # [B, H, S, D] |
37 | | - cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1)) |
38 | | - sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1)) |
39 | | - elif x.ndim == 3: # [B, S, D] |
40 | | - cos_freqs = jnp.expand_dims(freqs_cos, axis=0) |
41 | | - sin_freqs = jnp.expand_dims(freqs_sin, axis=0) |
42 | | - |
43 | | - # Duplicate cos and sin for the full dimension D |
44 | | - # Shape becomes [..., S, D] |
45 | | - cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1) |
46 | | - sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1) |
47 | | - |
48 | | - # Apply rotation: x * cos + rotate_half(x) * sin |
49 | | - x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs |
50 | | - return x_rotated.astype(x.dtype) |
51 | | - |
52 | | -class RotaryEmbedding(nn.Module): |
53 | | - dim: int # Dimension of the head |
54 | | - max_seq_len: int = 2048 |
55 | | - base: int = 10000 |
56 | | - dtype: Dtype = jnp.float32 |
57 | | - |
58 | | - def setup(self): |
59 | | - inv_freq = 1.0 / ( |
60 | | - self.base ** (jnp.arange(0, self.dim, 2, |
61 | | - dtype=jnp.float32) / self.dim) |
62 | | - ) |
63 | | - t = jnp.arange(self.max_seq_len, dtype=jnp.float32) |
64 | | - freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2] |
65 | | - |
66 | | - # Store cosine and sine separately instead of as complex numbers |
67 | | - self.freqs_cos = jnp.cos(freqs) # Shape: [max_seq_len, dim / 2] |
68 | | - self.freqs_sin = jnp.sin(freqs) # Shape: [max_seq_len, dim / 2] |
69 | | - |
70 | | - def __call__(self, seq_len: int): |
71 | | - if seq_len > self.max_seq_len: |
72 | | - raise ValueError( |
73 | | - f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}") |
74 | | - # Return separate cos and sin components |
75 | | - return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :] |
76 | | -# --- Attention with RoPE --- |
77 | | - |
78 | | - |
79 | | -class RoPEAttention(NormalAttention): |
80 | | - rope_emb: RotaryEmbedding = None # Instance of RotaryEmbedding |
81 | | - |
82 | | - @nn.compact |
83 | | - def __call__(self, x, context=None, freqs_cis=None): |
84 | | - # x has shape [B, H, W, C] or [B, S, C] |
85 | | - orig_x_shape = x.shape |
86 | | - is_4d = len(orig_x_shape) == 4 |
87 | | - if is_4d: |
88 | | - B, H, W, C = x.shape |
89 | | - seq_len = H * W |
90 | | - x = x.reshape((B, seq_len, C)) |
91 | | - else: |
92 | | - B, seq_len, C = x.shape |
93 | | - |
94 | | - context = x if context is None else context |
95 | | - if len(context.shape) == 4: |
96 | | - _B, _H, _W, _C = context.shape |
97 | | - context_seq_len = _H * _W |
98 | | - context = context.reshape((B, context_seq_len, _C)) |
99 | | - # else: context is already [B, S_ctx, C] |
100 | | - |
101 | | - query = self.query(x) # [B, S, H, D] |
102 | | - key = self.key(context) # [B, S_ctx, H, D] |
103 | | - value = self.value(context) # [B, S_ctx, H, D] |
104 | | - |
105 | | - # Apply RoPE to query and key |
106 | | - if freqs_cis is None: |
107 | | - # Generate frequencies using the rope_emb instance |
108 | | - seq_len_q = query.shape[1] # Use query's sequence length |
109 | | - freqs_cos, freqs_sin = self.rope_emb(seq_len_q) |
110 | | - else: |
111 | | - # If freqs_cis is passed in as a tuple |
112 | | - freqs_cos, freqs_sin = freqs_cis |
113 | | - |
114 | | - # Apply RoPE to query and key |
115 | | - # Permute to [B, H, S, D] for RoPE application |
116 | | - query = einops.rearrange(query, 'b s h d -> b h s d') |
117 | | - key = einops.rearrange(key, 'b s h d -> b h s d') |
118 | | - |
119 | | - # Apply RoPE only up to the context sequence length for keys if different |
120 | | - # Assuming self-attention or context has same seq len for simplicity here |
121 | | - query = apply_rotary_embedding(query, freqs_cos, freqs_sin) |
122 | | - key = apply_rotary_embedding(key, freqs_cos, freqs_sin) # Apply same freqs to key |
123 | | - |
124 | | - # Permute back to [B, S, H, D] for dot_product_attention |
125 | | - query = einops.rearrange(query, 'b h s d -> b s h d') |
126 | | - key = einops.rearrange(key, 'b h s d -> b s h d') |
127 | | - |
128 | | - hidden_states = nn.dot_product_attention( |
129 | | - query, key, value, dtype=self.dtype, broadcast_dropout=False, |
130 | | - dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax, |
131 | | - deterministic=True |
132 | | - ) # Output shape [B, S, H, D] |
133 | | - |
134 | | - # Use the proj_attn from NormalAttention which expects [B, S, H, D] |
135 | | - proj = self.proj_attn(hidden_states) # Output shape [B, S, C] |
136 | | - |
137 | | - if is_4d: |
138 | | - proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D |
139 | | - |
140 | | - return proj |
141 | | - |
142 | | -# --- adaLN-Zero --- |
143 | | - |
144 | | - |
145 | | -class AdaLNZero(nn.Module): |
146 | | - features: int |
147 | | - dtype: Optional[Dtype] = None |
148 | | - precision: PrecisionLike = None |
149 | | - norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon |
150 | | - |
151 | | - @nn.compact |
152 | | - def __call__(self, x, conditioning): |
153 | | - # Project conditioning signal to get scale and shift parameters |
154 | | - # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting |
155 | | - # Or [B, 1, 6*features] if x is [B, S, F] |
156 | | - |
157 | | - # Ensure conditioning has seq dim if x does |
158 | | - # x=[B,S,F], cond=[B,D_cond] |
159 | | - if x.ndim == 3 and conditioning.ndim == 2: |
160 | | - conditioning = jnp.expand_dims( |
161 | | - conditioning, axis=1) # cond=[B,1,D_cond] |
162 | | - |
163 | | - # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn) |
164 | | - # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond] |
165 | | - ada_params = nn.Dense( |
166 | | - features=6 * self.features, |
167 | | - dtype=self.dtype, |
168 | | - precision=self.precision, |
169 | | - # Initialize projection to zero (Zero init) |
170 | | - kernel_init=nn.initializers.zeros, |
171 | | - name="ada_proj" |
172 | | - )(conditioning) |
173 | | - |
174 | | - # Split into scale, shift, gate for MLP and Attention |
175 | | - scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split( |
176 | | - ada_params, 6, axis=-1) |
177 | | - |
178 | | - scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0) |
179 | | - shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0) |
180 | | - # Apply Layer Normalization |
181 | | - norm = nn.LayerNorm(epsilon=self.norm_epsilon, |
182 | | - use_scale=False, use_bias=False, dtype=self.dtype) |
183 | | - # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm |
184 | | - |
185 | | - norm_x = norm(x) |
186 | | - |
187 | | - # Modulate for Attention path |
188 | | - x_attn = norm_x * (1 + scale_attn) + shift_attn |
189 | | - |
190 | | - # Modulate for MLP path |
191 | | - x_mlp = norm_x * (1 + scale_mlp) + shift_mlp |
192 | | - |
193 | | - # Return modulated outputs and gates |
194 | | - return x_attn, gate_attn, x_mlp, gate_mlp |
195 | | - |
196 | | -class AdaLNParams(nn.Module): # Renamed for clarity |
197 | | - features: int |
198 | | - dtype: Optional[Dtype] = None |
199 | | - precision: PrecisionLike = None |
200 | | - |
201 | | - @nn.compact |
202 | | - def __call__(self, conditioning): |
203 | | - # Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond]) |
204 | | - if conditioning.ndim == 2: |
205 | | - conditioning = jnp.expand_dims(conditioning, axis=1) |
206 | | - |
207 | | - # Project conditioning to get 6 params per feature |
208 | | - ada_params = nn.Dense( |
209 | | - features=6 * self.features, |
210 | | - dtype=self.dtype, |
211 | | - precision=self.precision, |
212 | | - kernel_init=nn.initializers.zeros, |
213 | | - name="ada_proj" |
214 | | - )(conditioning) |
215 | | - # Return all params (or split if preferred, but maybe return tuple/dict) |
216 | | - # Shape: [B, 1, 6*F] |
217 | | - return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1) |
218 | | - |
219 | 18 | # --- DiT Block --- |
220 | 19 | class DiTBlock(nn.Module): |
221 | 20 | features: int |
|
0 commit comments