Skip to content

Commit 37f711d

Browse files
authored
mm: Fix cast buffers with intel offloading (#12229)
Intel has offloading support but there were some nvidia calls in the new cast buffer stuff.
1 parent dd86b15 commit 37f711d

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

comfy/model_management.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
11121112
return None
11131113
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
11141114
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
1115-
torch.cuda.synchronize()
1115+
synchronize()
11161116
del STREAM_CAST_BUFFERS[offload_stream]
11171117
del cast_buffer
11181118
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
1119-
torch.cuda.empty_cache()
1119+
soft_empty_cache()
11201120
with wf_context:
11211121
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
11221122
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
@@ -1132,9 +1132,7 @@ def reset_cast_buffers():
11321132
for offload_stream in STREAM_CAST_BUFFERS:
11331133
offload_stream.synchronize()
11341134
STREAM_CAST_BUFFERS.clear()
1135-
if comfy.memory_management.aimdo_allocator is None:
1136-
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
1137-
torch.cuda.empty_cache()
1135+
soft_empty_cache()
11381136

11391137
def get_offload_stream(device):
11401138
stream_counter = stream_counters.get(device, 0)
@@ -1284,7 +1282,7 @@ def discard_cuda_async_error():
12841282
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
12851283
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
12861284
_ = a + b
1287-
torch.cuda.synchronize()
1285+
synchronize()
12881286
except torch.AcceleratorError:
12891287
#Dump it! We already know about it from the synchronous return
12901288
pass
@@ -1688,6 +1686,12 @@ def lora_compute_dtype(device):
16881686
LORA_COMPUTE_DTYPES[device] = dtype
16891687
return dtype
16901688

1689+
def synchronize():
1690+
if is_intel_xpu():
1691+
torch.xpu.synchronize()
1692+
elif torch.cuda.is_available():
1693+
torch.cuda.synchronize()
1694+
16911695
def soft_empty_cache(force=False):
16921696
global cpu_state
16931697
if cpu_state == CPUState.MPS:

0 commit comments

Comments
 (0)