File tree Expand file tree Collapse file tree 1 file changed +10
-9
lines changed
ppdiffusers/examples/text_to_image_laion400m/ldm Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -168,15 +168,16 @@ def on_train_batch_end(self):
168
168
169
169
def forward (self , input_ids = None , pixel_values = None , ** kwargs ):
170
170
self .train ()
171
- with paddle .no_grad ():
172
- self .vae .eval ()
173
- latents = self .vae .encode (pixel_values ).latent_dist .sample ()
174
- latents = latents * 0.18215
175
- noise = paddle .randn (latents .shape )
176
- timesteps = paddle .randint (0 , self .noise_scheduler .num_train_timesteps , (latents .shape [0 ],)).astype (
177
- "int64"
178
- )
179
- noisy_latents = self .noise_scheduler .add_noise (latents , noise , timesteps )
171
+ with paddle .amp .auto_cast (enable = False ):
172
+ with paddle .no_grad ():
173
+ self .vae .eval ()
174
+ latents = self .vae .encode (pixel_values ).latent_dist .sample ()
175
+ latents = latents * 0.18215
176
+ noise = paddle .randn (latents .shape )
177
+ timesteps = paddle .randint (0 , self .noise_scheduler .num_train_timesteps , (latents .shape [0 ],)).astype (
178
+ "int64"
179
+ )
180
+ noisy_latents = self .noise_scheduler .add_noise (latents , noise , timesteps )
180
181
181
182
encoder_hidden_states = self .text_encoder (input_ids )[0 ]
182
183
noise_pred = self .unet (noisy_latents , timesteps , encoder_hidden_states ).sample
You can’t perform that action at this time.
0 commit comments