-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Description
Hi! @a-r-r-o-w ,I would like to ask you about my error on using Context Parallelism for inference.
Issue Description
Environment
- Diffusers: 0.36.0.dev0
Problem Description
I'm trying to run image-to-video generation using WanImageToVideoPipeline with model quantization (qfloat8_e4m3fn via Quanto), frozen weights, and Context Parallelism enabled with ulysses_degree=8. The pipeline initializes successfully, but during the first inference step (at 0/20 steps), it raises an AssertionError in the Context Parallel hook:
AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size
This occurs in diffusers/hooks/context_parallel.py during the sharding of hidden_states in the transformer block's forward pass.
Expected Behavior: The pipeline should generate the video frames without crashing, distributing computation across GPUs via Context Parallelism.
Actual Behavior: Crashes immediately at the start of denoising loop.
Minimal Reproducible Code
Here's the full script that's failing (run with torchrun --nproc_per_node=8 test.py or similar for 8 GPUs):
import torch
import os
from PIL import Image
from diffusers import (
AutoencoderKLWan, WanPipeline, WanTransformer3DModel, ContextParallelConfig,
WanImageToVideoPipeline
)
from diffusers.utils import export_to_video
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from optimum.quanto import freeze, qfloat8_e4m3fn, quantize
from transformers import AutoTokenizer, UMT5EncoderModel, CLIPVisionModel
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
dtype = torch.bfloat16
model_id = '/share/models/checkpoints/Wan-AI/Wan2___1-I2V-14B-720P-Diffusers'
transformer = WanTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
text_encoder = UMT5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=dtype
)
# Quantize text_encoder
quantize(text_encoder, weights=qfloat8_e4m3fn)
freeze(text_encoder)
# Quantize transformer
quantize(transformer, weights=qfloat8_e4m3fn)
freeze(transformer)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=dtype
)
flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config, flow_shift=flow_shift
)
pipe.to("cuda")
transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=8)
)
image = Image.open("/share/common/AIPhoto/3.jpeg").resize((832, 480))
# .resize((800,1280))
prompt = (
"现代都市风格摄影,一位身穿白色印花T恤和黑色短裤的年轻男子坐在透明玻璃楼梯上,"
"脚穿黑白帆布鞋,姿态随性自然。他的皮肤白皙,身材匀称,双腿微微分开,"
"手肘搭在膝盖上,背景是高耸的玻璃幕墙和现代化建筑,透过玻璃可见城市的高楼轮廓。"
"在固定镜头下,他缓慢抬起双手【双手比心】,动作轻松流畅,整个画面充满现代感与都市气息。"
"慢动作展现细腻的动态细节。"
)
negative_prompt = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, "
"static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, "
"extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, "
"fused fingers, still picture, messy background, three legs, many people in the background, "
"walking backwards"
)
# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0,
num_inference_steps=20,
generator=generator,
).frames[0]
if rank == 0:
export_to_video(output, "output.mp4", fps=16)
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()Error Traceback
0%| | 0/20 [00:00<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]: File "/share/gdli7/common/AIPhoto/test.py", line 46, in <module>
[rank1]: output = pipe(
[rank1]: ^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
[rank1]: noise_pred = current_model(
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank1]: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]: args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank1]: input_val = self._prepare_cp_input(input_val, cpm)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 209, in _prepare_cp_input
[rank1]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 259, in shard
[rank1]: assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size