File tree Expand file tree Collapse file tree 1 file changed +15
-7
lines changed
Expand file tree Collapse file tree 1 file changed +15
-7
lines changed Original file line number Diff line number Diff line change @@ -912,14 +912,22 @@ def remove_model(old_ckpt_name):
912912 if "latents" in batch and batch ["latents" ] is not None :
913913 latents = batch ["latents" ].to (accelerator .device ).to (dtype = weight_dtype )
914914 else :
915- with torch .no_grad ():
916- # latentに変換
917- latents = vae .encode (batch ["images" ].to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype )
918-
915+ if args .vae_batch_size is None or len (batch ["images" ]) <= args .vae_batch_size :
916+ with torch .no_grad ():
917+ # latentに変換
918+ latents = vae .encode (batch ["images" ].to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype )
919+ else :
920+ chunks = [batch ["images" ][i :i + args .vae_batch_size ] for i in range (0 , len (batch ["images" ]), args .vae_batch_size )]
921+ list_latents = []
922+ for chunk in chunks :
923+ with torch .no_grad ():
924+ # latentに変換
925+ list_latents .append (vae .encode (chunk .to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype ))
926+ latents = torch .cat (list_latents , dim = 0 )
919927 # NaNが含まれていれば警告を表示し0に置き換える
920- if torch .any (torch .isnan (latents )):
921- accelerator .print ("NaN found in latents, replacing with zeros" )
922- latents = torch .nan_to_num (latents , 0 , out = latents )
928+ if torch .any (torch .isnan (latents )):
929+ accelerator .print ("NaN found in latents, replacing with zeros" )
930+ latents = torch .nan_to_num (latents , 0 , out = latents )
923931 latents = latents * self .vae_scale_factor
924932
925933 # get multiplier for each sample
You can’t perform that action at this time.
0 commit comments