Skip to content

Commit cf010fc

Browse files
committed
apply suggestions from review
1 parent d9eabf8 commit cf010fc

File tree

2 files changed

+38
-38
lines changed

2 files changed

+38
-38
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
300300
return hidden_states
301301

302302

303-
class UNetMidBlock3DConv(nn.Module):
303+
class AllegroMidBlock3DConv(nn.Module):
304304
def __init__(
305305
self,
306306
in_channels: int,
@@ -473,7 +473,7 @@ def __init__(
473473
self.down_blocks.append(down_block)
474474

475475
# mid
476-
self.mid_block = UNetMidBlock3DConv(
476+
self.mid_block = AllegroMidBlock3DConv(
477477
in_channels=block_out_channels[-1],
478478
resnet_eps=1e-6,
479479
resnet_act_fn=act_fn,
@@ -581,7 +581,7 @@ def __init__(
581581
temb_channels = in_channels if norm_type == "spatial" else None
582582

583583
# mid
584-
self.mid_block = UNetMidBlock3DConv(
584+
self.mid_block = AllegroMidBlock3DConv(
585585
in_channels=block_out_channels[-1],
586586
resnet_eps=1e-6,
587587
resnet_act_fn=act_fn,

tests/pipelines/allegro/test_allegro.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -206,40 +206,40 @@ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
206206
def test_inference_batch_single_identical(self):
207207
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
208208

209-
# def test_attention_slicing_forward_pass(
210-
# self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
211-
# ):
212-
# if not self.test_attention_slicing:
213-
# return
214-
215-
# components = self.get_dummy_components()
216-
# pipe = self.pipeline_class(**components)
217-
# for component in pipe.components.values():
218-
# if hasattr(component, "set_default_attn_processor"):
219-
# component.set_default_attn_processor()
220-
# pipe.to(torch_device)
221-
# pipe.set_progress_bar_config(disable=None)
222-
223-
# generator_device = "cpu"
224-
# inputs = self.get_dummy_inputs(generator_device)
225-
# output_without_slicing = pipe(**inputs)[0]
226-
227-
# pipe.enable_attention_slicing(slice_size=1)
228-
# inputs = self.get_dummy_inputs(generator_device)
229-
# output_with_slicing1 = pipe(**inputs)[0]
230-
231-
# pipe.enable_attention_slicing(slice_size=2)
232-
# inputs = self.get_dummy_inputs(generator_device)
233-
# output_with_slicing2 = pipe(**inputs)[0]
234-
235-
# if test_max_difference:
236-
# max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
237-
# max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
238-
# self.assertLess(
239-
# max(max_diff1, max_diff2),
240-
# expected_max_diff,
241-
# "Attention slicing should not affect the inference results",
242-
# )
209+
def test_attention_slicing_forward_pass(
210+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
211+
):
212+
if not self.test_attention_slicing:
213+
return
214+
215+
components = self.get_dummy_components()
216+
pipe = self.pipeline_class(**components)
217+
for component in pipe.components.values():
218+
if hasattr(component, "set_default_attn_processor"):
219+
component.set_default_attn_processor()
220+
pipe.to(torch_device)
221+
pipe.set_progress_bar_config(disable=None)
222+
223+
generator_device = "cpu"
224+
inputs = self.get_dummy_inputs(generator_device)
225+
output_without_slicing = pipe(**inputs)[0]
226+
227+
pipe.enable_attention_slicing(slice_size=1)
228+
inputs = self.get_dummy_inputs(generator_device)
229+
output_with_slicing1 = pipe(**inputs)[0]
230+
231+
pipe.enable_attention_slicing(slice_size=2)
232+
inputs = self.get_dummy_inputs(generator_device)
233+
output_with_slicing2 = pipe(**inputs)[0]
234+
235+
if test_max_difference:
236+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
237+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
238+
self.assertLess(
239+
max(max_diff1, max_diff2),
240+
expected_max_diff,
241+
"Attention slicing should not affect the inference results",
242+
)
243243

244244
def test_vae_tiling(self, expected_diff_max: float = 0.2):
245245
generator_device = "cpu"
@@ -287,7 +287,7 @@ def tearDown(self):
287287
gc.collect()
288288
torch.cuda.empty_cache()
289289

290-
def test_cogvideox(self):
290+
def test_allegro(self):
291291
generator = torch.Generator("cpu").manual_seed(0)
292292

293293
pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16)

0 commit comments

Comments
 (0)