Skip to content

Conversation

@zzlol63
Copy link

@zzlol63 zzlol63 commented Nov 7, 2025

On Windows, PyTorch does not come pre-compiled with support for FlashAttention in the torch.nn.functional.scaled_dot_product_attention method, whereas on Linux it does, meaning there is a performance gap between the two.

This PR attempts to patch in support for FlashAttention-2 into OneTrainer by detecting the presence of flash-attn package on Windows and ensuring the current install of PyTorch does not support FlashAttention (just in-case it does in the future).

It then overwrites the default Torch SDPA method (via monkey-patch) with a version that dynamically determines if the current workload can be safely processed using flash-attention using similar logic employed by Torch SDP to see if the FlashAttention backend can be used (see here).

If FlashAttention cannot be used, it will fallback to the default Torch SDP implementation which will route it to one of the supported backends (ie. CUDNN, memory-efficient, math, etc).

I have run some benchmarks on Flux.1-dev and with the equivalent PyTorch and Python version, I am matching the performance of the same workload on Linux.

flash-attn can either be built from source or can use precompiled wheels that match the current configuration (ie. from here).

@zzlol63 zzlol63 marked this pull request as draft November 7, 2025 12:04
@zzlol63 zzlol63 marked this pull request as ready for review November 8, 2025 02:39
@zzlol63
Copy link
Author

zzlol63 commented Nov 8, 2025

I re-wrote the PR to support older models that directly use scaled_dot_product_attention such as SDXL and SD1.5.

Below are the benchmark results, SD1.5, SDXL and FLUX.1 enjoy a 12-20% speedup in performance. I previously thought SD1.5 was slower until I realised the units were different (iterations per second rather than seconds per iteration).

Operating System PyTorch Version Flash-Attention Preset Batch Size Speed Note
Windows 11 2.7.1+cu128 NO #sd 1.5 16 1.5it/s
Windows 11 2.7.1+cu128 YES #sd 1.5 16 1.8it/s
Windows 11 2.7.1+cu128 NO #sdxl 1.0 8 1.7s/it
Windows 11 2.7.1+cu128 YES #sdxl 1.0 8 1.5s/it
Windows 11 2.7.1+cu128 NO #qwen LoRA 24GB + 1024 resolution 2 5.3s/it Cannot use Flash-Attention due to use of attention masks.
Windows 11 2.7.1+cu128 YES #qwen LoRA 24GB + 1024 resolution 2 5.3s/it Cannot use Flash-Attention due to use of attention masks.
Windows 11 2.7.1+cu128 NO #flux LoRA + BF16 types + FP32 VAE 2 2.2s/it
Windows 11 2.7.1+cu128 YES #flux LoRA + BF16 types + FP32 VAE 2 1.8s/it
Windows 11 2.7.1+cu128 NO #flux LoRA 4 5.4s/it No improvement due to quantization (attention is in FP32)
Windows 11 2.7.1+cu128 YES #flux LoRA 4 5.4s/it No improvement due to quantization (attention is in FP32)
Windows 11 2.7.1+cu128 NO #chroma LoRA 24GB + 1024 resolution 2 4.8s/it Cannot use Flash-Attention due to use of attention masks.
Windows 11 2.7.1+cu128 YES #chroma LoRA 24GB + 1024 resolution 2 4.8s/it Cannot use Flash-Attention due to use of attention masks.

@zzlol63 zzlol63 force-pushed the master branch 4 times, most recently from f026e6f to 7482b05 Compare November 8, 2025 04:19
@zzlol63
Copy link
Author

zzlol63 commented Nov 8, 2025

Also have updated requirements-cuda.txt to include the precompiled wheels for flash-attn (pinned to PyTorch 2.7.x) automatically for ease of use for Windows users.

if supported is None:
supported = torch.cuda.get_device_properties(index).major >= 8
SUPPORTED_DEVICES[index] = supported
return supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resulted in random torch.compile errors (sometimes it works, others it errored) saying that it can't track: SUPPORTED_DEVICES[index] = supported
Inserting direct (return True) fixed it for me.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange, haven't had issues with it yet. Might do away with this check once a toggle is added into the UI.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I've precomputed the hardware support which should hopefully play more nice with torch.compile.

@yamatazen

This comment was marked as resolved.

@zzlol63
Copy link
Author

zzlol63 commented Nov 12, 2025

Also have updated requirements-cuda.txt to include the precompiled wheels for flash-attn (pinned to PyTorch 2.7.x) automatically for ease of use for Windows users.

OneTrainer is using PyTorch 2.8 now d3e0808

Yes as of only 5 hours ago... I have updated the dependencies.

@zzlol63
Copy link
Author

zzlol63 commented Nov 12, 2025

Have now migrated the patch into the UI as a toggle. It's now possible to turn the toggle on and off while training and immediately see the difference in speed with and without the feature on. 😄

@O-J1
Copy link
Collaborator

O-J1 commented Nov 14, 2025

It's now possible to turn the toggle on and off while training and immediately see the difference in speed with and without the feature on. 😄

Can we remove this, I dont see being able to toggle this whilst training as being useful and it will only introduce more potential failure cases.

@zzlol63
Copy link
Author

zzlol63 commented Nov 14, 2025

@O-J1 Fair enough, I've moved it into the GenericTrainer.

@dxqb dxqb marked this pull request as draft November 15, 2025 14:06
@yamatazen
Copy link

Any updates?

@dxqb
Copy link
Collaborator

dxqb commented Dec 21, 2025

I marked this draft, but I don't remember why. is it ready for review / merge?

@O-J1
Copy link
Collaborator

O-J1 commented Dec 21, 2025

I marked this draft, but I don't remember why. is it ready for review / merge?

Because of torch compile stuff, being worried about issues and I think something to do with discrepancies in testing

@dxqb
Copy link
Collaborator

dxqb commented Dec 21, 2025

I don't remember the exact issues, but certainly a cleaner way to implement this would be to submit a PR to diffusers. They already have a flash attn backend - but they have a bug (imho) that it is even used when an attention mask is passed - but flash attention ignores masks.

the fallback you've implemented here is the right thing to do, but I'm not sure OneTrainer is the right place

@dxqb
Copy link
Collaborator

dxqb commented Dec 28, 2025

where did you take these files from?

# flash-attn
https://github.com/zzlol63/flash-attention-prebuild-wheels/releases/download/v0.2/flash_attn-2.8.2+cu128torch2.8-cp310-cp310-win_amd64.whl; sys_platform == "win32" and python_version == "3.10"
https://github.com/zzlol63/flash-attention-prebuild-wheels/releases/download/v0.2/flash_attn-2.8.2+cu128torch2.8-cp311-cp311-win_amd64.whl; sys_platform == "win32" and python_version == "3.11"
https://github.com/zzlol63/flash-attention-prebuild-wheels/releases/download/v0.2/flash_attn-2.8.2+cu128torch2.8-cp312-cp312-win_amd64.whl; sys_platform == "win32" and python_version == "3.12"

@dxqb
Copy link
Collaborator

dxqb commented Jan 5, 2026

the fallback you've implemented here is the right thing to do, but I'm not sure OneTrainer is the right place

I have implemented flash-attn support in this PR: #1227
it waits for a diffusers PR to check for attn masks.

@zzlol63 could you please explain why you felt the need to do all these other checks in your PR?
https://github.com/Nerogar/OneTrainer/blob/515861f93bb15183b860f30e65aac017bc3a5909/modules/util/attn/flash_attn_win.py#L63C9-L99C1

are (some of) these also silently ignored by diffusers, like attention mask previously?

@dxqb dxqb added the followup Failure to provide config or other info or needs followup label Jan 5, 2026
@zzlol63
Copy link
Author

zzlol63 commented Jan 6, 2026

@dxqb It's intended to be a drop in replacement/enhancement for Torch SDPA, so it accounts for other potential inputs that Torch SDPA can handle but flash-attn cannot (such as nested tensors, FP32 tensors, CPU tensors, etc). This is not a diffusers specific implementation per-say.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

followup Failure to provide config or other info or needs followup

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants