Skip to content

Commit 43071e3

Browse files
Make old scaled fp8 format use the new mixed quant ops system. (#11000)
1 parent 0ec05b1 commit 43071e3

24 files changed

+278
-275
lines changed

comfy/model_base.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
134134
if not unet_config.get("disable_unet_model_creation", False):
135135
if model_config.custom_operations is None:
136136
fp8 = model_config.optimizations.get("fp8", False)
137-
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
137+
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
138138
else:
139139
operations = model_config.custom_operations
140140
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -329,18 +329,6 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_
329329
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
330330

331331
unet_state_dict = self.diffusion_model.state_dict()
332-
333-
if self.model_config.scaled_fp8 is not None:
334-
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
335-
336-
# Save mixed precision metadata
337-
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
338-
metadata = {
339-
"format_version": "1.0",
340-
"layers": self.model_config.layer_quant_config
341-
}
342-
unet_state_dict["_quantization_metadata"] = metadata
343-
344332
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
345333

346334
if self.model_type == ModelType.V_PREDICTION:

comfy/model_detection.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,6 @@
66
import logging
77
import torch
88

9-
10-
def detect_layer_quantization(metadata):
11-
quant_key = "_quantization_metadata"
12-
if metadata is not None and quant_key in metadata:
13-
quant_metadata = metadata.pop(quant_key)
14-
quant_metadata = json.loads(quant_metadata)
15-
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
16-
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
17-
return quant_metadata["layers"]
18-
else:
19-
raise ValueError("Invalid quantization metadata format")
20-
return None
21-
22-
239
def count_blocks(state_dict_keys, prefix_string):
2410
count = 0
2511
while True:
@@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
767753
if model_config is None and use_base_if_no_match:
768754
model_config = comfy.supported_models_base.BASE(unet_config)
769755

770-
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
771-
if scaled_fp8_key in state_dict:
772-
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
773-
model_config.scaled_fp8 = scaled_fp8_weight.dtype
774-
if model_config.scaled_fp8 == torch.float32:
775-
model_config.scaled_fp8 = torch.float8_e4m3fn
776-
if scaled_fp8_weight.nelement() == 2:
777-
model_config.optimizations["fp8"] = False
778-
else:
779-
model_config.optimizations["fp8"] = True
780-
781756
# Detect per-layer quantization (mixed precision)
782-
layer_quant_config = detect_layer_quantization(metadata)
783-
if layer_quant_config:
784-
model_config.layer_quant_config = layer_quant_config
785-
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
757+
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
758+
if quant_config:
759+
model_config.quant_config = quant_config
760+
logging.info("Detected mixed precision quantization")
786761

787762
return model_config
788763

comfy/model_patcher.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,27 +126,11 @@ class LowVramPatch:
126126
def __init__(self, key, patches, convert_func=None, set_func=None):
127127
self.key = key
128128
self.patches = patches
129-
self.convert_func = convert_func
129+
self.convert_func = convert_func # TODO: remove
130130
self.set_func = set_func
131131

132132
def __call__(self, weight):
133-
intermediate_dtype = weight.dtype
134-
if self.convert_func is not None:
135-
weight = self.convert_func(weight, inplace=False)
136-
137-
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
138-
intermediate_dtype = torch.float32
139-
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
140-
if self.set_func is None:
141-
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
142-
else:
143-
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
144-
145-
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
146-
if self.set_func is not None:
147-
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
148-
else:
149-
return out
133+
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
150134

151135
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
152136
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3

comfy/ops.py

Lines changed: 56 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import comfy.float
2424
import comfy.rmsnorm
2525
import contextlib
26+
import json
2627

2728
def run_every_op():
2829
if torch.compiler.is_compiling():
@@ -422,22 +423,12 @@ def fp8_linear(self, input):
422423

423424
if input.ndim == 3 or input.ndim == 2:
424425
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
426+
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
425427

426-
scale_weight = self.scale_weight
427-
scale_input = self.scale_input
428-
if scale_weight is None:
429-
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
430-
else:
431-
scale_weight = scale_weight.to(input.device)
432-
433-
if scale_input is None:
434-
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
435-
input = torch.clamp(input, min=-448, max=448, out=input)
436-
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
437-
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
438-
else:
439-
scale_input = scale_input.to(input.device)
440-
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
428+
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
429+
input = torch.clamp(input, min=-448, max=448, out=input)
430+
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
431+
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
441432

442433
# Wrap weight in QuantizedTensor - this enables unified dispatch
443434
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
@@ -458,7 +449,7 @@ def reset_parameters(self):
458449
return None
459450

460451
def forward_comfy_cast_weights(self, input):
461-
if not self.training:
452+
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
462453
try:
463454
out = fp8_linear(self, input)
464455
if out is not None:
@@ -471,59 +462,6 @@ def forward_comfy_cast_weights(self, input):
471462
uncast_bias_weight(self, weight, bias, offload_stream)
472463
return x
473464

474-
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
475-
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
476-
class scaled_fp8_op(manual_cast):
477-
class Linear(manual_cast.Linear):
478-
def __init__(self, *args, **kwargs):
479-
if override_dtype is not None:
480-
kwargs['dtype'] = override_dtype
481-
super().__init__(*args, **kwargs)
482-
483-
def reset_parameters(self):
484-
if not hasattr(self, 'scale_weight'):
485-
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
486-
487-
if not scale_input:
488-
self.scale_input = None
489-
490-
if not hasattr(self, 'scale_input'):
491-
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
492-
return None
493-
494-
def forward_comfy_cast_weights(self, input):
495-
if fp8_matrix_mult:
496-
out = fp8_linear(self, input)
497-
if out is not None:
498-
return out
499-
500-
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
501-
502-
if weight.numel() < input.numel(): #TODO: optimize
503-
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
504-
else:
505-
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
506-
uncast_bias_weight(self, weight, bias, offload_stream)
507-
return x
508-
509-
def convert_weight(self, weight, inplace=False, **kwargs):
510-
if inplace:
511-
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
512-
return weight
513-
else:
514-
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
515-
516-
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
517-
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
518-
if return_weight:
519-
return weight
520-
if inplace_update:
521-
self.weight.data.copy_(weight)
522-
else:
523-
self.weight = torch.nn.Parameter(weight, requires_grad=False)
524-
525-
return scaled_fp8_op
526-
527465
CUBLAS_IS_AVAILABLE = False
528466
try:
529467
from cublas_ops import CublasLinear
@@ -550,9 +488,9 @@ def forward(self, *args, **kwargs):
550488
from .quant_ops import QuantizedTensor, QUANT_ALGOS
551489

552490

553-
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
491+
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
554492
class MixedPrecisionOps(manual_cast):
555-
_layer_quant_config = layer_quant_config
493+
_quant_config = quant_config
556494
_compute_dtype = compute_dtype
557495
_full_precision_mm = full_precision_mm
558496

@@ -595,27 +533,36 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
595533

596534
manually_loaded_keys = [weight_key]
597535

598-
if layer_name not in MixedPrecisionOps._layer_quant_config:
536+
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
537+
if layer_conf is not None:
538+
layer_conf = json.loads(layer_conf.numpy().tobytes())
539+
540+
if layer_conf is None:
599541
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
600542
else:
601-
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
602-
if quant_format is None:
543+
self.quant_format = layer_conf.get("format", None)
544+
if not self._full_precision_mm:
545+
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
546+
547+
if self.quant_format is None:
603548
raise ValueError(f"Unknown quantization format for layer {layer_name}")
604549

605-
qconfig = QUANT_ALGOS[quant_format]
550+
qconfig = QUANT_ALGOS[self.quant_format]
606551
self.layout_type = qconfig["comfy_tensor_layout"]
607552

608553
weight_scale_key = f"{prefix}weight_scale"
554+
scale = state_dict.pop(weight_scale_key, None)
609555
layout_params = {
610-
'scale': state_dict.pop(weight_scale_key, None),
556+
'scale': scale,
611557
'orig_dtype': MixedPrecisionOps._compute_dtype,
612558
'block_size': qconfig.get("group_size", None),
613559
}
614-
if layout_params['scale'] is not None:
560+
561+
if scale is not None:
615562
manually_loaded_keys.append(weight_scale_key)
616563

617564
self.weight = torch.nn.Parameter(
618-
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
565+
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
619566
requires_grad=False
620567
)
621568

@@ -624,7 +571,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
624571
_v = state_dict.pop(param_key, None)
625572
if _v is None:
626573
continue
627-
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
574+
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
628575
manually_loaded_keys.append(param_key)
629576

630577
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@@ -633,6 +580,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
633580
if key in missing_keys:
634581
missing_keys.remove(key)
635582

583+
def state_dict(self, *args, destination=None, prefix="", **kwargs):
584+
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
585+
if isinstance(self.weight, QuantizedTensor):
586+
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
587+
quant_conf = {"format": self.quant_format}
588+
if self._full_precision_mm:
589+
quant_conf["full_precision_matrix_mult"] = True
590+
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
591+
return sd
592+
636593
def _forward(self, input, weight, bias):
637594
return torch.nn.functional.linear(input, weight, bias)
638595

@@ -648,9 +605,8 @@ def forward(self, input, *args, **kwargs):
648605
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
649606
return self.forward_comfy_cast_weights(input, *args, **kwargs)
650607
if (getattr(self, 'layout_type', None) is not None and
651-
getattr(self, 'input_scale', None) is not None and
652608
not isinstance(input, QuantizedTensor)):
653-
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
609+
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
654610
return self._forward(input, self.weight, self.bias)
655611

656612
def convert_weight(self, weight, inplace=False, **kwargs):
@@ -661,7 +617,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
661617

662618
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
663619
if getattr(self, 'layout_type', None) is not None:
664-
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
620+
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
665621
else:
666622
weight = weight.to(self.weight.dtype)
667623
if return_weight:
@@ -670,17 +626,28 @@ def set_weight(self, weight, inplace_update=False, seed=None, return_weight=Fals
670626
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
671627
self.weight = torch.nn.Parameter(weight, requires_grad=False)
672628

629+
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
630+
if recurse:
631+
for module in self.children():
632+
module._apply(fn)
633+
634+
for key, param in self._parameters.items():
635+
if param is None:
636+
continue
637+
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
638+
for key, buf in self._buffers.items():
639+
if buf is not None:
640+
self._buffers[key] = fn(buf)
641+
return self
642+
673643
return MixedPrecisionOps
674644

675-
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
645+
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
676646
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
677647

678-
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
679-
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
680-
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
681-
682-
if scaled_fp8 is not None:
683-
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
648+
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
649+
logging.info("Using mixed precision operations")
650+
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
684651

685652
if (
686653
fp8_compute and

0 commit comments

Comments
 (0)