-
-
Notifications
You must be signed in to change notification settings - Fork 258
Add flash-attn support for Windows #1107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
I re-wrote the PR to support older models that directly use Below are the benchmark results, SD1.5, SDXL and FLUX.1 enjoy a 12-20% speedup in performance. I previously thought SD1.5 was slower until I realised the units were different (iterations per second rather than seconds per iteration).
|
f026e6f to
7482b05
Compare
|
Also have updated requirements-cuda.txt to include the precompiled wheels for flash-attn (pinned to PyTorch 2.7.x) automatically for ease of use for Windows users. |
modules/util/attn/flash_attn_win.py
Outdated
| if supported is None: | ||
| supported = torch.cuda.get_device_properties(index).major >= 8 | ||
| SUPPORTED_DEVICES[index] = supported | ||
| return supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This resulted in random torch.compile errors (sometimes it works, others it errored) saying that it can't track: SUPPORTED_DEVICES[index] = supported
Inserting direct (return True) fixed it for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strange, haven't had issues with it yet. Might do away with this check once a toggle is added into the UI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I've precomputed the hardware support which should hopefully play more nice with torch.compile.
This comment was marked as resolved.
This comment was marked as resolved.
Yes as of only 5 hours ago... I have updated the dependencies. |
|
Have now migrated the patch into the UI as a toggle. It's now possible to turn the toggle on and off while training and immediately see the difference in speed with and without the feature on. 😄 |
Can we remove this, I dont see being able to toggle this whilst training as being useful and it will only introduce more potential failure cases. |
|
@O-J1 Fair enough, I've moved it into the GenericTrainer. |
|
Any updates? |
|
I marked this draft, but I don't remember why. is it ready for review / merge? |
Because of torch compile stuff, being worried about issues and I think something to do with discrepancies in testing |
|
I don't remember the exact issues, but certainly a cleaner way to implement this would be to submit a PR to diffusers. They already have a flash attn backend - but they have a bug (imho) that it is even used when an attention mask is passed - but flash attention ignores masks. the fallback you've implemented here is the right thing to do, but I'm not sure OneTrainer is the right place |
|
where did you take these files from? |
I have implemented flash-attn support in this PR: #1227 @zzlol63 could you please explain why you felt the need to do all these other checks in your PR? are (some of) these also silently ignored by diffusers, like attention mask previously? |
|
@dxqb It's intended to be a drop in replacement/enhancement for Torch SDPA, so it accounts for other potential inputs that Torch SDPA can handle but flash-attn cannot (such as nested tensors, FP32 tensors, CPU tensors, etc). This is not a diffusers specific implementation per-say. |
On Windows, PyTorch does not come pre-compiled with support for FlashAttention in the
torch.nn.functional.scaled_dot_product_attentionmethod, whereas on Linux it does, meaning there is a performance gap between the two.This PR attempts to patch in support for FlashAttention-2 into OneTrainer by detecting the presence of flash-attn package on Windows and ensuring the current install of PyTorch does not support FlashAttention (just in-case it does in the future).
It then overwrites the default Torch SDPA method (via monkey-patch) with a version that dynamically determines if the current workload can be safely processed using flash-attention using similar logic employed by Torch SDP to see if the FlashAttention backend can be used (see here).
If FlashAttention cannot be used, it will fallback to the default Torch SDP implementation which will route it to one of the supported backends (ie. CUDNN, memory-efficient, math, etc).
I have run some benchmarks on Flux.1-dev and with the equivalent PyTorch and Python version, I am matching the performance of the same workload on Linux.
flash-attn can either be built from source or can use precompiled wheels that match the current configuration (ie. from here).