Skip to content

Commit 9ac0b48

Browse files
Make --gpu-only put intermediate values in GPU memory instead of cpu.
1 parent cdff081 commit 9ac0b48

File tree

9 files changed

+36
-29
lines changed

9 files changed

+36
-29
lines changed

comfy/clip_vision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def encode_image(self, image):
5454
t = outputs[k]
5555
if t is not None:
5656
if k == 'hidden_states':
57-
outputs["penultimate_hidden_states"] = t[-2].cpu()
57+
outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device())
5858
outputs["hidden_states"] = None
5959
else:
60-
outputs[k] = t.cpu()
60+
outputs[k] = t.to(comfy.model_management.intermediate_device())
6161

6262
return outputs
6363

comfy/model_management.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ def text_encoder_dtype(device=None):
508508
else:
509509
return torch.float32
510510

511+
def intermediate_device():
512+
if args.gpu_only:
513+
return get_torch_device()
514+
else:
515+
return torch.device("cpu")
516+
511517
def vae_device():
512518
return get_torch_device()
513519

comfy/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
9898
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
9999

100100
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
101-
samples = samples.cpu()
101+
samples = samples.to(comfy.model_management.intermediate_device())
102102

103103
cleanup_additional_models(models)
104104
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
@@ -111,7 +111,7 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
111111
sigmas = sigmas.to(model.load_device)
112112

113113
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
114-
samples = samples.cpu()
114+
samples = samples.to(comfy.model_management.intermediate_device())
115115
cleanup_additional_models(models)
116116
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
117117
return samples

comfy/sd.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__(self, sd=None, device=None, config=None):
190190
offload_device = model_management.vae_offload_device()
191191
self.vae_dtype = model_management.vae_dtype()
192192
self.first_stage_model.to(self.vae_dtype)
193+
self.output_device = model_management.intermediate_device()
193194

194195
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
195196

@@ -201,9 +202,9 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
201202

202203
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
203204
output = torch.clamp((
204-
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
205-
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
206-
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
205+
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
206+
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
207+
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar))
207208
/ 3.0) / 2.0, min=0.0, max=1.0)
208209
return output
209210

@@ -214,9 +215,9 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
214215
pbar = comfy.utils.ProgressBar(steps)
215216

216217
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
217-
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
218-
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
219-
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
218+
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
219+
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
220+
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
220221
samples /= 3.0
221222
return samples
222223

@@ -228,15 +229,15 @@ def decode(self, samples_in):
228229
batch_number = int(free_memory / memory_used)
229230
batch_number = max(1, batch_number)
230231

231-
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
232+
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device)
232233
for x in range(0, samples_in.shape[0], batch_number):
233234
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
234-
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
235+
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
235236
except model_management.OOM_EXCEPTION as e:
236237
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
237238
pixel_samples = self.decode_tiled_(samples_in)
238239

239-
pixel_samples = pixel_samples.cpu().movedim(1,-1)
240+
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
240241
return pixel_samples
241242

242243
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@@ -252,10 +253,10 @@ def encode(self, pixel_samples):
252253
free_memory = model_management.get_free_memory(self.device)
253254
batch_number = int(free_memory / memory_used)
254255
batch_number = max(1, batch_number)
255-
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
256+
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device)
256257
for x in range(0, pixel_samples.shape[0], batch_number):
257258
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
258-
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
259+
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
259260

260261
except model_management.OOM_EXCEPTION as e:
261262
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")

comfy/sd1_clip.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def encode_token_weights(self, token_weight_pairs):
3939

4040
out, pooled = self.encode(to_encode)
4141
if pooled is not None:
42-
first_pooled = pooled[0:1].cpu()
42+
first_pooled = pooled[0:1].to(model_management.intermediate_device())
4343
else:
4444
first_pooled = pooled
4545

@@ -56,8 +56,8 @@ def encode_token_weights(self, token_weight_pairs):
5656
output.append(z)
5757

5858
if (len(output) == 0):
59-
return out[-1:].cpu(), first_pooled
60-
return torch.cat(output, dim=-2).cpu(), first_pooled
59+
return out[-1:].to(model_management.intermediate_device()), first_pooled
60+
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
6161

6262
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
6363
"""Uses the CLIP transformer encoder for text (from huggingface)"""

comfy/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def lanczos(samples, width, height):
376376
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
377377
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
378378
result = torch.stack(images)
379-
return result
379+
return result.to(samples.device, samples.dtype)
380380

381381
def common_upscale(samples, width, height, upscale_method, crop):
382382
if crop == "center":
@@ -405,17 +405,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
405405
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
406406

407407
@torch.inference_mode()
408-
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
409-
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
408+
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
409+
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
410410
for b in range(samples.shape[0]):
411411
s = samples[b:b+1]
412-
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
413-
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
412+
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
413+
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
414414
for y in range(0, s.shape[2], tile_y - overlap):
415415
for x in range(0, s.shape[3], tile_x - overlap):
416416
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
417417

418-
ps = function(s_in).cpu()
418+
ps = function(s_in).to(output_device)
419419
mask = torch.ones_like(ps)
420420
feather = round(overlap * upscale_amount)
421421
for t in range(feather):

comfy_extras/nodes_canny.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def INPUT_TYPES(s):
291291

292292
def detect_edge(self, image, low_threshold, high_threshold):
293293
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
294-
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
294+
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
295295
return (img_out,)
296296

297297
NODE_CLASS_MAPPINGS = {

comfy_extras/nodes_post_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:
226226
batch_size, height, width, channels = image.shape
227227

228228
kernel_size = sharpen_radius * 2 + 1
229-
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
229+
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
230230
center = kernel_size // 2
231231
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
232232
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

nodes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,8 +947,8 @@ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, heigh
947947
return (c, )
948948

949949
class EmptyLatentImage:
950-
def __init__(self, device="cpu"):
951-
self.device = device
950+
def __init__(self):
951+
self.device = comfy.model_management.intermediate_device()
952952

953953
@classmethod
954954
def INPUT_TYPES(s):
@@ -961,7 +961,7 @@ def INPUT_TYPES(s):
961961
CATEGORY = "latent"
962962

963963
def generate(self, width, height, batch_size=1):
964-
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
964+
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
965965
return ({"samples":latent}, )
966966

967967

0 commit comments

Comments
 (0)