Skip to content

Commit 519c941

Browse files
authored
Prs/lora reservations (reduce massive Lora reservations especially on Flux2) (#11069)
* mp: only count the offload cost of math once This was previously bundling the combined weight storage and computation cost * ops: put all post async transfer compute on the main stream Some models have massive weights that need either complex dequantization or lora patching. Don't do these patchings on the offload stream, instead do them on the main stream to syncrhonize the potentially large vram spikes for these compute processes. This avoids having to assume a worst case scenario of multiple offload streams all spiking VRAM is parallel with whatever the main stream is doing.
1 parent 861817d commit 519c941

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

comfy/model_patcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
704704

705705
lowvram_weight = False
706706

707-
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
707+
potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
708708
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
709709

710710
weight_key = "{}.weight".format(n)
@@ -883,7 +883,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
883883
break
884884
module_offload_mem, module_mem, n, m, params = unload
885885

886-
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
886+
potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
887887

888888
lowvram_possible = hasattr(m, "comfy_cast_weights")
889889
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:

comfy/ops.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
111111
if s.bias is not None:
112112
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
113113

114-
if bias_has_function:
115-
with wf_context:
116-
for f in s.bias_function:
117-
bias = f(bias)
114+
comfy.model_management.sync_stream(device, offload_stream)
115+
116+
bias_a = bias
117+
weight_a = weight
118+
119+
if s.bias is not None:
120+
for f in s.bias_function:
121+
bias = f(bias)
118122

119123
if weight_has_function or weight.dtype != dtype:
120-
with wf_context:
121-
weight = weight.to(dtype=dtype)
122-
if isinstance(weight, QuantizedTensor):
123-
weight = weight.dequantize()
124-
for f in s.weight_function:
125-
weight = f(weight)
124+
weight = weight.to(dtype=dtype)
125+
if isinstance(weight, QuantizedTensor):
126+
weight = weight.dequantize()
127+
for f in s.weight_function:
128+
weight = f(weight)
126129

127-
comfy.model_management.sync_stream(device, offload_stream)
128130
if offloadable:
129-
return weight, bias, offload_stream
131+
return weight, bias, (offload_stream, weight_a, bias_a)
130132
else:
131133
#Legacy function signature
132134
return weight, bias
@@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
135137
def uncast_bias_weight(s, weight, bias, offload_stream):
136138
if offload_stream is None:
137139
return
138-
if weight is not None:
139-
device = weight.device
140+
os, weight_a, bias_a = offload_stream
141+
if os is None:
142+
return
143+
if weight_a is not None:
144+
device = weight_a.device
140145
else:
141-
if bias is None:
146+
if bias_a is None:
142147
return
143-
device = bias.device
144-
offload_stream.wait_stream(comfy.model_management.current_stream(device))
148+
device = bias_a.device
149+
os.wait_stream(comfy.model_management.current_stream(device))
145150

146151

147152
class CastWeightBiasOp:

0 commit comments

Comments
 (0)