@@ -190,6 +190,7 @@ def __init__(self, sd=None, device=None, config=None):
190190 offload_device = model_management .vae_offload_device ()
191191 self .vae_dtype = model_management .vae_dtype ()
192192 self .first_stage_model .to (self .vae_dtype )
193+ self .output_device = model_management .intermediate_device ()
193194
194195 self .patcher = comfy .model_patcher .ModelPatcher (self .first_stage_model , load_device = self .device , offload_device = offload_device )
195196
@@ -201,9 +202,9 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
201202
202203 decode_fn = lambda a : (self .first_stage_model .decode (a .to (self .vae_dtype ).to (self .device )) + 1.0 ).float ()
203204 output = torch .clamp ((
204- (comfy .utils .tiled_scale (samples , decode_fn , tile_x // 2 , tile_y * 2 , overlap , upscale_amount = 8 , pbar = pbar ) +
205- comfy .utils .tiled_scale (samples , decode_fn , tile_x * 2 , tile_y // 2 , overlap , upscale_amount = 8 , pbar = pbar ) +
206- comfy .utils .tiled_scale (samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = 8 , pbar = pbar ))
205+ (comfy .utils .tiled_scale (samples , decode_fn , tile_x // 2 , tile_y * 2 , overlap , upscale_amount = 8 , output_device = self . output_device , pbar = pbar ) +
206+ comfy .utils .tiled_scale (samples , decode_fn , tile_x * 2 , tile_y // 2 , overlap , upscale_amount = 8 , output_device = self . output_device , pbar = pbar ) +
207+ comfy .utils .tiled_scale (samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = 8 , output_device = self . output_device , pbar = pbar ))
207208 / 3.0 ) / 2.0 , min = 0.0 , max = 1.0 )
208209 return output
209210
@@ -214,9 +215,9 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
214215 pbar = comfy .utils .ProgressBar (steps )
215216
216217 encode_fn = lambda a : self .first_stage_model .encode ((2. * a - 1. ).to (self .vae_dtype ).to (self .device )).float ()
217- samples = comfy .utils .tiled_scale (pixel_samples , encode_fn , tile_x , tile_y , overlap , upscale_amount = (1 / 8 ), out_channels = 4 , pbar = pbar )
218- samples += comfy .utils .tiled_scale (pixel_samples , encode_fn , tile_x * 2 , tile_y // 2 , overlap , upscale_amount = (1 / 8 ), out_channels = 4 , pbar = pbar )
219- samples += comfy .utils .tiled_scale (pixel_samples , encode_fn , tile_x // 2 , tile_y * 2 , overlap , upscale_amount = (1 / 8 ), out_channels = 4 , pbar = pbar )
218+ samples = comfy .utils .tiled_scale (pixel_samples , encode_fn , tile_x , tile_y , overlap , upscale_amount = (1 / 8 ), out_channels = 4 , output_device = self . output_device , pbar = pbar )
219+ samples += comfy .utils .tiled_scale (pixel_samples , encode_fn , tile_x * 2 , tile_y // 2 , overlap , upscale_amount = (1 / 8 ), out_channels = 4 , output_device = self . output_device , pbar = pbar )
220+ samples += comfy .utils .tiled_scale (pixel_samples , encode_fn , tile_x // 2 , tile_y * 2 , overlap , upscale_amount = (1 / 8 ), out_channels = 4 , output_device = self . output_device , pbar = pbar )
220221 samples /= 3.0
221222 return samples
222223
@@ -228,15 +229,15 @@ def decode(self, samples_in):
228229 batch_number = int (free_memory / memory_used )
229230 batch_number = max (1 , batch_number )
230231
231- pixel_samples = torch .empty ((samples_in .shape [0 ], 3 , round (samples_in .shape [2 ] * 8 ), round (samples_in .shape [3 ] * 8 )), device = "cpu" )
232+ pixel_samples = torch .empty ((samples_in .shape [0 ], 3 , round (samples_in .shape [2 ] * 8 ), round (samples_in .shape [3 ] * 8 )), device = self . output_device )
232233 for x in range (0 , samples_in .shape [0 ], batch_number ):
233234 samples = samples_in [x :x + batch_number ].to (self .vae_dtype ).to (self .device )
234- pixel_samples [x :x + batch_number ] = torch .clamp ((self .first_stage_model .decode (samples ).cpu ( ).float () + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
235+ pixel_samples [x :x + batch_number ] = torch .clamp ((self .first_stage_model .decode (samples ).to ( self . output_device ).float () + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
235236 except model_management .OOM_EXCEPTION as e :
236237 print ("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding." )
237238 pixel_samples = self .decode_tiled_ (samples_in )
238239
239- pixel_samples = pixel_samples .cpu ( ).movedim (1 ,- 1 )
240+ pixel_samples = pixel_samples .to ( self . output_device ).movedim (1 ,- 1 )
240241 return pixel_samples
241242
242243 def decode_tiled (self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ):
@@ -252,10 +253,10 @@ def encode(self, pixel_samples):
252253 free_memory = model_management .get_free_memory (self .device )
253254 batch_number = int (free_memory / memory_used )
254255 batch_number = max (1 , batch_number )
255- samples = torch .empty ((pixel_samples .shape [0 ], 4 , round (pixel_samples .shape [2 ] // 8 ), round (pixel_samples .shape [3 ] // 8 )), device = "cpu" )
256+ samples = torch .empty ((pixel_samples .shape [0 ], 4 , round (pixel_samples .shape [2 ] // 8 ), round (pixel_samples .shape [3 ] // 8 )), device = self . output_device )
256257 for x in range (0 , pixel_samples .shape [0 ], batch_number ):
257258 pixels_in = (2. * pixel_samples [x :x + batch_number ] - 1. ).to (self .vae_dtype ).to (self .device )
258- samples [x :x + batch_number ] = self .first_stage_model .encode (pixels_in ).cpu ( ).float ()
259+ samples [x :x + batch_number ] = self .first_stage_model .encode (pixels_in ).to ( self . output_device ).float ()
259260
260261 except model_management .OOM_EXCEPTION as e :
261262 print ("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding." )
0 commit comments