Skip to content

Commit c877a21

Browse files
authored
freeze vae part in fp32 (#4214)
1 parent 36e78fa commit c877a21

File tree

1 file changed

+10
-9
lines changed
  • ppdiffusers/examples/text_to_image_laion400m/ldm

1 file changed

+10
-9
lines changed

ppdiffusers/examples/text_to_image_laion400m/ldm/model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,16 @@ def on_train_batch_end(self):
168168

169169
def forward(self, input_ids=None, pixel_values=None, **kwargs):
170170
self.train()
171-
with paddle.no_grad():
172-
self.vae.eval()
173-
latents = self.vae.encode(pixel_values).latent_dist.sample()
174-
latents = latents * 0.18215
175-
noise = paddle.randn(latents.shape)
176-
timesteps = paddle.randint(0, self.noise_scheduler.num_train_timesteps, (latents.shape[0],)).astype(
177-
"int64"
178-
)
179-
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
171+
with paddle.amp.auto_cast(enable=False):
172+
with paddle.no_grad():
173+
self.vae.eval()
174+
latents = self.vae.encode(pixel_values).latent_dist.sample()
175+
latents = latents * 0.18215
176+
noise = paddle.randn(latents.shape)
177+
timesteps = paddle.randint(0, self.noise_scheduler.num_train_timesteps, (latents.shape[0],)).astype(
178+
"int64"
179+
)
180+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
180181

181182
encoder_hidden_states = self.text_encoder(input_ids)[0]
182183
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample

0 commit comments

Comments
 (0)