Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Nov 20, 2024

What does this PR do?

See: https://github.com/huggingface/diffusers/actions/runs/11925826067/job/33238594250?pr=9943#step:7:261. This PR skips the NaN fuse_lora() tests when detected PyTorch version is 2.5 or above and the device is CPU.

It's likely a PyTorch bug as the following passes on PyTorch 2.4.1 CPU but not on PyTorch 2.5.1 CPU. Thanks to Ben for helping to verifying.

import torch
import torch.nn.functional as F

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(f"Device: {device}, torch version: {torch.__version__}")

with torch.device(device):
    query = torch.randn(2, 4, 16, 8)
    key = torch.randn(2, 4, 16, 8)
    value = torch.randn(2, 4, 16, 8)

    hidden_states = F.scaled_dot_product_attention(
        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
    )
    assert torch.isfinite(hidden_states).all()

    query += torch.nan
    hidden_states = F.scaled_dot_product_attention(
        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
    )
    assert torch.isnan(hidden_states).all()

See relevant PyTorch thread: pytorch/pytorch#141128.

Comment on lines 130 to 133
@unittest.skipIf(
torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
"Test not supported on PyTorch 2.5 and CPU.",
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w not important for this PR but any reason this had to be rewritten for Cog?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM. My only concern is that if it gets fixed in a future PyTorch version, it can easily go unnoticed and tests continue to be skipped. That's why I usually prefer pytest.mark.xfail(..., strict=True) for such cases.

@DN6
Copy link
Collaborator

DN6 commented Nov 20, 2024

I agree with @BenjaminBossan marking with xfail is better.

@sayakpaul
Copy link
Member Author

@BenjaminBossan @DN6 how about now?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM.

@sayakpaul
Copy link
Member Author

@DN6 okay to merge, pinging because of the CI failures across the board.

@sayakpaul sayakpaul merged commit 2e86a3f into main Nov 22, 2024
14 checks passed
@sayakpaul sayakpaul deleted the skip-nan-tests branch November 22, 2024 07:15
lawrence-cj pushed a commit to lawrence-cj/diffusers that referenced this pull request Nov 26, 2024
* skip nan lora tests on PyTorch 2.5.1 CPU.

* cog

* use xfail

* correct xfail

* add condition

* tests
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* skip nan lora tests on PyTorch 2.5.1 CPU.

* cog

* use xfail

* correct xfail

* add condition

* tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants