@@ -46,38 +46,23 @@ def __init__(
4646 self .patch_size = patch_size
4747 self .text_hidden_size = text_hidden_size
4848 self .pos_embed_max_size = pos_embed_max_size
49- # Linear projection for image patches
50- self .proj = nn .Linear (in_channels * patch_size ** 2 , hidden_size )
5149
52- # Linear projection for text embeddings
50+ self . proj = nn . Linear ( in_channels * patch_size ** 2 , hidden_size )
5351 self .text_proj = nn .Linear (text_hidden_size , hidden_size )
5452
55- def forward (
56- self , hidden_states : torch .Tensor , prompt_embeds : torch .Tensor , negative_prompt_embeds : torch .Tensor | None
57- ) -> torch .Tensor :
53+ def forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ) -> torch .Tensor :
5854 batch_size , channel , height , width = hidden_states .shape
55+ post_patch_height = height // self .patch_size
56+ post_patch_width = width // self .patch_size
5957
60- if height % self .patch_size != 0 or width % self .patch_size != 0 :
61- raise ValueError ("Height and width must be divisible by patch size" )
62-
63- patch_height = height // self .patch_size
64- patch_width = width // self .patch_size
65-
66- # b, c, h, w -> b, c, patch_height, patch_size, patch_width, patch_size
67- # -> b, patch_height, patch_width, c, patch_size, patch_size
68- # -> b, patch_height * patch_width, c * patch_size * patch_size
69- hidden_states = (
70- hidden_states .reshape (batch_size , channel , patch_height , self .patch_size , patch_width , self .patch_size )
71- .permute (0 , 2 , 4 , 1 , 3 , 5 )
72- .reshape (batch_size , patch_height * patch_width , channel * self .patch_size * self .patch_size )
58+ hidden_states = hidden_states .reshape (
59+ batch_size , channel , post_patch_height , self .patch_size , post_patch_width , self .patch_size
7360 )
61+ hidden_states = hidden_states .permute (0 , 2 , 4 , 1 , 3 , 5 ).flatten (3 , 5 ).flatten (1 , 2 )
62+ hidden_states = self .proj (hidden_states )
63+ encoder_hidden_states = self .text_proj (encoder_hidden_states )
7464
75- # project
76- hidden_states = self .proj (hidden_states ) # embed_dim: 64 -> 4096
77- prompt_embeds = self .text_proj (prompt_embeds ) # embed_dim: 4096 -> 4096
78- if negative_prompt_embeds is not None :
79- negative_prompt_embeds = self .text_proj (negative_prompt_embeds ) # embed_dim: 4096 -> 4096
80- return hidden_states , prompt_embeds , negative_prompt_embeds
65+ return hidden_states , encoder_hidden_states
8166
8267
8368class CogView4AdaLayerNormZero (nn .Module ):
@@ -347,10 +332,10 @@ def __init__(
347332 self ,
348333 patch_size : int = 2 ,
349334 in_channels : int = 16 ,
335+ out_channels : int = 16 ,
350336 num_layers : int = 30 ,
351337 attention_head_dim : int = 40 ,
352338 num_attention_heads : int = 64 ,
353- out_channels : int = 16 ,
354339 text_embed_dim : int = 4096 ,
355340 time_embed_dim : int = 512 ,
356341 condition_dim : int = 256 ,
@@ -402,116 +387,46 @@ def __init__(
402387 def forward (
403388 self ,
404389 hidden_states : torch .Tensor ,
405- prompt_embeds : torch .Tensor ,
406- negative_prompt_embeds : Optional [torch .Tensor ],
390+ encoder_hidden_states : torch .Tensor ,
407391 timestep : torch .LongTensor ,
408392 original_size : torch .Tensor ,
409393 target_size : torch .Tensor ,
410394 crop_coords : torch .Tensor ,
411395 return_dict : bool = True ,
412396 ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
413- """
414- The [`CogView3PlusTransformer2DModel`] forward method.
415-
416- Args:
417- hidden_states (`torch.Tensor`):
418- Input `hidden_states` of shape `(batch size, channel, height, width)`.
419- encoder_hidden_states (`torch.Tensor`):
420- Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
421- `(batch_size, sequence_len, text_embed_dim)`
422- timestep (`torch.LongTensor`):
423- Used to indicate denoising step.
424- original_size (`torch.Tensor`):
425- CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
426- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
427- target_size (`torch.Tensor`):
428- CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
429- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
430- crop_coords (`torch.Tensor`):
431- CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
432- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
433- return_dict (`bool`, *optional*, defaults to `True`):
434- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
435- tuple.
436-
437- Returns:
438- `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
439- The denoised latents using provided inputs as conditioning.
440- """
441397 batch_size , num_channels , height , width = hidden_states .shape
442- do_cfg = negative_prompt_embeds is not None
443-
444- if do_cfg :
445- assert (
446- batch_size == prompt_embeds .shape [0 ] + negative_prompt_embeds .shape [0 ]
447- ), "batch size mismatch in CFG mode"
448- else :
449- assert batch_size == prompt_embeds .shape [0 ], "batch size mismatch in non-CFG mode"
450398
451399 # 1. RoPE
452400 image_rotary_emb = self .rope (hidden_states )
453401
454- # 2. Conditional embeddings
402+ # 2. Patch & Timestep embeddings
403+ p = self .config .patch_size
404+ post_patch_height = height // p
405+ post_patch_width = width // p
406+
407+ hidden_states , encoder_hidden_states = self .patch_embed (hidden_states , encoder_hidden_states )
408+
455409 temb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
456410 temb = F .silu (temb )
457- temb_cond , temb_uncond = temb .chunk (2 )
458- hidden_states , prompt_embeds , negative_prompt_embeds = self .patch_embed (
459- hidden_states , prompt_embeds , negative_prompt_embeds
460- )
461- hidden_states_cond , hidden_states_uncond = hidden_states .chunk (2 )
462411
463- encoder_hidden_states_cond = prompt_embeds
464- encoder_hidden_states_uncond = negative_prompt_embeds
465-
466- for index_block , block in enumerate (self .transformer_blocks ):
412+ # 3. Transformer blocks
413+ for block in self .transformer_blocks :
467414 if torch .is_grad_enabled () and self .gradient_checkpointing :
468- hidden_states_cond , encoder_hidden_states_cond = self ._gradient_checkpointing_func (
469- block , hidden_states_cond , encoder_hidden_states_cond , temb_cond , image_rotary_emb
470- )
471- hidden_states_uncond , encoder_hidden_states_uncond = self ._gradient_checkpointing_func (
472- block , hidden_states_uncond , encoder_hidden_states_uncond , temb_uncond , image_rotary_emb
415+ hidden_states , encoder_hidden_states = self ._gradient_checkpointing_func (
416+ block , hidden_states , encoder_hidden_states , temb , image_rotary_emb
473417 )
474418 else :
475- hidden_states_cond , encoder_hidden_states_cond = block (
476- hidden_states_cond , encoder_hidden_states_cond , temb_cond , image_rotary_emb
477- )
478- hidden_states_uncond , encoder_hidden_states_uncond = block (
479- hidden_states_uncond , encoder_hidden_states_uncond , temb_uncond , image_rotary_emb
419+ hidden_states , encoder_hidden_states = block (
420+ hidden_states , encoder_hidden_states , temb , image_rotary_emb
480421 )
481422
482- hidden_states_cond , encoder_hidden_states_cond = (
483- self .norm_out (hidden_states_cond , temb_cond ),
484- self .norm_out (encoder_hidden_states_cond , temb_cond ),
485- )
486- hidden_states_uncond , encoder_hidden_states_uncond = (
487- self .norm_out (hidden_states_uncond , temb_uncond ),
488- self .norm_out (encoder_hidden_states_uncond , temb_uncond ),
489- )
490-
491- hidden_states_cond = self .proj_out (hidden_states_cond )
492- hidden_states_uncond = self .proj_out (hidden_states_uncond )
423+ hidden_states = self .norm_out (hidden_states , temb )
424+ hidden_states = self .proj_out (hidden_states )
493425
494- # unpatchify
495- patch_size = self .config .patch_size
496- height = height // patch_size
497- width = width // patch_size
498-
499- hidden_states_cond = hidden_states_cond .reshape (
500- shape = (hidden_states_cond .shape [0 ], height , width , - 1 , patch_size , patch_size )
501- )
502- hidden_states_cond = torch .einsum ("nhwcpq->nchpwq" , hidden_states_cond )
503- output_cond = hidden_states_cond .reshape (
504- shape = (hidden_states_cond .shape [0 ], - 1 , height * patch_size , width * patch_size )
505- )
506-
507- hidden_states_uncond = hidden_states_uncond .reshape (
508- hidden_states_uncond .shape [0 ], height , width , - 1 , patch_size , patch_size
509- )
510- hidden_states_uncond = torch .einsum ("nhwcpq->nchpwq" , hidden_states_uncond )
511- output_uncond = hidden_states_uncond .reshape (
512- hidden_states_uncond .shape [0 ], - 1 , height * patch_size , width * patch_size
513- )
426+ # 4. Unpatchify
427+ hidden_states = hidden_states .reshape (batch_size , post_patch_height , post_patch_width , - 1 , p , p )
428+ output = hidden_states .permute (0 , 3 , 1 , 4 , 2 , 5 ).flatten (4 , 5 ).flatten (2 , 3 )
514429
515430 if not return_dict :
516- return (output_cond , output_uncond )
517- return Transformer2DModelOutput (sample = output_cond ), Transformer2DModelOutput ( sample = output_uncond )
431+ return (output , )
432+ return Transformer2DModelOutput (sample = output )
0 commit comments