@@ -493,11 +493,12 @@ def forward(self, *args, **kwargs):
493493)
494494
495495
496- def mixed_precision_ops (quant_config = {}, compute_dtype = torch .bfloat16 , full_precision_mm = False ):
496+ def mixed_precision_ops (quant_config = {}, compute_dtype = torch .bfloat16 , full_precision_mm = False , disabled = [] ):
497497 class MixedPrecisionOps (manual_cast ):
498498 _quant_config = quant_config
499499 _compute_dtype = compute_dtype
500500 _full_precision_mm = full_precision_mm
501+ _disabled = disabled
501502
502503 class Linear (torch .nn .Module , CastWeightBiasOp ):
503504 def __init__ (
@@ -522,6 +523,7 @@ def __init__(
522523
523524 self .tensor_class = None
524525 self ._full_precision_mm = MixedPrecisionOps ._full_precision_mm
526+ self ._full_precision_mm_config = False
525527
526528 def reset_parameters (self ):
527529 return None
@@ -556,8 +558,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
556558 self .weight = torch .nn .Parameter (weight .to (device = device , dtype = MixedPrecisionOps ._compute_dtype ), requires_grad = False )
557559 else :
558560 self .quant_format = layer_conf .get ("format" , None )
561+ self ._full_precision_mm_config = layer_conf .get ("full_precision_matrix_mult" , False )
559562 if not self ._full_precision_mm :
560- self ._full_precision_mm = layer_conf .get ("full_precision_matrix_mult" , False )
563+ self ._full_precision_mm = self ._full_precision_mm_config
564+
565+ if self .quant_format in MixedPrecisionOps ._disabled :
566+ self ._full_precision_mm = True
561567
562568 if self .quant_format is None :
563569 raise ValueError (f"Unknown quantization format for layer { layer_name } " )
@@ -630,7 +636,7 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
630636 sd ["{}weight_scale" .format (prefix )] = self .weight ._params .block_scale
631637
632638 quant_conf = {"format" : self .quant_format }
633- if self ._full_precision_mm :
639+ if self ._full_precision_mm_config :
634640 quant_conf ["full_precision_matrix_mult" ] = True
635641 sd ["{}comfy_quant" .format (prefix )] = torch .tensor (list (json .dumps (quant_conf ).encode ('utf-8' )), dtype = torch .uint8 )
636642 return sd
@@ -711,10 +717,17 @@ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving wei
711717
712718def pick_operations (weight_dtype , compute_dtype , load_device = None , disable_fast_fp8 = False , fp8_optimizations = False , model_config = None ):
713719 fp8_compute = comfy .model_management .supports_fp8_compute (load_device ) # TODO: if we support more ops this needs to be more granular
720+ nvfp4_compute = comfy .model_management .supports_nvfp4_compute (load_device )
714721
715722 if model_config and hasattr (model_config , 'quant_config' ) and model_config .quant_config :
716723 logging .info ("Using mixed precision operations" )
717- return mixed_precision_ops (model_config .quant_config , compute_dtype , full_precision_mm = not fp8_compute )
724+ disabled = set ()
725+ if not nvfp4_compute :
726+ disabled .add ("nvfp4" )
727+ if not fp8_compute :
728+ disabled .add ("float8_e4m3fn" )
729+ disabled .add ("float8_e5m2" )
730+ return mixed_precision_ops (model_config .quant_config , compute_dtype , disabled = disabled )
718731
719732 if (
720733 fp8_compute and
0 commit comments