Skip to content

Commit 2c03884

Browse files
Skip fp4 matrix mult on devices that don't support it. (#11677)
1 parent 6e9ee55 commit 2c03884

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

comfy/model_management.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None):
15041504

15051505
return True
15061506

1507+
def supports_nvfp4_compute(device=None):
1508+
if not is_nvidia():
1509+
return False
1510+
1511+
props = torch.cuda.get_device_properties(device)
1512+
if props.major < 10:
1513+
return False
1514+
1515+
return True
1516+
15071517
def extended_fp16_support():
15081518
# TODO: check why some models work with fp16 on newer torch versions but not on older
15091519
if torch_version_numeric < (2, 7):

comfy/ops.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,12 @@ def forward(self, *args, **kwargs):
493493
)
494494

495495

496-
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
496+
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
497497
class MixedPrecisionOps(manual_cast):
498498
_quant_config = quant_config
499499
_compute_dtype = compute_dtype
500500
_full_precision_mm = full_precision_mm
501+
_disabled = disabled
501502

502503
class Linear(torch.nn.Module, CastWeightBiasOp):
503504
def __init__(
@@ -522,6 +523,7 @@ def __init__(
522523

523524
self.tensor_class = None
524525
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
526+
self._full_precision_mm_config = False
525527

526528
def reset_parameters(self):
527529
return None
@@ -556,8 +558,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
556558
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
557559
else:
558560
self.quant_format = layer_conf.get("format", None)
561+
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
559562
if not self._full_precision_mm:
560-
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
563+
self._full_precision_mm = self._full_precision_mm_config
564+
565+
if self.quant_format in MixedPrecisionOps._disabled:
566+
self._full_precision_mm = True
561567

562568
if self.quant_format is None:
563569
raise ValueError(f"Unknown quantization format for layer {layer_name}")
@@ -630,7 +636,7 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
630636
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
631637

632638
quant_conf = {"format": self.quant_format}
633-
if self._full_precision_mm:
639+
if self._full_precision_mm_config:
634640
quant_conf["full_precision_matrix_mult"] = True
635641
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
636642
return sd
@@ -711,10 +717,17 @@ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving wei
711717

712718
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
713719
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
720+
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
714721

715722
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
716723
logging.info("Using mixed precision operations")
717-
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
724+
disabled = set()
725+
if not nvfp4_compute:
726+
disabled.add("nvfp4")
727+
if not fp8_compute:
728+
disabled.add("float8_e4m3fn")
729+
disabled.add("float8_e5m2")
730+
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
718731

719732
if (
720733
fp8_compute and

0 commit comments

Comments
 (0)