|
35 | 35 | import comfy.patcher_extension |
36 | 36 | import comfy.utils |
37 | 37 | from comfy.comfy_types import UnetWrapperFunction |
| 38 | +from comfy.quant_ops import QuantizedTensor |
38 | 39 | from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP |
39 | 40 |
|
40 | 41 |
|
@@ -132,14 +133,17 @@ def __init__(self, key, patches, convert_func=None, set_func=None): |
132 | 133 | def __call__(self, weight): |
133 | 134 | return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) |
134 | 135 |
|
135 | | -#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 |
136 | | -LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 |
| 136 | +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 |
137 | 137 |
|
138 | 138 | def low_vram_patch_estimate_vram(model, key): |
139 | 139 | weight, set_func, convert_func = get_key_weight(model, key) |
140 | 140 | if weight is None: |
141 | 141 | return 0 |
142 | | - return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR |
| 142 | + model_dtype = getattr(model, "manual_cast_dtype", torch.float32) |
| 143 | + if model_dtype is None: |
| 144 | + model_dtype = weight.dtype |
| 145 | + |
| 146 | + return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR |
143 | 147 |
|
144 | 148 | def get_key_weight(model, key): |
145 | 149 | set_func = None |
@@ -614,10 +618,11 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False): |
614 | 618 | if key not in self.backup: |
615 | 619 | self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) |
616 | 620 |
|
| 621 | + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) |
617 | 622 | if device_to is not None: |
618 | | - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) |
| 623 | + temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) |
619 | 624 | else: |
620 | | - temp_weight = weight.to(torch.float32, copy=True) |
| 625 | + temp_weight = weight.to(temp_dtype, copy=True) |
621 | 626 | if convert_func is not None: |
622 | 627 | temp_weight = convert_func(temp_weight, inplace=True) |
623 | 628 |
|
@@ -661,12 +666,18 @@ def _load_list(self): |
661 | 666 | module_mem = comfy.model_management.module_size(m) |
662 | 667 | module_offload_mem = module_mem |
663 | 668 | if hasattr(m, "comfy_cast_weights"): |
664 | | - weight_key = "{}.weight".format(n) |
665 | | - bias_key = "{}.bias".format(n) |
666 | | - if weight_key in self.patches: |
667 | | - module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) |
668 | | - if bias_key in self.patches: |
669 | | - module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) |
| 669 | + def check_module_offload_mem(key): |
| 670 | + if key in self.patches: |
| 671 | + return low_vram_patch_estimate_vram(self.model, key) |
| 672 | + model_dtype = getattr(self.model, "manual_cast_dtype", None) |
| 673 | + weight, _, _ = get_key_weight(self.model, key) |
| 674 | + if model_dtype is None or weight is None: |
| 675 | + return 0 |
| 676 | + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): |
| 677 | + return weight.numel() * model_dtype.itemsize |
| 678 | + return 0 |
| 679 | + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) |
| 680 | + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) |
670 | 681 | loading.append((module_offload_mem, module_mem, n, m, params)) |
671 | 682 | return loading |
672 | 683 |
|
@@ -761,6 +772,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False |
761 | 772 | key = "{}.{}".format(n, param) |
762 | 773 | self.unpin_weight(key) |
763 | 774 | self.patch_weight_to_device(key, device_to=device_to) |
| 775 | + if comfy.model_management.is_device_cuda(device_to): |
| 776 | + torch.cuda.synchronize() |
764 | 777 |
|
765 | 778 | logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) |
766 | 779 | m.comfy_patched_weights = True |
@@ -917,7 +930,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals |
917 | 930 | patch_counter += 1 |
918 | 931 | cast_weight = True |
919 | 932 |
|
920 | | - if cast_weight: |
| 933 | + if cast_weight and hasattr(m, "comfy_cast_weights"): |
921 | 934 | m.prev_comfy_cast_weights = m.comfy_cast_weights |
922 | 935 | m.comfy_cast_weights = True |
923 | 936 | m.comfy_patched_weights = False |
|
0 commit comments