Skip to content

Conversation

@iDAPPA
Copy link

@iDAPPA iDAPPA commented Mar 19, 2025

AMD GPU Compatibility Fix for WanVideoWrapper

Description

This PR adds AMD GPU compatibility to ComfyUI-WanVideoWrapper by addressing tensor device mismatch issues that occur specifically on AMD GPUs running PyTorch with ROCm.

Issue

Users with AMD GPUs encounter the following error when running the WanVideoImageClipEncode node:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)

This occurs because the ROCm backend is less forgiving than CUDA when handling operations between tensors on different devices.

Solution

This PR adds a minimal, non-intrusive patch that:

  1. Detects when running on an AMD GPU
  2. Patches the torch.concat function to ensure all tensors are on the same device before concatenation
  3. Maintains full compatibility with NVIDIA GPUs (patch only activates on AMD)

Testing

Tested on an AMD Radeon RX 7900 XTX GPU with PyTorch 2.4.0+rocm6.3.4.

Changes

  • Added amd_fix.py with AMD-specific tensor device handling
  • Added single import line to __init__.py to load the fix

Impact

  • Fixes operation on AMD GPUs
  • Zero impact on NVIDIA users (patch is only applied when AMD is detected)
  • Minimal overhead (only checks/moves tensors when necessary)

## Additional Notes for the Repository Maintainers

1. **Non-intrusive approach**: The patch doesn't modify any existing code in the repository. It only adds a single new file and one import line, making it easy to review and maintain.

2. **Selective application**: The patch only activates when an AMD GPU is detected, ensuring NVIDIA users aren't affected by any potential overhead.

3. **Root cause**: The issue stems from how ROCm handles tensor operations across devices. While CUDA sometimes implicitly handles device mismatches, ROCm is stricter and requires explicit device management.

4. **Future-proofing**: If AMD's ROCm implementation changes to match NVIDIA's behavior in the future, this patch won't interfere as it only ensures tensors are on the same device, which is good practice regardless.

5. **Broader applicability**: While this PR focuses on fixing the immediate issue in WanVideoImageClipEncode, the pattern could be useful in other parts of the codebase where cross-device tensor operations occur.

This solution is minimally invasive while effectively solving the compatibility issue for AMD GPU users, without any drawbacks for NVIDIA users.

@Qurtison
Copy link

Works great! Hopefully gets merged to main soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants