Skip to content

Commit 477937e

Browse files
committed
add vae tiling test
1 parent 84854b4 commit 477937e

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/pipelines/sana/test_sana.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,36 @@ def test_attention_slicing_forward_pass(
254254
"Attention slicing should not affect the inference results",
255255
)
256256

257+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
258+
generator_device = "cpu"
259+
components = self.get_dummy_components()
260+
261+
pipe = self.pipeline_class(**components)
262+
pipe.to("cpu")
263+
pipe.set_progress_bar_config(disable=None)
264+
265+
# Without tiling
266+
inputs = self.get_dummy_inputs(generator_device)
267+
inputs["height"] = inputs["width"] = 128
268+
output_without_tiling = pipe(**inputs)[0]
269+
270+
# With tiling
271+
pipe.vae.enable_tiling(
272+
tile_sample_min_height=96,
273+
tile_sample_min_width=96,
274+
tile_sample_stride_height=64,
275+
tile_sample_stride_width=64,
276+
)
277+
inputs = self.get_dummy_inputs(generator_device)
278+
inputs["height"] = inputs["width"] = 128
279+
output_with_tiling = pipe(**inputs)[0]
280+
281+
self.assertLess(
282+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
283+
expected_diff_max,
284+
"VAE tiling should not affect the inference results",
285+
)
286+
257287
# TODO(aryan): Create a dummy gemma model with smol vocab size
258288
@unittest.skip(
259289
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."

0 commit comments

Comments
 (0)