Skip to content

FlashAttention2 & 3 推理报错 #35

@nanwanj

Description

@nanwanj

Found FlashAttention 2& 3 both have this error. flashAttention is compiled from source code.

torchrun --nproc_per_node=2 generate.py --task i2v-A14B --size 480*832 --ckpt_dir ./Wan2.2-I2V-A14B --lora_dir ./Wan2.2-Lightning/Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1 --dit_fsdp --t5_fsdp --ulysses_size 2 --base_seed 42 --prompt_file examples/i2v_prompt_list.txt --image_path_file examples/i2v_image_path_list.txt

[2025-12-02 21:27:27,389] INFO: Generating video ... index:0 seed:42 prompt:Static camera shot, wide shot, sunrise time, side lighting, warm colors. A dinosaur runs swiftly through a savanna, sunlight casting long shadows on the grassy terrain. The lions cower and scatter as the dinosaur approaches, bushes and trees swaying slightly in the breeze. The sky is a beautiful blend of orange and pink hues, highlighting the dramatic scene.

[rank0]: Traceback (most recent call last):
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/generate.py", line 497, in
[rank0]: generate(args)
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/generate.py", line 459, in generate
[rank0]: video = wan_i2v.generate(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/image2video.py", line 443, in generate
[rank0]: noise_pred_cond = model(
[rank0]: ^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/sequence_parallel.py", line 134, in sp_dit_forward
[rank0]: x = block(x, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/modules/model.py", line 243, in forward
[rank0]: y = self.self_attn(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/sequence_parallel.py", line 165, in sp_attn_forward
[rank0]: x = distributed_attention(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/ulysses.py", line 32, in distributed_attention
[rank0]: q = all_to_all(q, scatter_dim=2, gather_dim=1)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/util.py", line 29, in all_to_all
[rank0]: dist.all_to_all(outputs, inputs, group=group, **kwargs)
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 4749, in all_to_all
[rank0]: work = group.alltoall(output_tensor_list, input_tensor_list, opts)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/generate.py", line 497, in
[rank1]: generate(args)
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/generate.py", line 459, in generate
[rank1]: video = wan_i2v.generate(
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/image2video.py", line 443, in generate
[rank1]: noise_pred_cond = model(
[rank1]: ^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank1]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/sequence_parallel.py", line 134, in sp_dit_forward
[rank1]: x = block(x, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank1]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/modules/model.py", line 243, in forward
[rank1]: y = self.self_attn(
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/sequence_parallel.py", line 165, in sp_attn_forward
[rank1]: x = distributed_attention(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/ulysses.py", line 32, in distributed_attention
[rank1]: q = all_to_all(q, scatter_dim=2, gather_dim=1)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/ubuntu/8cards/Wan2.2-Lightning/wan/distributed/util.py", line 29, in all_to_all
[rank1]: dist.all_to_all(outputs, inputs, group=group, **kwargs)
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/envs/8cardstest2/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 4749, in all_to_all
[rank1]: work = group.alltoall(output_tensor_list, input_tensor_list, opts)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions