Skip to content

Commit c1b8004

Browse files
committed
refactor pt. 3
1 parent 4c01c9d commit c1b8004

File tree

2 files changed

+56
-142
lines changed

2 files changed

+56
-142
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 32 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -46,38 +46,23 @@ def __init__(
4646
self.patch_size = patch_size
4747
self.text_hidden_size = text_hidden_size
4848
self.pos_embed_max_size = pos_embed_max_size
49-
# Linear projection for image patches
50-
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
5149

52-
# Linear projection for text embeddings
50+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
5351
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
5452

55-
def forward(
56-
self, hidden_states: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: torch.Tensor | None
57-
) -> torch.Tensor:
53+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
5854
batch_size, channel, height, width = hidden_states.shape
55+
post_patch_height = height // self.patch_size
56+
post_patch_width = width // self.patch_size
5957

60-
if height % self.patch_size != 0 or width % self.patch_size != 0:
61-
raise ValueError("Height and width must be divisible by patch size")
62-
63-
patch_height = height // self.patch_size
64-
patch_width = width // self.patch_size
65-
66-
# b, c, h, w -> b, c, patch_height, patch_size, patch_width, patch_size
67-
# -> b, patch_height, patch_width, c, patch_size, patch_size
68-
# -> b, patch_height * patch_width, c * patch_size * patch_size
69-
hidden_states = (
70-
hidden_states.reshape(batch_size, channel, patch_height, self.patch_size, patch_width, self.patch_size)
71-
.permute(0, 2, 4, 1, 3, 5)
72-
.reshape(batch_size, patch_height * patch_width, channel * self.patch_size * self.patch_size)
58+
hidden_states = hidden_states.reshape(
59+
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
7360
)
61+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
62+
hidden_states = self.proj(hidden_states)
63+
encoder_hidden_states = self.text_proj(encoder_hidden_states)
7464

75-
# project
76-
hidden_states = self.proj(hidden_states) # embed_dim: 64 -> 4096
77-
prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
78-
if negative_prompt_embeds is not None:
79-
negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
80-
return hidden_states, prompt_embeds, negative_prompt_embeds
65+
return hidden_states, encoder_hidden_states
8166

8267

8368
class CogView4AdaLayerNormZero(nn.Module):
@@ -347,10 +332,10 @@ def __init__(
347332
self,
348333
patch_size: int = 2,
349334
in_channels: int = 16,
335+
out_channels: int = 16,
350336
num_layers: int = 30,
351337
attention_head_dim: int = 40,
352338
num_attention_heads: int = 64,
353-
out_channels: int = 16,
354339
text_embed_dim: int = 4096,
355340
time_embed_dim: int = 512,
356341
condition_dim: int = 256,
@@ -402,116 +387,46 @@ def __init__(
402387
def forward(
403388
self,
404389
hidden_states: torch.Tensor,
405-
prompt_embeds: torch.Tensor,
406-
negative_prompt_embeds: Optional[torch.Tensor],
390+
encoder_hidden_states: torch.Tensor,
407391
timestep: torch.LongTensor,
408392
original_size: torch.Tensor,
409393
target_size: torch.Tensor,
410394
crop_coords: torch.Tensor,
411395
return_dict: bool = True,
412396
) -> Union[torch.Tensor, Transformer2DModelOutput]:
413-
"""
414-
The [`CogView3PlusTransformer2DModel`] forward method.
415-
416-
Args:
417-
hidden_states (`torch.Tensor`):
418-
Input `hidden_states` of shape `(batch size, channel, height, width)`.
419-
encoder_hidden_states (`torch.Tensor`):
420-
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
421-
`(batch_size, sequence_len, text_embed_dim)`
422-
timestep (`torch.LongTensor`):
423-
Used to indicate denoising step.
424-
original_size (`torch.Tensor`):
425-
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
426-
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
427-
target_size (`torch.Tensor`):
428-
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
429-
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
430-
crop_coords (`torch.Tensor`):
431-
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
432-
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
433-
return_dict (`bool`, *optional*, defaults to `True`):
434-
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
435-
tuple.
436-
437-
Returns:
438-
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
439-
The denoised latents using provided inputs as conditioning.
440-
"""
441397
batch_size, num_channels, height, width = hidden_states.shape
442-
do_cfg = negative_prompt_embeds is not None
443-
444-
if do_cfg:
445-
assert (
446-
batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0]
447-
), "batch size mismatch in CFG mode"
448-
else:
449-
assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
450398

451399
# 1. RoPE
452400
image_rotary_emb = self.rope(hidden_states)
453401

454-
# 2. Conditional embeddings
402+
# 2. Patch & Timestep embeddings
403+
p = self.config.patch_size
404+
post_patch_height = height // p
405+
post_patch_width = width // p
406+
407+
hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
408+
455409
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
456410
temb = F.silu(temb)
457-
temb_cond, temb_uncond = temb.chunk(2)
458-
hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
459-
hidden_states, prompt_embeds, negative_prompt_embeds
460-
)
461-
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
462411

463-
encoder_hidden_states_cond = prompt_embeds
464-
encoder_hidden_states_uncond = negative_prompt_embeds
465-
466-
for index_block, block in enumerate(self.transformer_blocks):
412+
# 3. Transformer blocks
413+
for block in self.transformer_blocks:
467414
if torch.is_grad_enabled() and self.gradient_checkpointing:
468-
hidden_states_cond, encoder_hidden_states_cond = self._gradient_checkpointing_func(
469-
block, hidden_states_cond, encoder_hidden_states_cond, temb_cond, image_rotary_emb
470-
)
471-
hidden_states_uncond, encoder_hidden_states_uncond = self._gradient_checkpointing_func(
472-
block, hidden_states_uncond, encoder_hidden_states_uncond, temb_uncond, image_rotary_emb
415+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
416+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
473417
)
474418
else:
475-
hidden_states_cond, encoder_hidden_states_cond = block(
476-
hidden_states_cond, encoder_hidden_states_cond, temb_cond, image_rotary_emb
477-
)
478-
hidden_states_uncond, encoder_hidden_states_uncond = block(
479-
hidden_states_uncond, encoder_hidden_states_uncond, temb_uncond, image_rotary_emb
419+
hidden_states, encoder_hidden_states = block(
420+
hidden_states, encoder_hidden_states, temb, image_rotary_emb
480421
)
481422

482-
hidden_states_cond, encoder_hidden_states_cond = (
483-
self.norm_out(hidden_states_cond, temb_cond),
484-
self.norm_out(encoder_hidden_states_cond, temb_cond),
485-
)
486-
hidden_states_uncond, encoder_hidden_states_uncond = (
487-
self.norm_out(hidden_states_uncond, temb_uncond),
488-
self.norm_out(encoder_hidden_states_uncond, temb_uncond),
489-
)
490-
491-
hidden_states_cond = self.proj_out(hidden_states_cond)
492-
hidden_states_uncond = self.proj_out(hidden_states_uncond)
423+
hidden_states = self.norm_out(hidden_states, temb)
424+
hidden_states = self.proj_out(hidden_states)
493425

494-
# unpatchify
495-
patch_size = self.config.patch_size
496-
height = height // patch_size
497-
width = width // patch_size
498-
499-
hidden_states_cond = hidden_states_cond.reshape(
500-
shape=(hidden_states_cond.shape[0], height, width, -1, patch_size, patch_size)
501-
)
502-
hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
503-
output_cond = hidden_states_cond.reshape(
504-
shape=(hidden_states_cond.shape[0], -1, height * patch_size, width * patch_size)
505-
)
506-
507-
hidden_states_uncond = hidden_states_uncond.reshape(
508-
hidden_states_uncond.shape[0], height, width, -1, patch_size, patch_size
509-
)
510-
hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
511-
output_uncond = hidden_states_uncond.reshape(
512-
hidden_states_uncond.shape[0], -1, height * patch_size, width * patch_size
513-
)
426+
# 4. Unpatchify
427+
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
428+
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
514429

515430
if not return_dict:
516-
return (output_cond, output_uncond)
517-
return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)
431+
return (output,)
432+
return Transformer2DModelOutput(sample=output)

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -453,16 +453,11 @@ def __call__(
453453

454454
device = self._execution_device
455455

456-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
457-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
458-
# corresponds to doing no classifier free guidance.
459-
do_classifier_free_guidance = self.do_classifier_free_guidance
460-
461456
# Encode input prompt
462457
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
463458
prompt,
464459
negative_prompt,
465-
do_classifier_free_guidance,
460+
self.do_classifier_free_guidance,
466461
num_images_per_prompt=num_images_per_prompt,
467462
prompt_embeds=prompt_embeds,
468463
negative_prompt_embeds=negative_prompt_embeds,
@@ -484,18 +479,13 @@ def __call__(
484479
)
485480

486481
# Prepare additional timestep conditions
487-
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
488-
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
489-
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
490-
491-
if do_classifier_free_guidance:
492-
original_size = torch.cat([original_size, original_size])
493-
target_size = torch.cat([target_size, target_size])
494-
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
482+
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
483+
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
484+
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
495485

496-
original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
497-
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
498-
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
486+
original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
487+
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
488+
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
499489

500490
# Prepare timesteps
501491
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
@@ -513,28 +503,37 @@ def __call__(
513503
if self.interrupt:
514504
continue
515505

516-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
517-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
506+
latent_model_input = self.scheduler.scale_model_input(latents, t)
518507
latent_model_input = latent_model_input.to(transformer_dtype)
519508

520509
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
521510
timestep = t.expand(latents.shape[0])
522511

523-
noise_pred = self.transformer(
512+
noise_pred_cond = self.transformer(
524513
hidden_states=latent_model_input,
525-
prompt_embeds=prompt_embeds,
526-
negative_prompt_embeds=negative_prompt_embeds,
514+
encoder_hidden_states=prompt_embeds,
527515
timestep=timestep,
528516
original_size=original_size,
529517
target_size=target_size,
530518
crop_coords=crops_coords_top_left,
531519
return_dict=False,
532-
)
520+
)[0]
533521

534522
# perform guidance
535-
if do_classifier_free_guidance:
536-
noise_pred_cond, noise_pred_uncond = noise_pred
523+
if self.do_classifier_free_guidance:
524+
noise_pred_uncond = self.transformer(
525+
hidden_states=latent_model_input,
526+
encoder_hidden_states=negative_prompt_embeds,
527+
timestep=timestep,
528+
original_size=original_size,
529+
target_size=target_size,
530+
crop_coords=crops_coords_top_left,
531+
return_dict=False,
532+
)[0]
533+
537534
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
535+
else:
536+
noise_pred = noise_pred_cond
538537

539538
latents = self.scheduler.step(noise_pred, latents, t).prev_sample
540539

0 commit comments

Comments
 (0)