From 5687dc6d39fb27b0bad1654f012919f657362e8f Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 20 Sep 2024 00:00:05 +0200 Subject: [PATCH 1/5] profile cogvideox --- .../transformers/cogvideox_transformer_3d.py | 99 +++++++++-------- .../pipelines/cogvideo/pipeline_cogvideox.py | 100 ++++++++++++------ .../schedulers/scheduling_ddim_cogvideox.py | 99 +++++++++++------ 3 files changed, 187 insertions(+), 111 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 821da6d032d5..f79e80f62e8a 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -17,6 +17,7 @@ import torch from torch import nn +from torch.profiler import record_function from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin @@ -433,49 +434,52 @@ def forward( batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) + with record_function("time embedding"): + timesteps = timestep + t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) + with record_function("patch embedding"): + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] # 3. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - ) + with record_function("blocks"): + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B @@ -487,16 +491,17 @@ def custom_forward(*inputs): hidden_states = hidden_states[:, text_seq_length:] # 4. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + with record_function("final output"): + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 02497e77edb7..88379c1814f8 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from torch.profiler import record_function from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -679,39 +680,73 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred.float() + with record_function(f"transformer_iteration_{i}"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + # noise_pred = noise_pred.float() # perform guidance - if use_dynamic_cfg: - self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 - ) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, + with record_function(f"guidance_{i}"): + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) + / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + with record_function("1.1 scheduler"): + prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + + with record_function("1.2 scheduler"): + alpha_prod_t = self.scheduler.alphas_cumprod[t] + + with record_function("1.3 scheduler"): + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.scheduler.final_alpha_cumprod ) - latents = latents.to(prompt_embeds.dtype) + + with record_function("1.4 scheduler"): + beta_prod_t = 1 - alpha_prod_t + + with record_function("1.5 scheduler"): + pred_original_sample = (alpha_prod_t**0.5) * latents - (beta_prod_t**0.5) * noise_pred + + with record_function("1.6 scheduler"): + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + + with record_function("1.7 scheduler"): + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + with record_function("1.8 scheduler"): + prev_sample = a_t * latents + b_t * pred_original_sample + + latents = prev_sample + + # # compute the previous noisy sample x_t -> x_t-1 + # with record_function(f"scheduler_step_{i}"): + # if not isinstance(self.scheduler, CogVideoXDPMScheduler): + # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # else: + # latents, old_pred_original_sample = self.scheduler.step( + # noise_pred, + # old_pred_original_sample, + # t, + # timesteps[i - 1] if i > 0 else None, + # latents, + # **extra_step_kwargs, + # return_dict=False, + # ) + # # latents = latents.to(prompt_embeds.dtype) # call the callback, if provided if callback_on_step_end is not None: @@ -728,8 +763,9 @@ def __call__( progress_bar.update() if not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + with record_function("decode_latents"): + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index ec5c5f3e1c5d..6a4d979fb4f7 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -22,6 +22,7 @@ import numpy as np import torch +from torch.profiler import record_function from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput @@ -362,41 +363,75 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - - # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - # To make style tests pass, commented out `pred_epsilon` as it is an unused variable - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - # pred_epsilon = model_output - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" + with record_function("get original prediction"): + with record_function("step 1"): + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + with record_function("step 2"): + print(self.alphas_cumprod.device, self.alphas_cumprod.dtype) + print(timestep.device, timestep.type) + print(prev_timestep.device, prev_timestep.dtype) + with record_function("step 2.1"): + alpha_prod_t = self.alphas_cumprod[timestep] + + with record_function("step 2.2"): + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + ) + + with record_function("step 2.3"): + beta_prod_t = 1 - alpha_prod_t + print(beta_prod_t.device, beta_prod_t.dtype) + print("======") + + with record_function("step 3"): + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + print( + "vpred:", + sample.dtype, + model_output.dtype, + alpha_prod_t.dtype, + beta_prod_t.dtype, + pred_original_sample.dtype, + ) + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + with record_function("compute prev sample"): + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + print( + "prevsample devices:", + a_t.device, + b_t.device, + sample.device, + pred_original_sample.device, + prev_sample.device, ) + print("prevsample:", a_t.dtype, b_t.dtype, sample.dtype, pred_original_sample.dtype, prev_sample.dtype) + print("=== done ===") - a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 - b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t - - prev_sample = a_t * sample + b_t * pred_original_sample - - if not return_dict: - return (prev_sample,) + if not return_dict: + return (prev_sample,) - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( From fb0ebbb7313ca6a4c30fb36ae63d62e991950a2b Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 20 Sep 2024 01:04:12 +0200 Subject: [PATCH 2/5] update --- .../pipelines/cogvideo/pipeline_cogvideox.py | 8 +++++-- .../schedulers/scheduling_ddim_cogvideox.py | 21 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 88379c1814f8..c2a0d1fcf305 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -703,10 +703,14 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) with record_function("1.1 scheduler"): - prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + prev_timestep = ( + self.scheduler.timesteps_numpy[i] + - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + ) with record_function("1.2 scheduler"): - alpha_prod_t = self.scheduler.alphas_cumprod[t] + # alpha_prod_t = self.scheduler.alphas_cumprod[t] + alpha_prod_t = self.scheduler.alphas_cumprod[self.scheduler.timesteps_numpy[i]] with record_function("1.3 scheduler"): alpha_prod_t_prev = ( diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index 6a4d979fb4f7..b0f2aabb4856 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -106,11 +106,15 @@ def rescale_zero_terminal_snr(alphas_cumprod): `torch.Tensor`: rescaled betas with zero terminal SNR """ - alphas_bar_sqrt = alphas_cumprod.sqrt() + # alphas_bar_sqrt = alphas_cumprod.sqrt() + alphas_bar_sqrt = np.sqrt(alphas_cumprod) # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + # alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + print("alphas_bar_sqrt", alphas_bar_sqrt[0]) + alphas_bar_sqrt_0 = np.copy(alphas_bar_sqrt[0]) + alphas_bar_sqrt_T = np.copy(alphas_bar_sqrt[-1]) # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T @@ -208,8 +212,10 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + self.betas = self.betas.cpu().detach().numpy() self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) # Modify: SNR shift following SD3 self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) @@ -222,14 +228,16 @@ def __init__( # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = 1.0 if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + # self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.timesteps = np.arange(0, num_train_timesteps)[::-1] def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] @@ -302,6 +310,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps_numpy = timesteps def step( self, From 3ae6094d1a7e27231291b7cfbc83a54bc70bc975 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 24 Sep 2024 06:31:53 +0200 Subject: [PATCH 3/5] dump --- .../transformers/cogvideox_transformer_3d.py | 99 +++++++------- src/diffusers/schedulers/scheduling_ddim.py | 129 +++++++++++------- .../schedulers/scheduling_ddim_cogvideox.py | 13 +- 3 files changed, 129 insertions(+), 112 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index f79e80f62e8a..821da6d032d5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -17,7 +17,6 @@ import torch from torch import nn -from torch.profiler import record_function from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin @@ -434,52 +433,49 @@ def forward( batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding - with record_function("time embedding"): - timesteps = timestep - t_emb = self.time_proj(timesteps) + timesteps = timestep + t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) # 2. Patch embedding - with record_function("patch embedding"): - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] # 3. Transformer blocks - with record_function("blocks"): - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - ) + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B @@ -491,17 +487,16 @@ def custom_forward(*inputs): hidden_states = hidden_states[:, text_seq_length:] # 4. Final block - with record_function("final output"): - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 14356eafdaea..83c261d0fe77 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -21,6 +21,7 @@ import numpy as np import torch +from torch.profiler import record_function from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput @@ -231,7 +232,10 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) + self.timesteps_cpu = np.arange(0, num_train_timesteps)[::-1].astype(np.int64) def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ @@ -251,8 +255,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None return sample def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + gathered = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, gathered, self.final_alpha_cumprod) + # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -338,11 +345,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps_cpu = timesteps def step( self, model_output: torch.Tensor, - timestep: int, + timestep: Union[torch.Tensor, float], sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, @@ -357,7 +365,7 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`torch.Tensor` or `float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. @@ -399,68 +407,87 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + with record_function("1 scheduler"): + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + with record_function("2 scheduler"): + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + # alpha_prod_t = self.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t + with record_function("3 scheduler"): + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + gathered = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, gathered, self.final_alpha_cumprod) + # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + with record_function("4 scheduler"): + beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" - ) + with record_function("5 scheduler"): + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) # 4. Clip or threshold "predicted x_0" - if self.config.thresholding: - pred_original_sample = self._threshold_sample(pred_original_sample) - elif self.config.clip_sample: - pred_original_sample = pred_original_sample.clamp( - -self.config.clip_sample_range, self.config.clip_sample_range - ) + with record_function("6 scheduler"): + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) + with record_function("7 scheduler"): + with record_function("7.1 scheduler"): + variance = self._get_variance(timestep, prev_timestep) + + with record_function("7.2 scheduler"): + std_dev_t = eta * variance ** (0.5) - if use_clipped_model_output: - # the pred_epsilon is always re-derived from the clipped x_0 in Glide - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + with record_function("7.3 scheduler"): + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + with record_function("8 scheduler"): + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - - if eta > 0: - if variance_noise is not None and generator is not None: - raise ValueError( - "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" - " `variance_noise` stays `None`." - ) - - if variance_noise is None: - variance_noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype - ) - variance = std_dev_t * variance_noise - - prev_sample = prev_sample + variance + with record_function("9 scheduler"): + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + with record_function("10 scheduler"): + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index b0f2aabb4856..333f3bf08d20 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -110,11 +110,8 @@ def rescale_zero_terminal_snr(alphas_cumprod): alphas_bar_sqrt = np.sqrt(alphas_cumprod) # Store old values. - # alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - # alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - print("alphas_bar_sqrt", alphas_bar_sqrt[0]) - alphas_bar_sqrt_0 = np.copy(alphas_bar_sqrt[0]) - alphas_bar_sqrt_T = np.copy(alphas_bar_sqrt[-1]) + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T @@ -228,16 +225,14 @@ def __init__( # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - # self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - self.final_alpha_cumprod = 1.0 if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None - # self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] From 8e78c9d1d1b677a058e5d6f9146b7df7204dc489 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 24 Sep 2024 07:34:53 +0200 Subject: [PATCH 4/5] update --- .../pipelines/cogvideo/pipeline_cogvideox.py | 102 +++++--------- src/diffusers/schedulers/scheduling_ddim.py | 50 ++++--- .../schedulers/scheduling_ddim_cogvideox.py | 124 +++++++----------- 3 files changed, 101 insertions(+), 175 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 547c11631be4..82839ffd2c92 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from torch.profiler import record_function from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -682,77 +681,39 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - with record_function(f"transformer_iteration_{i}"): - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - # noise_pred = noise_pred.float() + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() # perform guidance - with record_function(f"guidance_{i}"): - if use_dynamic_cfg: - self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) - / 2 - ) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - with record_function("1.1 scheduler"): - prev_timestep = ( - self.scheduler.timesteps_numpy[i] - - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) - - with record_function("1.2 scheduler"): - # alpha_prod_t = self.scheduler.alphas_cumprod[t] - alpha_prod_t = self.scheduler.alphas_cumprod[self.scheduler.timesteps_numpy[i]] - - with record_function("1.3 scheduler"): - alpha_prod_t_prev = ( - self.scheduler.alphas_cumprod[prev_timestep] - if prev_timestep >= 0 - else self.scheduler.final_alpha_cumprod + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, ) - - with record_function("1.4 scheduler"): - beta_prod_t = 1 - alpha_prod_t - - with record_function("1.5 scheduler"): - pred_original_sample = (alpha_prod_t**0.5) * latents - (beta_prod_t**0.5) * noise_pred - - with record_function("1.6 scheduler"): - a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 - - with record_function("1.7 scheduler"): - b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t - - with record_function("1.8 scheduler"): - prev_sample = a_t * latents + b_t * pred_original_sample - - latents = prev_sample - - # # compute the previous noisy sample x_t -> x_t-1 - # with record_function(f"scheduler_step_{i}"): - # if not isinstance(self.scheduler, CogVideoXDPMScheduler): - # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - # else: - # latents, old_pred_original_sample = self.scheduler.step( - # noise_pred, - # old_pred_original_sample, - # t, - # timesteps[i - 1] if i > 0 else None, - # latents, - # **extra_step_kwargs, - # return_dict=False, - # ) - # # latents = latents.to(prompt_embeds.dtype) + latents = latents.to(prompt_embeds.dtype) # call the callback, if provided if callback_on_step_end is not None: @@ -769,9 +730,8 @@ def __call__( progress_bar.update() if not output_type == "latent": - with record_function("decode_latents"): - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 83c261d0fe77..998b5b069fd4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -21,13 +21,14 @@ import numpy as np import torch -from torch.profiler import record_function from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from torch.profiler import record_function + @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM @@ -232,10 +233,9 @@ def __init__( # setable values self.num_inference_steps = None - + # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) - self.timesteps_cpu = np.arange(0, num_train_timesteps)[::-1].astype(np.int64) def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ @@ -256,10 +256,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None def _get_variance(self, timestep, prev_timestep): alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + safe_prev_timestep = torch.clamp(prev_timestep, min=0) - gathered = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) - alpha_prod_t_prev = torch.where(prev_timestep >= 0, gathered, self.final_alpha_cumprod) - # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -345,12 +346,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps_cpu = timesteps + self.alphas_cumprod = self.alphas_cumprod.to(device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(device) def step( self, model_output: torch.Tensor, - timestep: Union[torch.Tensor, float], + timestep: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, @@ -365,7 +367,7 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`torch.Tensor` or `float`): + timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. @@ -413,13 +415,11 @@ def step( # 2. compute alphas, betas with record_function("2 scheduler"): alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) - # alpha_prod_t = self.alphas_cumprod[timestep] - + with record_function("3 scheduler"): safe_prev_timestep = torch.clamp(prev_timestep, min=0) - gathered = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) - alpha_prod_t_prev = torch.where(prev_timestep >= 0, gathered, self.final_alpha_cumprod) - # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) with record_function("4 scheduler"): beta_prod_t = 1 - alpha_prod_t @@ -454,26 +454,20 @@ def step( # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) with record_function("7 scheduler"): - with record_function("7.1 scheduler"): - variance = self._get_variance(timestep, prev_timestep) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) - with record_function("7.2 scheduler"): - std_dev_t = eta * variance ** (0.5) - - with record_function("7.3 scheduler"): - if use_clipped_model_output: - # the pred_epsilon is always re-derived from the clipped x_0 in Glide - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - with record_function("8 scheduler"): + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - with record_function("9 scheduler"): + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - with record_function("10 scheduler"): + with record_function("8 scheduler"): if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index 333f3bf08d20..f4e4945b703d 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -22,7 +22,6 @@ import numpy as np import torch -from torch.profiler import record_function from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput @@ -106,8 +105,7 @@ def rescale_zero_terminal_snr(alphas_cumprod): `torch.Tensor`: rescaled betas with zero terminal SNR """ - # alphas_bar_sqrt = alphas_cumprod.sqrt() - alphas_bar_sqrt = np.sqrt(alphas_cumprod) + alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() @@ -209,10 +207,8 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") - self.betas = self.betas.cpu().detach().numpy() self.alphas = 1.0 - self.betas - # self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Modify: SNR shift following SD3 self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) @@ -232,11 +228,17 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -305,7 +307,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps_numpy = timesteps + self.alphas_cumprod = self.alphas_cumprod.to(device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(device) def step( self, @@ -367,75 +370,44 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - with record_function("get original prediction"): - with record_function("step 1"): - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - - # 2. compute alphas, betas - with record_function("step 2"): - print(self.alphas_cumprod.device, self.alphas_cumprod.dtype) - print(timestep.device, timestep.type) - print(prev_timestep.device, prev_timestep.dtype) - with record_function("step 2.1"): - alpha_prod_t = self.alphas_cumprod[timestep] - - with record_function("step 2.2"): - alpha_prod_t_prev = ( - self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - ) - - with record_function("step 2.3"): - beta_prod_t = 1 - alpha_prod_t - print(beta_prod_t.device, beta_prod_t.dtype) - print("======") - - with record_function("step 3"): - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - # To make style tests pass, commented out `pred_epsilon` as it is an unused variable - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - # pred_epsilon = model_output - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - print( - "vpred:", - sample.dtype, - model_output.dtype, - alpha_prod_t.dtype, - beta_prod_t.dtype, - pred_original_sample.dtype, - ) - # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" - ) - - with record_function("compute prev sample"): - a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 - b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t - - prev_sample = a_t * sample + b_t * pred_original_sample - print( - "prevsample devices:", - a_t.device, - b_t.device, - sample.device, - pred_original_sample.device, - prev_sample.device, + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" ) - print("prevsample:", a_t.dtype, b_t.dtype, sample.dtype, pred_original_sample.dtype, prev_sample.dtype) - print("=== done ===") - if not return_dict: - return (prev_sample,) + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + + if not return_dict: + return (prev_sample,) - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( From 3bf39613fbc052602e07d68644182fea0984befc Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 24 Sep 2024 13:23:29 +0200 Subject: [PATCH 5/5] update --- src/diffusers/schedulers/scheduling_ddim.py | 124 ++++++++---------- .../schedulers/scheduling_ddim_cogvideox.py | 4 +- 2 files changed, 59 insertions(+), 69 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 998b5b069fd4..bd567babd7c7 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -27,8 +27,6 @@ from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -from torch.profiler import record_function - @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM @@ -233,7 +231,7 @@ def __init__( # setable values self.num_inference_steps = None - + # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) @@ -256,11 +254,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None def _get_variance(self, timestep, prev_timestep): alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) - + safe_prev_timestep = torch.clamp(prev_timestep, min=0) safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) - + beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -409,79 +407,71 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - with record_function("1 scheduler"): - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas - with record_function("2 scheduler"): - alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) - - with record_function("3 scheduler"): - safe_prev_timestep = torch.clamp(prev_timestep, min=0) - safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) - alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) - with record_function("4 scheduler"): - beta_prod_t = 1 - alpha_prod_t + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) + + beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - with record_function("5 scheduler"): - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" - ) + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) # 4. Clip or threshold "predicted x_0" - with record_function("6 scheduler"): - if self.config.thresholding: - pred_original_sample = self._threshold_sample(pred_original_sample) - elif self.config.clip_sample: - pred_original_sample = pred_original_sample.clamp( - -self.config.clip_sample_range, self.config.clip_sample_range - ) + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - with record_function("7 scheduler"): - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) - - if use_clipped_model_output: - # the pred_epsilon is always re-derived from the clipped x_0 in Glide - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon - - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - - with record_function("8 scheduler"): - if eta > 0: - if variance_noise is not None and generator is not None: - raise ValueError( - "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" - " `variance_noise` stays `None`." - ) - - if variance_noise is None: - variance_noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype - ) - variance = std_dev_t * variance_noise - - prev_sample = prev_sample + variance + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index f4e4945b703d..20a5f1a75f33 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -229,12 +229,12 @@ def __init__( # setable values self.num_inference_steps = None - # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now + # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) def _get_variance(self, timestep, prev_timestep): alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) - + safe_prev_timestep = torch.clamp(prev_timestep, min=0) safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)