Skip to content

Commit 9e234d8

Browse files
patil-surajpcuenca
andauthored
handle fp16 in UNet2DModel (#1216)
* make sure fp16 runs well * add fp16 test for superes * Update src/diffusers/models/unet_2d.py Co-authored-by: Pedro Cuenca <[email protected]> * gen on cuda * always run fast inferecne test on cpu * run on cpu Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 8fd3a74 commit 9e234d8

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ def forward(
209209
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
210210

211211
t_emb = self.time_proj(timesteps)
212+
213+
# timesteps does not contain any weights and will always return f32 tensors
214+
# but time_embedding might actually be running in fp16. so we need to cast here.
215+
# there might be better ways to encapsulate this.
216+
t_emb = t_emb.to(dtype=self.dtype)
212217
emb = self.time_embedding(t_emb)
213218

214219
# 2. pre-process
@@ -242,9 +247,7 @@ def forward(
242247
sample = upsample_block(sample, res_samples, emb)
243248

244249
# 6. post-process
245-
# make sure hidden states is in float32
246-
# when running in half-precision
247-
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
250+
sample = self.conv_norm_out(sample)
248251
sample = self.conv_act(sample)
249252
sample = self.conv_out(sample)
250253

tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,27 @@ def test_inference_superresolution(self):
8787
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
8888
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
8989

90+
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
91+
def test_inference_superresolution_fp16(self):
92+
unet = self.dummy_uncond_unet
93+
scheduler = DDIMScheduler()
94+
vqvae = self.dummy_vq_model
95+
96+
# put models in fp16
97+
unet = unet.half()
98+
vqvae = vqvae.half()
99+
100+
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
101+
ldm.to(torch_device)
102+
ldm.set_progress_bar_config(disable=None)
103+
104+
init_image = self.dummy_image.to(torch_device)
105+
106+
generator = torch.Generator(device=torch_device).manual_seed(0)
107+
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
108+
109+
assert image.shape == (1, 64, 64, 3)
110+
90111

91112
@slow
92113
@require_torch

0 commit comments

Comments
 (0)