Skip to content

Commit 6be85c7

Browse files
authored
mp: use look-ahead actuals for stream offload VRAM calculation (#11096)
TIL that the WAN TE has a 2GB weight followed by 16MB as the next size down. This means that team 8GB VRAM would fully offload the TE in async offload mode as it just multiplied this giant size my the num streams. Do the more complex logic of summing up the upcoming to-load weight sizes to avoid triple counting this massive weight. partial unload does the converse of recording the NS most recent unloads as they go.
1 parent ea17add commit 6be85c7

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

comfy/model_patcher.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,12 +699,12 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
699699
offloaded = []
700700
offload_buffer = 0
701701
loading.sort(reverse=True)
702-
for x in loading:
702+
for i, x in enumerate(loading):
703703
module_offload_mem, module_mem, n, m, params = x
704704

705705
lowvram_weight = False
706706

707-
potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
707+
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
708708
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
709709

710710
weight_key = "{}.weight".format(n)
@@ -876,14 +876,18 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
876876
patch_counter = 0
877877
unload_list = self._load_list()
878878
unload_list.sort()
879+
879880
offload_buffer = self.model.model_offload_buffer_memory
881+
if len(unload_list) > 0:
882+
NS = comfy.model_management.NUM_STREAMS
883+
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
880884

881885
for unload in unload_list:
882886
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
883887
break
884888
module_offload_mem, module_mem, n, m, params = unload
885889

886-
potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
890+
potential_offload = module_offload_mem + sum(offload_weight_factor)
887891

888892
lowvram_possible = hasattr(m, "comfy_cast_weights")
889893
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@@ -935,6 +939,8 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
935939
m.comfy_patched_weights = False
936940
memory_freed += module_mem
937941
offload_buffer = max(offload_buffer, potential_offload)
942+
offload_weight_factor.append(module_mem)
943+
offload_weight_factor.pop(0)
938944
logging.debug("freed {}".format(n))
939945

940946
for param in params:

0 commit comments

Comments
 (0)