Skip to content

Commit 3a08cf4

Browse files
committed
code format fixes
1 parent 628f72d commit 3a08cf4

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
is_xformers_version,
4545
)
4646
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
47-
from ..utils import is_torch_version
4847

4948

5049
if TYPE_CHECKING:
@@ -1076,7 +1075,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10761075

10771076
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
10781077
x_temp = x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim).transpose(0, 2).contiguous()
1079-
1078+
10801079

10811080
if group_world_size >1:
10821081
out = _all_to_all_single(x_temp, group=group)
@@ -1365,7 +1364,7 @@ def TemplatedUnifiedAttention(
13651364
forward_op,
13661365
backward_op,
13671366
_parallel_config: Optional["ParallelConfig"] = None,
1368-
scatter_idx: int =2,
1367+
scatter_idx: int =2,
13691368
gather_idx: int =1,
13701369
):
13711370
"""

tests/others/test_unified_sp_attention.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
import math
2+
import os
3+
24
import torch
35
import torch.distributed as dist
46
import torch.multiprocessing as mp
7+
58
from diffusers.models.attention_dispatch import TemplatedUnifiedAttention
6-
import os
9+
from diffusers.models._modeling_parallel import (
10+
ParallelConfig,
11+
ContextParallelConfig
12+
)
713

814
def run(rank, world_size):
915
dist.init_process_group(
10-
backend="gloo",
16+
backend="gloo",
1117
rank=rank,
1218
world_size=world_size
1319
)
@@ -21,10 +27,7 @@ def run(rank, world_size):
2127

2228
q.requires_grad_(True)
2329

24-
from diffusers.models._modeling_parallel import (
25-
ParallelConfig,
26-
ContextParallelConfig
27-
)
30+
2831

2932
pc = ParallelConfig(
3033
context_parallel_config=ContextParallelConfig(
@@ -126,4 +129,4 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs):
126129
world_size = 4
127130
os.environ["MASTER_ADDR"] = "localhost"
128131
os.environ["MASTER_PORT"] = "12355"
129-
mp.spawn(run, args=(world_size,), nprocs=world_size)
132+
mp.spawn(run, args=(world_size,), nprocs=world_size)

0 commit comments

Comments
 (0)