We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2d7c198 commit e35dd67Copy full SHA for e35dd67
src/diffusers/models/attention_processor.py
@@ -316,7 +316,7 @@ def set_use_xla_flash_attention(
316
elif is_spmd() and is_torch_xla_version("<", "2.4"):
317
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
318
else:
319
- if len(kwargs) > 0 and kwargs.get("is_flux", None):
+ if is_flux:
320
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
321
322
processor = XLAFlashAttnProcessor2_0(partition_spec)
0 commit comments