Skip to content

Commit df31c9d

Browse files
committed
naming nit
1 parent 46baa56 commit df31c9d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

examples/research_projects/pytorch_xla/train_text_to_image_xla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def main(args):
520520
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
521521

522522
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
523-
unet.enable_use_xla_flash_attention(partition_spec=("data", None, None, None))
523+
unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
524524

525525
vae.requires_grad_(False)
526526
text_encoder.requires_grad_(False)

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,13 @@ def fn_recursive_set_flash_attention(module: torch.nn.Module):
225225
if isinstance(module, torch.nn.Module):
226226
fn_recursive_set_flash_attention(module)
227227

228-
def enable_use_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
228+
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
229229
r"""
230230
Enable the flash attention pallals kernel for torch_xla.
231231
"""
232232
self.set_use_xla_flash_attention(True, partition_spec)
233233

234-
def disable_use_xla_flash_attention(self):
234+
def disable_xla_flash_attention(self):
235235
r"""
236236
Disable the flash attention pallals kernel for torch_xla.
237237
"""

0 commit comments

Comments
 (0)