Skip to content

Commit 34c90db

Browse files
authored
fix OOM for test_vae_tiling (#7510)
use float16 and add torch.no_grad()
1 parent e49c04d commit 34c90db

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

tests/models/autoencoders/test_models_vae.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,8 +1118,10 @@ def test_sd_f16(self):
11181118
assert torch_all_close(actual_output, expected_output, atol=5e-3)
11191119

11201120
def test_vae_tiling(self):
1121-
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
1122-
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
1121+
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
1122+
pipe = StableDiffusionPipeline.from_pretrained(
1123+
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
1124+
)
11231125
pipe.to(torch_device)
11241126
pipe.set_progress_bar_config(disable=None)
11251127

@@ -1143,6 +1145,7 @@ def test_vae_tiling(self):
11431145

11441146
# test that tiled decode works with various shapes
11451147
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
1146-
for shape in shapes:
1147-
image = torch.zeros(shape, device=torch_device)
1148-
pipe.vae.decode(image)
1148+
with torch.no_grad():
1149+
for shape in shapes:
1150+
image = torch.zeros(shape, device=torch_device)
1151+
pipe.vae.decode(image)

tests/pipelines/test_pipelines_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def test_vae_tiling(self):
124124

125125
# test that tiled decode works with various shapes
126126
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
127-
for shape in shapes:
128-
zeros = torch.zeros(shape).to(torch_device)
129-
pipe.vae.decode(zeros)
127+
with torch.no_grad():
128+
for shape in shapes:
129+
zeros = torch.zeros(shape).to(torch_device)
130+
pipe.vae.decode(zeros)
130131

131132
def test_freeu_enabled(self):
132133
components = self.get_dummy_components()

0 commit comments

Comments
 (0)