1515"""
1616
1717from typing import Tuple , Optional , Dict , Union , Any
18+ import contextlib
1819import math
1920import jax
2021import jax .numpy as jnp
@@ -205,11 +206,13 @@ def __init__(
205206 dtype : jnp .dtype = jnp .float32 ,
206207 weights_dtype : jnp .dtype = jnp .float32 ,
207208 precision : jax .lax .Precision = None ,
209+ enable_jax_named_scopes : bool = False ,
208210 ):
209211 if inner_dim is None :
210212 inner_dim = int (dim * mult )
211213 dim_out = dim_out if dim_out is not None else dim
212214
215+ self .enable_jax_named_scopes = enable_jax_named_scopes
213216 self .act_fn = nnx .data (None )
214217 if activation_fn == "gelu-approximate" :
215218 self .act_fn = ApproximateGELU (
@@ -236,11 +239,17 @@ def __init__(
236239 ),
237240 )
238241
242+ def conditional_named_scope (self , name : str ):
243+ """Return a JAX named scope if enabled, otherwise a null context."""
244+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
245+
239246 def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ) -> jax .Array :
240- hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
241- hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
242- hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
243- return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
247+ with self .conditional_named_scope ("mlp_up_proj_and_gelu" ):
248+ hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
249+ hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
250+ hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
251+ with self .conditional_named_scope ("mlp_down_proj" ):
252+ return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
244253
245254
246255class WanTransformerBlock (nnx .Module ):
@@ -265,8 +274,11 @@ def __init__(
265274 attention : str = "dot_product" ,
266275 dropout : float = 0.0 ,
267276 mask_padding_tokens : bool = True ,
277+ enable_jax_named_scopes : bool = False ,
268278 ):
269279
280+ self .enable_jax_named_scopes = enable_jax_named_scopes
281+
270282 # 1. Self-attention
271283 self .norm1 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
272284 self .attn1 = FlaxWanAttention (
@@ -287,6 +299,7 @@ def __init__(
287299 is_self_attention = True ,
288300 mask_padding_tokens = mask_padding_tokens ,
289301 residual_checkpoint_name = "self_attn" ,
302+ enable_jax_named_scopes = enable_jax_named_scopes ,
290303 )
291304
292305 # 1. Cross-attention
@@ -308,6 +321,7 @@ def __init__(
308321 is_self_attention = False ,
309322 mask_padding_tokens = mask_padding_tokens ,
310323 residual_checkpoint_name = "cross_attn" ,
324+ enable_jax_named_scopes = enable_jax_named_scopes ,
311325 )
312326 assert cross_attn_norm is True
313327 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -322,6 +336,7 @@ def __init__(
322336 weights_dtype = weights_dtype ,
323337 precision = precision ,
324338 dropout = dropout ,
339+ enable_jax_named_scopes = enable_jax_named_scopes ,
325340 )
326341 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
327342
@@ -330,6 +345,10 @@ def __init__(
330345 jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
331346 )
332347
348+ def conditional_named_scope (self , name : str ):
349+ """Return a JAX named scope if enabled, otherwise a null context."""
350+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
351+
333352 def __call__ (
334353 self ,
335354 hidden_states : jax .Array ,
@@ -339,45 +358,59 @@ def __call__(
339358 deterministic : bool = True ,
340359 rngs : nnx .Rngs = None ,
341360 ):
342- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
343- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
344- )
345- hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
346- hidden_states = checkpoint_name (hidden_states , "hidden_states" )
347- encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
348-
349- # 1. Self-attention
350- norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
351- hidden_states .dtype
352- )
353- attn_output = self .attn1 (
354- hidden_states = norm_hidden_states ,
355- encoder_hidden_states = norm_hidden_states ,
356- rotary_emb = rotary_emb ,
357- deterministic = deterministic ,
358- rngs = rngs ,
359- )
360- hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
361-
362- # 2. Cross-attention
363- norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
364- attn_output = self .attn2 (
365- hidden_states = norm_hidden_states ,
366- encoder_hidden_states = encoder_hidden_states ,
367- deterministic = deterministic ,
368- rngs = rngs ,
369- )
370- hidden_states = hidden_states + attn_output
371-
372- # 3. Feed-forward
373- norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
374- hidden_states .dtype
375- )
376- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
377- hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
378- hidden_states .dtype
379- )
380- return hidden_states
361+ with self .conditional_named_scope ("transformer_block" ):
362+ with self .conditional_named_scope ("adaln" ):
363+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
364+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
365+ )
366+ hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
367+ hidden_states = checkpoint_name (hidden_states , "hidden_states" )
368+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
369+
370+ # 1. Self-attention
371+ with self .conditional_named_scope ("self_attn" ):
372+ with self .conditional_named_scope ("self_attn_norm" ):
373+ norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
374+ hidden_states .dtype
375+ )
376+ with self .conditional_named_scope ("self_attn_attn" ):
377+ attn_output = self .attn1 (
378+ hidden_states = norm_hidden_states ,
379+ encoder_hidden_states = norm_hidden_states ,
380+ rotary_emb = rotary_emb ,
381+ deterministic = deterministic ,
382+ rngs = rngs ,
383+ )
384+ with self .conditional_named_scope ("self_attn_residual" ):
385+ hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
386+
387+ # 2. Cross-attention
388+ with self .conditional_named_scope ("cross_attn" ):
389+ with self .conditional_named_scope ("cross_attn_norm" ):
390+ norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
391+ with self .conditional_named_scope ("cross_attn_attn" ):
392+ attn_output = self .attn2 (
393+ hidden_states = norm_hidden_states ,
394+ encoder_hidden_states = encoder_hidden_states ,
395+ deterministic = deterministic ,
396+ rngs = rngs ,
397+ )
398+ with self .conditional_named_scope ("cross_attn_residual" ):
399+ hidden_states = hidden_states + attn_output
400+
401+ # 3. Feed-forward
402+ with self .conditional_named_scope ("mlp" ):
403+ with self .conditional_named_scope ("mlp_norm" ):
404+ norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
405+ hidden_states .dtype
406+ )
407+ with self .conditional_named_scope ("mlp_ffn" ):
408+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
409+ with self .conditional_named_scope ("mlp_residual" ):
410+ hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
411+ hidden_states .dtype
412+ )
413+ return hidden_states
381414
382415
383416class WanModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
@@ -416,11 +449,13 @@ def __init__(
416449 names_which_can_be_offloaded : list = [],
417450 mask_padding_tokens : bool = True ,
418451 scan_layers : bool = True ,
452+ enable_jax_named_scopes : bool = False ,
419453 ):
420454 inner_dim = num_attention_heads * attention_head_dim
421455 out_channels = out_channels or in_channels
422456 self .num_layers = num_layers
423457 self .scan_layers = scan_layers
458+ self .enable_jax_named_scopes = enable_jax_named_scopes
424459
425460 # 1. Patch & position embedding
426461 self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
@@ -472,6 +507,7 @@ def init_block(rngs):
472507 attention = attention ,
473508 dropout = dropout ,
474509 mask_padding_tokens = mask_padding_tokens ,
510+ enable_jax_named_scopes = enable_jax_named_scopes ,
475511 )
476512
477513 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
@@ -497,6 +533,7 @@ def init_block(rngs):
497533 weights_dtype = weights_dtype ,
498534 precision = precision ,
499535 attention = attention ,
536+ enable_jax_named_scopes = enable_jax_named_scopes ,
500537 )
501538 blocks .append (block )
502539 self .blocks = blocks
@@ -517,6 +554,10 @@ def init_block(rngs):
517554 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
518555 )
519556
557+ def conditional_named_scope (self , name : str ):
558+ """Return a JAX named scope if enabled, otherwise a null context."""
559+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
560+
520561 def __call__ (
521562 self ,
522563 hidden_states : jax .Array ,
@@ -536,14 +577,15 @@ def __call__(
536577 post_patch_width = width // p_w
537578
538579 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
539- rotary_emb = self .rope (hidden_states )
540- with jax .named_scope ("PatchEmbedding" ):
580+ with self .conditional_named_scope ("rotary_embedding" ):
581+ rotary_emb = self .rope (hidden_states )
582+ with self .conditional_named_scope ("patch_embedding" ):
541583 hidden_states = self .patch_embedding (hidden_states )
542- hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
543-
544- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
545- timestep , encoder_hidden_states , encoder_hidden_states_image
546- )
584+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
585+ with self . conditional_named_scope ( "condition_embedder" ):
586+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
587+ timestep , encoder_hidden_states , encoder_hidden_states_image
588+ )
547589 timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
548590
549591 if encoder_hidden_states_image is not None :
@@ -583,9 +625,10 @@ def layer_forward(hidden_states):
583625 hidden_states = rematted_layer_forward (hidden_states )
584626
585627 shift , scale = jnp .split (self .scale_shift_table + jnp .expand_dims (temb , axis = 1 ), 2 , axis = 1 )
586-
587- hidden_states = (self .norm_out (hidden_states .astype (jnp .float32 )) * (1 + scale ) + shift ).astype (hidden_states .dtype )
588- hidden_states = self .proj_out (hidden_states )
628+ with self .conditional_named_scope ("output_norm" ):
629+ hidden_states = (self .norm_out (hidden_states .astype (jnp .float32 )) * (1 + scale ) + shift ).astype (hidden_states .dtype )
630+ with self .conditional_named_scope ("output_proj" ):
631+ hidden_states = self .proj_out (hidden_states )
589632
590633 hidden_states = hidden_states .reshape (
591634 batch_size , post_patch_num_frames , post_patch_height , post_patch_width , p_t , p_h , p_w , - 1
0 commit comments