@@ -1112,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
11121112 return None
11131113 if cast_buffer is not None and cast_buffer .numel () > 50 * (1024 ** 2 ):
11141114 #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
1115- torch . cuda . synchronize ()
1115+ synchronize ()
11161116 del STREAM_CAST_BUFFERS [offload_stream ]
11171117 del cast_buffer
11181118 #FIXME: This doesn't work in Aimdo because mempool cant clear cache
1119- torch . cuda . empty_cache ()
1119+ soft_empty_cache ()
11201120 with wf_context :
11211121 cast_buffer = torch .empty ((size ), dtype = torch .int8 , device = device )
11221122 STREAM_CAST_BUFFERS [offload_stream ] = cast_buffer
@@ -1132,9 +1132,7 @@ def reset_cast_buffers():
11321132 for offload_stream in STREAM_CAST_BUFFERS :
11331133 offload_stream .synchronize ()
11341134 STREAM_CAST_BUFFERS .clear ()
1135- if comfy .memory_management .aimdo_allocator is None :
1136- #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
1137- torch .cuda .empty_cache ()
1135+ soft_empty_cache ()
11381136
11391137def get_offload_stream (device ):
11401138 stream_counter = stream_counters .get (device , 0 )
@@ -1284,7 +1282,7 @@ def discard_cuda_async_error():
12841282 a = torch .tensor ([1 ], dtype = torch .uint8 , device = get_torch_device ())
12851283 b = torch .tensor ([1 ], dtype = torch .uint8 , device = get_torch_device ())
12861284 _ = a + b
1287- torch . cuda . synchronize ()
1285+ synchronize ()
12881286 except torch .AcceleratorError :
12891287 #Dump it! We already know about it from the synchronous return
12901288 pass
@@ -1688,6 +1686,12 @@ def lora_compute_dtype(device):
16881686 LORA_COMPUTE_DTYPES [device ] = dtype
16891687 return dtype
16901688
1689+ def synchronize ():
1690+ if is_intel_xpu ():
1691+ torch .xpu .synchronize ()
1692+ elif torch .cuda .is_available ():
1693+ torch .cuda .synchronize ()
1694+
16911695def soft_empty_cache (force = False ):
16921696 global cpu_state
16931697 if cpu_state == CPUState .MPS :
0 commit comments