From bfce9c2583e575646b69e293e99df4ecaae4889a Mon Sep 17 00:00:00 2001 From: "mengli.cml" Date: Mon, 6 Jan 2025 20:45:10 +0800 Subject: [PATCH 1/3] support for parallel inference using xfuser --- easyanimate/models/attention.py | 2 +- easyanimate/models/processor.py | 53 +++- easyanimate/models/transformer3d.py | 64 ++++- predict_i2v_multi_gpus.py | 401 ++++++++++++++++++++++++++++ 4 files changed, 512 insertions(+), 8 deletions(-) create mode 100644 predict_i2v_multi_gpus.py diff --git a/easyanimate/models/attention.py b/easyanimate/models/attention.py index 0c36a0ea1..b788fc0e8 100644 --- a/easyanimate/models/attention.py +++ b/easyanimate/models/attention.py @@ -1145,4 +1145,4 @@ def forward( norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states) hidden_states = hidden_states + gate_ff * norm_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states - return hidden_states, encoder_hidden_states \ No newline at end of file + return hidden_states, encoder_hidden_states diff --git a/easyanimate/models/processor.py b/easyanimate/models/processor.py index 3f9224085..f9d331a3d 100644 --- a/easyanimate/models/processor.py +++ b/easyanimate/models/processor.py @@ -6,6 +6,21 @@ from diffusers.models.embeddings import apply_rotary_emb from einops import rearrange, repeat +try: + import xfuser + from xfuser.core.distributed import ( + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + initialize_model_parallel, + init_distributed_environment + ) + from xfuser.core.long_ctx_attention import xFuserLongContextAttention +except Exception as ex: + get_sequence_parallel_world_size = None + get_sequence_parallel_rank = None + xFuserLongContextAttention = None + class HunyuanAttnProcessor2_0: r""" @@ -217,7 +232,14 @@ def __call__( class EasyAnimateAttnProcessor2_0: def __init__(self): - pass + if xFuserLongContextAttention is not None: + try: + get_sequence_parallel_world_size() + self.hybrid_seq_parallel_attn = xFuserLongContextAttention() + except Exception: + self.hybrid_seq_parallel_attn = None + else: + self.hybrid_seq_parallel_attn = None def __call__( self, @@ -284,11 +306,30 @@ def __call__( if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + if self.hybrid_seq_parallel_attn is None: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2) + else: + sp_world_rank = get_sequence_parallel_rank() + sp_world_size = get_sequence_parallel_world_size() + + img_q = query[:, :, text_seq_length:].transpose(1,2) + txt_q = query[:, :, :text_seq_length].transpose(1,2) + img_k = key[:, :, text_seq_length:].transpose(1,2) + txt_k = key[:, :, :text_seq_length].transpose(1,2) + img_v = value[:, :, text_seq_length:].transpose(1,2) + txt_v = value[:, :, :text_seq_length].transpose(1,2) + + hidden_states = self.hybrid_seq_parallel_attn(None, + img_q, img_k, img_v, dropout_p=0.0, causal=False, + joint_tensor_query=txt_q, + joint_tensor_key=txt_k, + joint_tensor_value=txt_v, + joint_strategy='front',) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) if attn2 is None: # linear proj diff --git a/easyanimate/models/transformer3d.py b/easyanimate/models/transformer3d.py index 42fa3051f..7854916f0 100644 --- a/easyanimate/models/transformer3d.py +++ b/easyanimate/models/transformer3d.py @@ -51,6 +51,23 @@ from diffusers.models.embeddings import \ CaptionProjection as PixArtAlphaTextProjection +try: + import xfuser + from xfuser.core.distributed import ( + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + initialize_model_parallel, + init_distributed_environment + ) +except Exception as ex: + xfuser = None + get_sequence_parallel_world_size = None + get_sequence_parallel_rank = None + get_sp_group = None + initialize_model_parallel = None + init_distributed_environment = None + class CLIPProjection(nn.Module): """ @@ -1375,6 +1392,14 @@ def __init__( self.gradient_checkpointing = False + try: + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + except Exception: + self.sp_world_size = 1 + self.sp_world_rank = 0 + xfuser = None + def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value @@ -1399,6 +1424,40 @@ def forward( ): batch_size, channels, video_length, height, width = hidden_states.size() + if xfuser is not None and self.sp_world_size > 1: + if hidden_states.shape[-2] // self.patch_size % self.sp_world_size == 0: + split_height = height // self.sp_world_size + split_dim = -2 + elif hidden_states.shape[-2] // self.patch_size % self.sp_world_size == 0: + split_width = width // self.sp_world_size + split_dim = -1 + else: + raise ValueError("Cannot split video sequence into ulysses_degree x ring_degree=%d parts evenly, hidden_states.shape=%s" % (self.sp_world_size, str(hidden_states.shape))) + + + hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=split_dim)[self.sp_world_rank] + if inpaint_latents is not None: + inpaint_latents = torch.chunk(inpaint_latents, self.sp_world_size, dim=split_dim)[self.sp_world_rank] + + if image_rotary_emb is not None: + embed_dim = image_rotary_emb[0].shape[-1] + freq_cos = image_rotary_emb[0].reshape(video_length, height // self.patch_size, width // self.patch_size, embed_dim) + freq_sin = image_rotary_emb[1].reshape(video_length, height // self.patch_size, width // self.patch_size, embed_dim) + + freq_cos = torch.chunk(freq_cos, self.sp_world_size, dim=split_dim-1)[self.sp_world_rank] + freq_sin = torch.chunk(freq_sin, self.sp_world_size, dim=split_dim-1)[self.sp_world_rank] + + freq_cos = freq_cos.reshape(-1, embed_dim) + freq_sin = freq_sin.reshape(-1, embed_dim) + + image_rotary_emb = (freq_cos, freq_sin) + + if split_dim == -2: + height = split_height + elif split_dim == -1: + width = split_width + + # 1. Time embedding temb = self.time_proj(timestep).to(dtype=hidden_states.dtype) temb = self.time_embedding(temb, timestep_cond) @@ -1486,6 +1545,9 @@ def custom_forward(*inputs): output = hidden_states.reshape(batch_size, video_length, height // p, width // p, channels, p, p) output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + if xfuser is not None and self.sp_world_size > 1: + output = get_sp_group().all_gather(output, dim=split_dim) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) @@ -1606,4 +1668,4 @@ def from_pretrained_2d( print(f"### attn1 Parameters: {sum(params) / 1e6} M") model = model.to(torch_dtype) - return model \ No newline at end of file + return model diff --git a/predict_i2v_multi_gpus.py b/predict_i2v_multi_gpus.py new file mode 100644 index 000000000..ea078ca4e --- /dev/null +++ b/predict_i2v_multi_gpus.py @@ -0,0 +1,401 @@ +import os + +import numpy as np +import torch +import torch.distributed as dist +from diffusers import (DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + PNDMScheduler) +from omegaconf import OmegaConf +from PIL import Image +from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, + CLIPVisionModelWithProjection, + T5EncoderModel, T5Tokenizer) + +from easyanimate.models import (name_to_autoencoder_magvit, + name_to_transformer3d) +from easyanimate.pipeline.pipeline_easyanimate_inpaint import \ + EasyAnimateInpaintPipeline +from easyanimate.pipeline.pipeline_easyanimate_multi_text_encoder_inpaint import \ + EasyAnimatePipeline_Multi_Text_Encoder_Inpaint +from easyanimate.utils.lora_utils import merge_lora, unmerge_lora +from easyanimate.utils.utils import get_image_to_video_latent, save_videos_grid +from easyanimate.utils.fp8_optimization import convert_weight_dtype_wrapper + + +try: + import xfuser + from xfuser.core.distributed import ( + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + initialize_model_parallel, + init_distributed_environment + ) +except: + xfuser = None + get_sequence_parallel_world_size = None + get_sequence_parallel_rank = None + get_sp_group = None + initialize_model_parallel = None + init_distributed_environment = None + +ulysses_degree = 2 +ring_degree = 2 + +if ulysses_degree > 1 or ring_degree > 1: + dist.init_process_group("nccl") + print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % ( + ulysses_degree, ring_degree, dist.get_rank(), + dist.get_world_size())) + assert dist.get_world_size() == ring_degree * ulysses_degree, \ + "number of GPUs should be equal to ring_degree * ulysses_degree." + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(), + ring_degree=ring_degree, + ulysses_degree=ulysses_degree) + device = torch.device("cuda:%d" % dist.get_rank()) + print('rank=%d device=%s' % (dist.get_rank(), str(device))) +else: + device = "cuda" + +# GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "model_cpu_offload" + +# Config and model path +config_path = "config/easyanimate_video_v5_magvit_multi_text_encoder.yaml" +model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP" + +# Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM" +# EasyAnimateV1, V2 and V3 cannot use DDIM. +# EasyAnimateV4 and V5 support DDIM. +sampler_name = "DDIM" + +# Load pretrained model if need +transformer_path = None +# Only V1 does need a motion module +motion_module_path = None +vae_path = None +lora_path = None + +# Other params +sample_size = [384, 672] +# In EasyAnimateV1, the video_length of video is 40 ~ 80. +# In EasyAnimateV2, V3, V4, the video_length of video is 1 ~ 144. +# In EasyAnimateV5, the video_length of video is 1 ~ 49. +# If u want to generate a image, please set the video_length = 1. +video_length = 49 +fps = 8 + +# If you want to generate ultra long videos, please set partial_video_length as the length of each sub video segment +partial_video_length = None +overlap_video_length = 4 + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +# If you want to generate from text, please set the validation_image_start = None and validation_image_end = None +validation_image_start = "asset/1.png" +validation_image_end = None + +# EasyAnimateV1, V2 and V3 support English. +# EasyAnimateV4 and V5 support English and Chinese. +# 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性 +# 在neg prompt中添加"安静,固定"等词语可以增加动态性。 +prompt = "一条狗正在摇头。质量高、杰作、最佳品质、高分辨率、超精细、梦幻般。" +negative_prompt = "扭曲的身体,肢体残缺,文本字幕,漫画,静止,丑陋,错误,乱码。" +# +# Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability +# Adding words such as "quiet, solid" to the neg prompt can increase dynamism. +# prompt = "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic." +# negative_prompt = "Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." +guidance_scale = 6.0 +seed = 43 +num_inference_steps = 50 +lora_weight = 0.60 +save_path = "samples/easyanimate-videos_i2v" + +config = OmegaConf.load(config_path) + +# Get Transformer +Choosen_Transformer3DModel = name_to_transformer3d[ + config['transformer_additional_kwargs'].get('transformer_type', 'Transformer3DModel') +] + +transformer_additional_kwargs = OmegaConf.to_container(config['transformer_additional_kwargs']) +if weight_dtype == torch.float16: + transformer_additional_kwargs["upcast_attention"] = True + +transformer = Choosen_Transformer3DModel.from_pretrained_2d( + model_name, + subfolder="transformer", + transformer_additional_kwargs=transformer_additional_kwargs, + torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype, + low_cpu_mem_usage=True, +) + +transformer = transformer.to(device) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +if motion_module_path is not None: + print(f"From Motion Module: {motion_module_path}") + if motion_module_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(motion_module_path) + else: + state_dict = torch.load(motion_module_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}, {u}") + +# Get Vae +Choosen_AutoencoderKL = name_to_autoencoder_magvit[ + config['vae_kwargs'].get('vae_type', 'AutoencoderKL') +] +vae = Choosen_AutoencoderKL.from_pretrained( + model_name, + subfolder="vae", + vae_additional_kwargs=OmegaConf.to_container(config['vae_kwargs']) +).to(weight_dtype).to(device) +if config['vae_kwargs'].get('vae_type', 'AutoencoderKL') == 'AutoencoderKLMagvit' and weight_dtype == torch.float16: + vae.upcast_vae = True + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +if config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + tokenizer = BertTokenizer.from_pretrained( + model_name, subfolder="tokenizer" + ) + tokenizer_2 = T5Tokenizer.from_pretrained( + model_name, subfolder="tokenizer_2" + ) +else: + tokenizer = T5Tokenizer.from_pretrained( + model_name, subfolder="tokenizer" + ) + tokenizer_2 = None + +if config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + text_encoder = BertModel.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=weight_dtype + ).to(device) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_name, subfolder="text_encoder_2", torch_dtype=weight_dtype + ).to(device) +else: + text_encoder = T5EncoderModel.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=weight_dtype + ).to(device) + text_encoder_2 = None + +if transformer.config.in_channels != vae.config.latent_channels and config['transformer_additional_kwargs'].get('enable_clip_in_inpaint', True): + clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained( + model_name, subfolder="image_encoder" + ).to(device, weight_dtype) + clip_image_processor = CLIPImageProcessor.from_pretrained( + model_name, subfolder="image_encoder" + ) +else: + clip_image_encoder = None + clip_image_processor = None + +# Get Scheduler +Choosen_Scheduler = scheduler_dict = { + "Euler": EulerDiscreteScheduler, + "Euler A": EulerAncestralDiscreteScheduler, + "DPM++": DPMSolverMultistepScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, +}[sampler_name] + +scheduler = Choosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) +if config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + pipeline = EasyAnimatePipeline_Multi_Text_Encoder_Inpaint.from_pretrained( + model_name, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=transformer, + scheduler=scheduler, + torch_dtype=weight_dtype, + clip_image_encoder=clip_image_encoder, + clip_image_processor=clip_image_processor, + ) +else: + pipeline = EasyAnimateInpaintPipeline.from_pretrained( + model_name, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + transformer=transformer, + scheduler=scheduler, + torch_dtype=weight_dtype, + clip_image_encoder=clip_image_encoder, + clip_image_processor=clip_image_processor, + ) + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + pipeline.enable_model_cpu_offload(device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) +else: + pipeline.enable_model_cpu_offload(device=device) + +# print('pipeline to device=%s' % str(device)) +# pipeline.to(device) +# print('pipeline.device=%s' % str(pipeline.device)) + + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +if partial_video_length is not None: + init_frames = 0 + last_frames = init_frames + partial_video_length + while init_frames < video_length: + if last_frames >= video_length: + if pipeline.vae.quant_conv.weight.ndim==5: + mini_batch_encoder = pipeline.vae.mini_batch_encoder + _partial_video_length = video_length - init_frames + if vae.cache_mag_vae: + _partial_video_length = int((_partial_video_length - 1) // vae.mini_batch_encoder * vae.mini_batch_encoder) + 1 + else: + _partial_video_length = int(_partial_video_length // vae.mini_batch_encoder * vae.mini_batch_encoder) + else: + _partial_video_length = video_length - init_frames + + if _partial_video_length <= 0: + break + else: + _partial_video_length = partial_video_length + + input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image, None, video_length=_partial_video_length, sample_size=sample_size) + + with torch.no_grad(): + sample = pipeline( + prompt, + video_length = _partial_video_length, + negative_prompt = negative_prompt, + height = sample_size[0], + width = sample_size[1], + generator = generator, + guidance_scale = guidance_scale, + num_inference_steps = num_inference_steps, + + video = input_video, + mask_video = input_video_mask, + clip_image = clip_image, + ).videos + + if init_frames != 0: + mix_ratio = torch.from_numpy( + np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32) + ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + + new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \ + sample[:, :, :overlap_video_length] * mix_ratio + new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2) + + sample = new_sample + else: + new_sample = sample + + if last_frames >= video_length: + break + + validation_image = [ + Image.fromarray( + (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8) + ) for _index in range(-overlap_video_length, 0) + ] + + init_frames = init_frames + _partial_video_length - overlap_video_length + last_frames = init_frames + _partial_video_length +else: + if vae.cache_mag_vae: + video_length = int((video_length - 1) // vae.mini_batch_encoder * vae.mini_batch_encoder) + 1 if video_length != 1 else 1 + else: + video_length = int(video_length // vae.mini_batch_encoder * vae.mini_batch_encoder) if video_length != 1 else 1 + input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image_start, validation_image_end, video_length=video_length, sample_size=sample_size) + + with torch.no_grad(): + sample = pipeline( + prompt, + video_length = video_length, + negative_prompt = negative_prompt, + height = sample_size[0], + width = sample_size[1], + generator = generator, + guidance_scale = guidance_scale, + num_inference_steps = num_inference_steps, + + video = input_video, + mask_video = input_video_mask, + clip_image = clip_image, + ).videos + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + +index = len([path for path in os.listdir(save_path)]) + 1 +prefix = str(index).zfill(8) + +if video_length == 1: + save_sample_path = os.path.join(save_path, prefix + f".png") + + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save(save_sample_path) +else: + if ulysses_degree * ring_degree > 1: + if dist.get_rank() == 0: + video_path = os.path.join(save_path, prefix + ".mp4") + save_videos_grid(sample, video_path, fps=fps) + print('save video to %s' % video_path) + else: + video_path = os.path.join(save_path, prefix + ".mp4") + save_videos_grid(sample, video_path, fps=fps) + print('save video to %s' % video_path) From a35d7bbf0e3f49628dde7275d3db90a165b09dab Mon Sep 17 00:00:00 2001 From: "mengli.cml" Date: Thu, 9 Jan 2025 10:29:23 +0800 Subject: [PATCH 2/3] fix bug --- easyanimate/models/transformer3d.py | 2 +- predict_i2v_multi_gpus.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/easyanimate/models/transformer3d.py b/easyanimate/models/transformer3d.py index 7854916f0..6b9ededf9 100644 --- a/easyanimate/models/transformer3d.py +++ b/easyanimate/models/transformer3d.py @@ -1428,7 +1428,7 @@ def forward( if hidden_states.shape[-2] // self.patch_size % self.sp_world_size == 0: split_height = height // self.sp_world_size split_dim = -2 - elif hidden_states.shape[-2] // self.patch_size % self.sp_world_size == 0: + elif hidden_states.shape[-1] // self.patch_size % self.sp_world_size == 0: split_width = width // self.sp_world_size split_dim = -1 else: diff --git a/predict_i2v_multi_gpus.py b/predict_i2v_multi_gpus.py index ea078ca4e..316e309b0 100644 --- a/predict_i2v_multi_gpus.py +++ b/predict_i2v_multi_gpus.py @@ -50,7 +50,7 @@ ulysses_degree, ring_degree, dist.get_rank(), dist.get_world_size())) assert dist.get_world_size() == ring_degree * ulysses_degree, \ - "number of GPUs should be equal to ring_degree * ulysses_degree." + "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size() init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(), ring_degree=ring_degree, @@ -87,7 +87,9 @@ lora_path = None # Other params -sample_size = [384, 672] +# sample_size = [384, 672] +# sample_size = [576, 1008] +sample_size = [720, 1280] # In EasyAnimateV1, the video_length of video is 40 ~ 80. # In EasyAnimateV2, V3, V4, the video_length of video is 1 ~ 144. # In EasyAnimateV5, the video_length of video is 1 ~ 49. From 89cae4def04d463b2d6518e6634c70bdcc07d10c Mon Sep 17 00:00:00 2001 From: "mengli.cml" Date: Tue, 14 Jan 2025 18:13:11 +0800 Subject: [PATCH 3/3] add gradient norm tensorboard visualizatoin for deepspeed mode --- scripts/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/train.py b/scripts/train.py index 10b2ca705..ab469ac32 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1989,6 +1989,10 @@ def custom_mse_loss(noise_pred, target, threshold=50): lr_scheduler.step() optimizer.zero_grad() + if args.use_deepspeed and hasattr(optimizer, 'optimizer') and hasattr(optimizer.optimizer, '_global_grad_norm') and accelerator.is_main_process: + writer.add_scalar(f'gradients/norm_sum', optimizer.optimizer._global_grad_norm, + global_step=global_step) + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: