Skip to content

Commit 37ac734

Browse files
eltsaieltsai
andauthored
Added named_scope for WAN 2.1 Xprof profiling (#285)
* Added named_scope for WAN 2.1 Xprof profiling * Added flag to enable named_scope * Make named_scope in flash_attention respect enable_jax_named_scopes flag * Added named scope for WanModel output --------- Co-authored-by: eltsai <[email protected]>
1 parent 2017810 commit 37ac734

File tree

4 files changed

+133
-66
lines changed

4 files changed

+133
-66
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ enable_profiler: False
283283
skip_first_n_steps_for_profiler: 5
284284
profiler_steps: 10
285285

286+
# Enable JAX named scopes for detailed profiling and debugging
287+
# When enabled, adds named scopes around key operations in transformer and attention layers
288+
enable_jax_named_scopes: False
289+
286290
# Generation parameters
287291
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
288292
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."

src/maxdiffusion/models/attention_flax.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import functools
1617
import math
1718
from typing import Optional, Callable, Tuple
@@ -805,6 +806,7 @@ def __init__(
805806
is_self_attention: bool = True,
806807
mask_padding_tokens: bool = True,
807808
residual_checkpoint_name: str | None = None,
809+
enable_jax_named_scopes: bool = False,
808810
):
809811
if attention_kernel == "cudnn_flash_te":
810812
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
@@ -820,6 +822,7 @@ def __init__(
820822
self.key_axis_names = key_axis_names
821823
self.value_axis_names = value_axis_names
822824
self.out_axis_names = out_axis_names
825+
self.enable_jax_named_scopes = enable_jax_named_scopes
823826

824827
if is_self_attention:
825828
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
@@ -952,6 +955,10 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
952955

953956
return xq_out, xk_out
954957

958+
def conditional_named_scope(self, name: str):
959+
"""Return a JAX named scope if enabled, otherwise a null context."""
960+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
961+
955962
def __call__(
956963
self,
957964
hidden_states: jax.Array,
@@ -966,29 +973,41 @@ def __call__(
966973
if encoder_hidden_states is None:
967974
encoder_hidden_states = hidden_states
968975

969-
query_proj = self.query(hidden_states)
970-
key_proj = self.key(encoder_hidden_states)
971-
value_proj = self.value(encoder_hidden_states)
976+
with self.conditional_named_scope("attn_qkv_proj"):
977+
with self.conditional_named_scope("proj_query"):
978+
query_proj = self.query(hidden_states)
979+
with self.conditional_named_scope("proj_key"):
980+
key_proj = self.key(encoder_hidden_states)
981+
with self.conditional_named_scope("proj_value"):
982+
value_proj = self.value(encoder_hidden_states)
972983

973984
if self.qk_norm:
974-
query_proj = self.norm_q(query_proj)
975-
key_proj = self.norm_k(key_proj)
985+
with self.conditional_named_scope("attn_q_norm"):
986+
query_proj = self.norm_q(query_proj)
987+
with self.conditional_named_scope("attn_k_norm"):
988+
key_proj = self.norm_k(key_proj)
989+
976990
if rotary_emb is not None:
977-
query_proj = _unflatten_heads(query_proj, self.heads)
978-
key_proj = _unflatten_heads(key_proj, self.heads)
979-
value_proj = _unflatten_heads(value_proj, self.heads)
980-
# output of _unflatten_heads Batch, heads, seq_len, head_dim
981-
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
991+
with self.conditional_named_scope("attn_rope"):
992+
query_proj = _unflatten_heads(query_proj, self.heads)
993+
key_proj = _unflatten_heads(key_proj, self.heads)
994+
value_proj = _unflatten_heads(value_proj, self.heads)
995+
# output of _unflatten_heads Batch, heads, seq_len, head_dim
996+
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
982997

983998
query_proj = checkpoint_name(query_proj, "query_proj")
984999
key_proj = checkpoint_name(key_proj, "key_proj")
9851000
value_proj = checkpoint_name(value_proj, "value_proj")
986-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1001+
1002+
with self.conditional_named_scope("attn_compute"):
1003+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
9871004

9881005
attn_output = attn_output.astype(dtype=dtype)
9891006
attn_output = checkpoint_name(attn_output, "attn_output")
990-
hidden_states = self.proj_attn(attn_output)
991-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
1007+
1008+
with self.conditional_named_scope("attn_out_proj"):
1009+
hidden_states = self.proj_attn(attn_output)
1010+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
9921011
return hidden_states
9931012

9941013

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 96 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
from typing import Tuple, Optional, Dict, Union, Any
18+
import contextlib
1819
import math
1920
import jax
2021
import jax.numpy as jnp
@@ -205,11 +206,13 @@ def __init__(
205206
dtype: jnp.dtype = jnp.float32,
206207
weights_dtype: jnp.dtype = jnp.float32,
207208
precision: jax.lax.Precision = None,
209+
enable_jax_named_scopes: bool = False,
208210
):
209211
if inner_dim is None:
210212
inner_dim = int(dim * mult)
211213
dim_out = dim_out if dim_out is not None else dim
212214

215+
self.enable_jax_named_scopes = enable_jax_named_scopes
213216
self.act_fn = nnx.data(None)
214217
if activation_fn == "gelu-approximate":
215218
self.act_fn = ApproximateGELU(
@@ -236,11 +239,17 @@ def __init__(
236239
),
237240
)
238241

242+
def conditional_named_scope(self, name: str):
243+
"""Return a JAX named scope if enabled, otherwise a null context."""
244+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
245+
239246
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
240-
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
241-
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
242-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
243-
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
247+
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
248+
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
249+
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
250+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
251+
with self.conditional_named_scope("mlp_down_proj"):
252+
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
244253

245254

246255
class WanTransformerBlock(nnx.Module):
@@ -265,8 +274,11 @@ def __init__(
265274
attention: str = "dot_product",
266275
dropout: float = 0.0,
267276
mask_padding_tokens: bool = True,
277+
enable_jax_named_scopes: bool = False,
268278
):
269279

280+
self.enable_jax_named_scopes = enable_jax_named_scopes
281+
270282
# 1. Self-attention
271283
self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
272284
self.attn1 = FlaxWanAttention(
@@ -287,6 +299,7 @@ def __init__(
287299
is_self_attention=True,
288300
mask_padding_tokens=mask_padding_tokens,
289301
residual_checkpoint_name="self_attn",
302+
enable_jax_named_scopes=enable_jax_named_scopes,
290303
)
291304

292305
# 1. Cross-attention
@@ -308,6 +321,7 @@ def __init__(
308321
is_self_attention=False,
309322
mask_padding_tokens=mask_padding_tokens,
310323
residual_checkpoint_name="cross_attn",
324+
enable_jax_named_scopes=enable_jax_named_scopes,
311325
)
312326
assert cross_attn_norm is True
313327
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -322,6 +336,7 @@ def __init__(
322336
weights_dtype=weights_dtype,
323337
precision=precision,
324338
dropout=dropout,
339+
enable_jax_named_scopes=enable_jax_named_scopes,
325340
)
326341
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
327342

@@ -330,6 +345,10 @@ def __init__(
330345
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
331346
)
332347

348+
def conditional_named_scope(self, name: str):
349+
"""Return a JAX named scope if enabled, otherwise a null context."""
350+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
351+
333352
def __call__(
334353
self,
335354
hidden_states: jax.Array,
@@ -339,45 +358,59 @@ def __call__(
339358
deterministic: bool = True,
340359
rngs: nnx.Rngs = None,
341360
):
342-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
343-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
344-
)
345-
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
346-
hidden_states = checkpoint_name(hidden_states, "hidden_states")
347-
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
348-
349-
# 1. Self-attention
350-
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
351-
hidden_states.dtype
352-
)
353-
attn_output = self.attn1(
354-
hidden_states=norm_hidden_states,
355-
encoder_hidden_states=norm_hidden_states,
356-
rotary_emb=rotary_emb,
357-
deterministic=deterministic,
358-
rngs=rngs,
359-
)
360-
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
361-
362-
# 2. Cross-attention
363-
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
364-
attn_output = self.attn2(
365-
hidden_states=norm_hidden_states,
366-
encoder_hidden_states=encoder_hidden_states,
367-
deterministic=deterministic,
368-
rngs=rngs,
369-
)
370-
hidden_states = hidden_states + attn_output
371-
372-
# 3. Feed-forward
373-
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
374-
hidden_states.dtype
375-
)
376-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
377-
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
378-
hidden_states.dtype
379-
)
380-
return hidden_states
361+
with self.conditional_named_scope("transformer_block"):
362+
with self.conditional_named_scope("adaln"):
363+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
364+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
365+
)
366+
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
367+
hidden_states = checkpoint_name(hidden_states, "hidden_states")
368+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
369+
370+
# 1. Self-attention
371+
with self.conditional_named_scope("self_attn"):
372+
with self.conditional_named_scope("self_attn_norm"):
373+
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
374+
hidden_states.dtype
375+
)
376+
with self.conditional_named_scope("self_attn_attn"):
377+
attn_output = self.attn1(
378+
hidden_states=norm_hidden_states,
379+
encoder_hidden_states=norm_hidden_states,
380+
rotary_emb=rotary_emb,
381+
deterministic=deterministic,
382+
rngs=rngs,
383+
)
384+
with self.conditional_named_scope("self_attn_residual"):
385+
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
386+
387+
# 2. Cross-attention
388+
with self.conditional_named_scope("cross_attn"):
389+
with self.conditional_named_scope("cross_attn_norm"):
390+
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
391+
with self.conditional_named_scope("cross_attn_attn"):
392+
attn_output = self.attn2(
393+
hidden_states=norm_hidden_states,
394+
encoder_hidden_states=encoder_hidden_states,
395+
deterministic=deterministic,
396+
rngs=rngs,
397+
)
398+
with self.conditional_named_scope("cross_attn_residual"):
399+
hidden_states = hidden_states + attn_output
400+
401+
# 3. Feed-forward
402+
with self.conditional_named_scope("mlp"):
403+
with self.conditional_named_scope("mlp_norm"):
404+
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
405+
hidden_states.dtype
406+
)
407+
with self.conditional_named_scope("mlp_ffn"):
408+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
409+
with self.conditional_named_scope("mlp_residual"):
410+
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
411+
hidden_states.dtype
412+
)
413+
return hidden_states
381414

382415

383416
class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
@@ -416,11 +449,13 @@ def __init__(
416449
names_which_can_be_offloaded: list = [],
417450
mask_padding_tokens: bool = True,
418451
scan_layers: bool = True,
452+
enable_jax_named_scopes: bool = False,
419453
):
420454
inner_dim = num_attention_heads * attention_head_dim
421455
out_channels = out_channels or in_channels
422456
self.num_layers = num_layers
423457
self.scan_layers = scan_layers
458+
self.enable_jax_named_scopes = enable_jax_named_scopes
424459

425460
# 1. Patch & position embedding
426461
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -472,6 +507,7 @@ def init_block(rngs):
472507
attention=attention,
473508
dropout=dropout,
474509
mask_padding_tokens=mask_padding_tokens,
510+
enable_jax_named_scopes=enable_jax_named_scopes,
475511
)
476512

477513
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -497,6 +533,7 @@ def init_block(rngs):
497533
weights_dtype=weights_dtype,
498534
precision=precision,
499535
attention=attention,
536+
enable_jax_named_scopes=enable_jax_named_scopes,
500537
)
501538
blocks.append(block)
502539
self.blocks = blocks
@@ -517,6 +554,10 @@ def init_block(rngs):
517554
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
518555
)
519556

557+
def conditional_named_scope(self, name: str):
558+
"""Return a JAX named scope if enabled, otherwise a null context."""
559+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
560+
520561
def __call__(
521562
self,
522563
hidden_states: jax.Array,
@@ -536,14 +577,15 @@ def __call__(
536577
post_patch_width = width // p_w
537578

538579
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
539-
rotary_emb = self.rope(hidden_states)
540-
with jax.named_scope("PatchEmbedding"):
580+
with self.conditional_named_scope("rotary_embedding"):
581+
rotary_emb = self.rope(hidden_states)
582+
with self.conditional_named_scope("patch_embedding"):
541583
hidden_states = self.patch_embedding(hidden_states)
542-
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
543-
544-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
545-
timestep, encoder_hidden_states, encoder_hidden_states_image
546-
)
584+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
585+
with self.conditional_named_scope("condition_embedder"):
586+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
587+
timestep, encoder_hidden_states, encoder_hidden_states_image
588+
)
547589
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
548590

549591
if encoder_hidden_states_image is not None:
@@ -583,9 +625,10 @@ def layer_forward(hidden_states):
583625
hidden_states = rematted_layer_forward(hidden_states)
584626

585627
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
586-
587-
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
588-
hidden_states = self.proj_out(hidden_states)
628+
with self.conditional_named_scope("output_norm"):
629+
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
630+
with self.conditional_named_scope("output_proj"):
631+
hidden_states = self.proj_out(hidden_states)
589632

590633
hidden_states = hidden_states.reshape(
591634
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
114114
wan_config["dropout"] = config.dropout
115115
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
116116
wan_config["scan_layers"] = config.scan_layers
117+
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes
117118

118119
# 2. eval_shape - will not use flops or create weights on device
119120
# thus not using HBM memory.

0 commit comments

Comments
 (0)