Skip to content

Commit deb2b93

Browse files
Move code to empty gpu cache to model_management.py
1 parent f4c689e commit deb2b93

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

comfy/model_management.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,15 @@ def should_use_fp16():
307307

308308
return True
309309

310+
def soft_empty_cache():
311+
global xpu_available
312+
if xpu_available:
313+
torch.xpu.empty_cache()
314+
elif torch.cuda.is_available():
315+
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
316+
torch.cuda.empty_cache()
317+
torch.cuda.ipc_collect()
318+
310319
#TODO: might be cleaner to put this somewhere else
311320
import threading
312321

execution.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import nodes
1212

13-
from model_management import xpu_available
13+
import comfy.model_management
1414

1515
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
1616
valid_inputs = class_def.INPUT_TYPES()
@@ -204,12 +204,7 @@ def execute(self, prompt, extra_data={}):
204204
self.server.send_sync("executing", { "node": None }, self.server.client_id)
205205

206206
gc.collect()
207-
if torch.cuda.is_available():
208-
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
209-
torch.cuda.empty_cache()
210-
torch.cuda.ipc_collect()
211-
elif xpu_available:
212-
torch.xpu.empty_cache()
207+
comfy.model_management.soft_empty_cache()
213208

214209

215210
def validate_inputs(prompt, item):

0 commit comments

Comments
 (0)