Skip to content

Commit 3b04cdc

Browse files
shinetzhneoshang
andauthored
fix loop bug in SlicedAttnProcessor (#8836)
* fix loop bug in SlicedAttnProcessor --------- Co-authored-by: neoshang <[email protected]>
1 parent c009c20 commit 3b04cdc

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,7 +2190,7 @@ def __call__(
21902190
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
21912191
)
21922192

2193-
for i in range(batch_size_attention // self.slice_size):
2193+
for i in range((batch_size_attention - 1) // self.slice_size + 1):
21942194
start_idx = i * self.slice_size
21952195
end_idx = (i + 1) * self.slice_size
21962196

@@ -2287,7 +2287,7 @@ def __call__(
22872287
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
22882288
)
22892289

2290-
for i in range(batch_size_attention // self.slice_size):
2290+
for i in range((batch_size_attention - 1) // self.slice_size + 1):
22912291
start_idx = i * self.slice_size
22922292
end_idx = (i + 1) * self.slice_size
22932293

tests/pipelines/test_pipelines_common.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,14 +1351,24 @@ def _test_attention_slicing_forward_pass(
13511351

13521352
pipe.enable_attention_slicing(slice_size=1)
13531353
inputs = self.get_dummy_inputs(generator_device)
1354-
output_with_slicing = pipe(**inputs)[0]
1354+
output_with_slicing1 = pipe(**inputs)[0]
1355+
1356+
pipe.enable_attention_slicing(slice_size=2)
1357+
inputs = self.get_dummy_inputs(generator_device)
1358+
output_with_slicing2 = pipe(**inputs)[0]
13551359

13561360
if test_max_difference:
1357-
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max()
1358-
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
1361+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
1362+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
1363+
self.assertLess(
1364+
max(max_diff1, max_diff2),
1365+
expected_max_diff,
1366+
"Attention slicing should not affect the inference results",
1367+
)
13591368

13601369
if test_mean_pixel_difference:
1361-
assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
1370+
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
1371+
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
13621372

13631373
@unittest.skipIf(
13641374
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),

0 commit comments

Comments
 (0)