1313)
1414from diffsynth_engine .models .basic .timestep import TimestepEmbeddings
1515from diffsynth_engine .models .base import PreTrainedModel , StateDictConverter
16+ from diffsynth_engine .models .basic import attention as attention_ops
1617from diffsynth_engine .models .utils import no_init_weights
1718from diffsynth_engine .utils .gguf import gguf_inference
1819from diffsynth_engine .utils .fp8_linear import fp8_inference
1920from diffsynth_engine .utils .constants import FLUX_DIT_CONFIG_FILE
20- from diffsynth_engine .models . basic . attention import attention
21+ from diffsynth_engine .utils . parallel import sequence_parallel , sequence_parallel_unshard
2122from diffsynth_engine .utils import logging
2223
2324
@@ -198,7 +199,7 @@ def forward(self, image, text, rope_emb, image_emb):
198199 k = torch .cat ([self .norm_k_b (k_b ), self .norm_k_a (k_a )], dim = 1 )
199200 v = torch .cat ([v_b , v_a ], dim = 1 )
200201 q , k = apply_rope (q , k , rope_emb )
201- attn_out = attention (q , k , v , attn_impl = self .attn_impl )
202+ attn_out = attention_ops . attention (q , k , v , attn_impl = self .attn_impl )
202203 attn_out = rearrange (attn_out , "b s h d -> b s (h d)" ).to (q .dtype )
203204 text_out , image_out = attn_out [:, : text .shape [1 ]], attn_out [:, text .shape [1 ] :]
204205 image_out , text_out = self .attention_callback (
@@ -286,7 +287,7 @@ def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
286287 def forward (self , x , rope_emb , image_emb ):
287288 q , k , v = rearrange (self .to_qkv (x ), "b s (h d) -> b s h d" , h = (3 * self .num_heads )).chunk (3 , dim = 2 )
288289 q , k = apply_rope (self .norm_q_a (q ), self .norm_k_a (k ), rope_emb )
289- attn_out = attention (q , k , v , attn_impl = self .attn_impl )
290+ attn_out = attention_ops . attention (q , k , v , attn_impl = self .attn_impl )
290291 attn_out = rearrange (attn_out , "b s h d -> b s (h d)" ).to (q .dtype )
291292 return self .attention_callback (attn_out = attn_out , x = x , q = q , k = k , v = v , rope_emb = rope_emb , image_emb = image_emb )
292293
@@ -324,6 +325,7 @@ def __init__(
324325 self ,
325326 in_channel : int = 64 ,
326327 attn_impl : Optional [str ] = None ,
328+ use_usp : bool = False ,
327329 device : str = "cuda:0" ,
328330 dtype : torch .dtype = torch .bfloat16 ,
329331 ):
@@ -349,6 +351,8 @@ def __init__(
349351 self .final_norm_out = AdaLayerNorm (3072 , device = device , dtype = dtype )
350352 self .final_proj_out = nn .Linear (3072 , 64 , device = device , dtype = dtype )
351353
354+ self .use_usp = use_usp
355+
352356 def patchify (self , hidden_states ):
353357 hidden_states = rearrange (hidden_states , "B C (H P) (W Q) -> B (H W) (C P Q)" , P = 2 , Q = 2 )
354358 return hidden_states
@@ -359,7 +363,8 @@ def unpatchify(self, hidden_states, height, width):
359363 )
360364 return hidden_states
361365
362- def prepare_image_ids (self , latents ):
366+ @staticmethod
367+ def prepare_image_ids (latents : torch .Tensor ):
363368 batch_size , _ , height , width = latents .shape
364369 latent_image_ids = torch .zeros (height // 2 , width // 2 , 3 )
365370 latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height // 2 )[:, None ]
@@ -389,7 +394,14 @@ def forward(
389394 controlnet_single_block_output = None ,
390395 ** kwargs ,
391396 ):
392- height , width = hidden_states .shape [- 2 :]
397+ h , w = hidden_states .shape [- 2 :]
398+ controlnet_double_block_output = (
399+ controlnet_double_block_output if controlnet_double_block_output is not None else ()
400+ )
401+ controlnet_single_block_output = (
402+ controlnet_single_block_output if controlnet_single_block_output is not None else ()
403+ )
404+
393405 fp8_linear_enabled = getattr (self , "fp8_linear_enabled" , False )
394406 with fp8_inference (fp8_linear_enabled ), gguf_inference ():
395407 if image_ids is None :
@@ -402,28 +414,54 @@ def forward(
402414 guidance = guidance * 1000
403415 conditioning += self .guidance_embedder (guidance , hidden_states .dtype )
404416 conditioning += self .pooled_text_embedder (pooled_prompt_emb )
405- prompt_emb = self .context_embedder (prompt_emb )
406417 rope_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
418+ text_rope_emb = rope_emb [:, :, : text_ids .size (1 )]
419+ image_rope_emb = rope_emb [:, :, text_ids .size (1 ) :]
407420 hidden_states = self .patchify (hidden_states )
408- hidden_states = self .x_embedder (hidden_states )
409- for i , block in enumerate (self .blocks ):
410- hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , rope_emb , image_emb )
411- if controlnet_double_block_output is not None :
412- interval_control = len (self .blocks ) / len (controlnet_double_block_output )
413- interval_control = int (np .ceil (interval_control ))
414- hidden_states = hidden_states + controlnet_double_block_output [i // interval_control ]
415- hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
416- for i , block in enumerate (self .single_blocks ):
417- hidden_states = block (hidden_states , conditioning , rope_emb , image_emb )
418- if controlnet_single_block_output is not None :
419- interval_control = len (self .single_blocks ) / len (controlnet_double_block_output )
420- interval_control = int (np .ceil (interval_control ))
421- hidden_states = hidden_states + controlnet_single_block_output [i // interval_control ]
422-
423- hidden_states = hidden_states [:, prompt_emb .shape [1 ] :]
424- hidden_states = self .final_norm_out (hidden_states , conditioning )
425- hidden_states = self .final_proj_out (hidden_states )
426- hidden_states = self .unpatchify (hidden_states , height , width )
421+
422+ with sequence_parallel (
423+ (
424+ hidden_states ,
425+ prompt_emb ,
426+ text_rope_emb ,
427+ image_rope_emb ,
428+ * controlnet_double_block_output ,
429+ * controlnet_single_block_output ,
430+ ),
431+ seq_dims = (
432+ 1 ,
433+ 1 ,
434+ 2 ,
435+ 2 ,
436+ * (1 for _ in controlnet_double_block_output ),
437+ * (1 for _ in controlnet_single_block_output ),
438+ ),
439+ enabled = self .use_usp ,
440+ ):
441+ hidden_states = self .x_embedder (hidden_states )
442+ prompt_emb = self .context_embedder (prompt_emb )
443+ rope_emb = torch .cat ((text_rope_emb , image_rope_emb ), dim = 2 )
444+
445+ for i , block in enumerate (self .blocks ):
446+ hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , rope_emb , image_emb )
447+ if len (controlnet_double_block_output ) > 0 :
448+ interval_control = len (self .blocks ) / len (controlnet_double_block_output )
449+ interval_control = int (np .ceil (interval_control ))
450+ hidden_states = hidden_states + controlnet_double_block_output [i // interval_control ]
451+ hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
452+ for i , block in enumerate (self .single_blocks ):
453+ hidden_states = block (hidden_states , conditioning , rope_emb , image_emb )
454+ if len (controlnet_single_block_output ) > 0 :
455+ interval_control = len (self .single_blocks ) / len (controlnet_double_block_output )
456+ interval_control = int (np .ceil (interval_control ))
457+ hidden_states = hidden_states + controlnet_single_block_output [i // interval_control ]
458+
459+ hidden_states = hidden_states [:, prompt_emb .shape [1 ] :]
460+ hidden_states = self .final_norm_out (hidden_states , conditioning )
461+ hidden_states = self .final_proj_out (hidden_states )
462+ (hidden_states ,) = sequence_parallel_unshard ((hidden_states ,), seq_dims = (1 ,), seq_lens = (h * w // 4 ,))
463+
464+ hidden_states = self .unpatchify (hidden_states , h , w )
427465 return hidden_states
428466
429467 @classmethod
@@ -434,6 +472,7 @@ def from_state_dict(
434472 dtype : torch .dtype ,
435473 in_channel : int = 64 ,
436474 attn_impl : Optional [str ] = None ,
475+ use_usp : bool = False ,
437476 ):
438477 with no_init_weights ():
439478 model = torch .nn .utils .skip_init (
@@ -442,6 +481,7 @@ def from_state_dict(
442481 dtype = dtype ,
443482 in_channel = in_channel ,
444483 attn_impl = attn_impl ,
484+ use_usp = use_usp ,
445485 )
446486 model = model .requires_grad_ (False ) # for loading gguf
447487 model .load_state_dict (state_dict , assign = True )
0 commit comments