@@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
7979 if input is not None :
8080 if dtype is None :
8181 if isinstance (input , QuantizedTensor ):
82- dtype = input ._layout_params [ " orig_dtype" ]
82+ dtype = input .params . orig_dtype
8383 else :
8484 dtype = input .dtype
8585 if bias_dtype is None :
@@ -412,26 +412,34 @@ def fp8_linear(self, input):
412412 return None
413413
414414 input_dtype = input .dtype
415+ input_shape = input .shape
416+ tensor_3d = input .ndim == 3
415417
416- if input .ndim == 3 or input .ndim == 2 :
417- w , bias , offload_stream = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype , offloadable = True )
418- scale_weight = torch .ones ((), device = input .device , dtype = torch .float32 )
418+ if tensor_3d :
419+ input = input .reshape (- 1 , input_shape [2 ])
419420
420- scale_input = torch .ones ((), device = input .device , dtype = torch .float32 )
421- input = torch .clamp (input , min = - 448 , max = 448 , out = input )
422- layout_params_weight = {'scale' : scale_input , 'orig_dtype' : input_dtype }
423- quantized_input = QuantizedTensor (input .to (dtype ).contiguous (), "TensorCoreFP8Layout" , layout_params_weight )
421+ if input .ndim != 2 :
422+ return None
423+ w , bias , offload_stream = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype , offloadable = True )
424+ scale_weight = torch .ones ((), device = input .device , dtype = torch .float32 )
425+
426+ scale_input = torch .ones ((), device = input .device , dtype = torch .float32 )
427+ input = torch .clamp (input , min = - 448 , max = 448 , out = input )
428+ input_fp8 = input .to (dtype ).contiguous ()
429+ layout_params_input = TensorCoreFP8Layout .Params (scale = scale_input , orig_dtype = input_dtype , orig_shape = tuple (input_fp8 .shape ))
430+ quantized_input = QuantizedTensor (input_fp8 , TensorCoreFP8Layout , layout_params_input )
424431
425- # Wrap weight in QuantizedTensor - this enables unified dispatch
426- # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
427- layout_params_weight = { ' scale' : scale_weight , ' orig_dtype' : input_dtype }
428- quantized_weight = QuantizedTensor (w , " TensorCoreFP8Layout" , layout_params_weight )
429- o = torch .nn .functional .linear (quantized_input , quantized_weight , bias )
432+ # Wrap weight in QuantizedTensor - this enables unified dispatch
433+ # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
434+ layout_params_weight = TensorCoreFP8Layout . Params ( scale = scale_weight , orig_dtype = input_dtype , orig_shape = tuple ( w . shape ))
435+ quantized_weight = QuantizedTensor (w , TensorCoreFP8Layout , layout_params_weight )
436+ o = torch .nn .functional .linear (quantized_input , quantized_weight , bias )
430437
431- uncast_bias_weight (self , w , bias , offload_stream )
432- return o
438+ uncast_bias_weight (self , w , bias , offload_stream )
439+ if tensor_3d :
440+ o = o .reshape ((input_shape [0 ], input_shape [1 ], w .shape [0 ]))
433441
434- return None
442+ return o
435443
436444class fp8_ops (manual_cast ):
437445 class Linear (manual_cast .Linear ):
@@ -477,7 +485,12 @@ def forward(self, *args, **kwargs):
477485# ==============================================================================
478486# Mixed Precision Operations
479487# ==============================================================================
480- from .quant_ops import QuantizedTensor , QUANT_ALGOS
488+ from .quant_ops import (
489+ QuantizedTensor ,
490+ QUANT_ALGOS ,
491+ TensorCoreFP8Layout ,
492+ get_layout_class ,
493+ )
481494
482495
483496def mixed_precision_ops (quant_config = {}, compute_dtype = torch .bfloat16 , full_precision_mm = False ):
@@ -497,21 +510,32 @@ def __init__(
497510 ) -> None :
498511 super ().__init__ ()
499512
500- if dtype is None :
501- dtype = MixedPrecisionOps ._compute_dtype
502-
503- self .factory_kwargs = {"device" : device , "dtype" : dtype }
513+ self .factory_kwargs = {"device" : device , "dtype" : MixedPrecisionOps ._compute_dtype }
514+ # self.factory_kwargs = {"device": device, "dtype": dtype}
504515
505516 self .in_features = in_features
506517 self .out_features = out_features
507- self ._has_bias = bias
518+ if bias :
519+ self .bias = torch .nn .Parameter (torch .empty (out_features , ** self .factory_kwargs ))
520+ else :
521+ self .register_parameter ("bias" , None )
508522
509523 self .tensor_class = None
510524 self ._full_precision_mm = MixedPrecisionOps ._full_precision_mm
511525
512526 def reset_parameters (self ):
513527 return None
514528
529+ def _load_scale_param (self , state_dict , prefix , param_name , device , manually_loaded_keys , dtype = None ):
530+ key = f"{ prefix } { param_name } "
531+ value = state_dict .pop (key , None )
532+ if value is not None :
533+ value = value .to (device = device )
534+ if dtype is not None :
535+ value = value .view (dtype = dtype )
536+ manually_loaded_keys .append (key )
537+ return value
538+
515539 def _load_from_state_dict (self , state_dict , prefix , local_metadata ,
516540 strict , missing_keys , unexpected_keys , error_msgs ):
517541
@@ -529,14 +553,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
529553 layer_conf = json .loads (layer_conf .numpy ().tobytes ())
530554
531555 if layer_conf is None :
532- dtype = self .factory_kwargs ["dtype" ]
533- self .weight = torch .nn .Parameter (weight .to (device = device , dtype = dtype ), requires_grad = False )
534- if dtype != MixedPrecisionOps ._compute_dtype :
535- self .comfy_cast_weights = True
536- if self ._has_bias :
537- self .bias = torch .nn .Parameter (torch .empty (self .out_features , device = device , dtype = dtype ))
538- else :
539- self .register_parameter ("bias" , None )
556+ self .weight = torch .nn .Parameter (weight .to (device = device , dtype = MixedPrecisionOps ._compute_dtype ), requires_grad = False )
540557 else :
541558 self .quant_format = layer_conf .get ("format" , None )
542559 if not self ._full_precision_mm :
@@ -547,31 +564,46 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
547564
548565 qconfig = QUANT_ALGOS [self .quant_format ]
549566 self .layout_type = qconfig ["comfy_tensor_layout" ]
550-
551- weight_scale_key = f"{ prefix } weight_scale"
552- scale = state_dict .pop (weight_scale_key , None )
553- if scale is not None :
554- scale = scale .to (device )
555- layout_params = {
556- 'scale' : scale ,
557- 'orig_dtype' : MixedPrecisionOps ._compute_dtype ,
558- 'block_size' : qconfig .get ("group_size" , None ),
559- }
560-
561- if scale is not None :
562- manually_loaded_keys .append (weight_scale_key )
567+ layout_cls = get_layout_class (self .layout_type )
568+
569+ # Load format-specific parameters
570+ if self .quant_format in ["float8_e4m3fn" , "float8_e5m2" ]:
571+ # FP8: single tensor scale
572+ scale = self ._load_scale_param (state_dict , prefix , "weight_scale" , device , manually_loaded_keys )
573+
574+ params = layout_cls .Params (
575+ scale = scale ,
576+ orig_dtype = MixedPrecisionOps ._compute_dtype ,
577+ orig_shape = (self .out_features , self .in_features ),
578+ )
579+
580+ elif self .quant_format == "nvfp4" :
581+ # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
582+ tensor_scale = self ._load_scale_param (state_dict , prefix , "weight_scale_2" , device , manually_loaded_keys )
583+ block_scale = self ._load_scale_param (state_dict , prefix , "weight_scale" , device , manually_loaded_keys ,
584+ dtype = torch .float8_e4m3fn )
585+
586+ if tensor_scale is None or block_scale is None :
587+ raise ValueError (f"Missing NVFP4 scales for layer { layer_name } " )
588+
589+ params = layout_cls .Params (
590+ scale = tensor_scale ,
591+ block_scale = block_scale ,
592+ orig_dtype = MixedPrecisionOps ._compute_dtype ,
593+ orig_shape = (self .out_features , self .in_features ),
594+ )
595+ else :
596+ raise ValueError (f"Unsupported quantization format: { self .quant_format } " )
563597
564598 self .weight = torch .nn .Parameter (
565- QuantizedTensor (weight .to (device = device , dtype = qconfig . get ( "storage_t" , None )) , self .layout_type , layout_params ),
599+ QuantizedTensor (weight .to (device = device , dtype = qconfig [ "storage_t" ]) , self .layout_type , params ),
566600 requires_grad = False
567601 )
568602
569- if self ._has_bias :
570- self .bias = torch .nn .Parameter (torch .empty (self .out_features , device = device , dtype = MixedPrecisionOps ._compute_dtype ))
571- else :
572- self .register_parameter ("bias" , None )
573-
574603 for param_name in qconfig ["parameters" ]:
604+ if param_name in {"weight_scale" , "weight_scale_2" }:
605+ continue # Already handled above
606+
575607 param_key = f"{ prefix } { param_name } "
576608 _v = state_dict .pop (param_key , None )
577609 if _v is None :
@@ -588,7 +620,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
588620 def state_dict (self , * args , destination = None , prefix = "" , ** kwargs ):
589621 sd = super ().state_dict (* args , destination = destination , prefix = prefix , ** kwargs )
590622 if isinstance (self .weight , QuantizedTensor ):
591- sd ["{}weight_scale" .format (prefix )] = self .weight ._layout_params ['scale' ]
623+ layout_cls = self .weight ._layout_cls
624+
625+ # Check if it's any FP8 variant (E4M3 or E5M2)
626+ if layout_cls in ("TensorCoreFP8E4M3Layout" , "TensorCoreFP8E5M2Layout" , "TensorCoreFP8Layout" ):
627+ sd ["{}weight_scale" .format (prefix )] = self .weight ._params .scale
628+ elif layout_cls == "TensorCoreNVFP4Layout" :
629+ sd ["{}weight_scale_2" .format (prefix )] = self .weight ._params .scale
630+ sd ["{}weight_scale" .format (prefix )] = self .weight ._params .block_scale
631+
592632 quant_conf = {"format" : self .quant_format }
593633 if self ._full_precision_mm :
594634 quant_conf ["full_precision_matrix_mult" ] = True
@@ -607,12 +647,33 @@ def forward_comfy_cast_weights(self, input):
607647 def forward (self , input , * args , ** kwargs ):
608648 run_every_op ()
609649
650+ input_shape = input .shape
651+ tensor_3d = input .ndim == 3
652+
610653 if self ._full_precision_mm or self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
611654 return self .forward_comfy_cast_weights (input , * args , ** kwargs )
655+
612656 if (getattr (self , 'layout_type' , None ) is not None and
613657 not isinstance (input , QuantizedTensor )):
614- input = QuantizedTensor .from_float (input , self .layout_type , scale = getattr (self , 'input_scale' , None ), dtype = self .weight .dtype )
615- return self ._forward (input , self .weight , self .bias )
658+
659+ # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
660+ if tensor_3d :
661+ input = input .reshape (- 1 , input_shape [2 ])
662+
663+ if input .ndim != 2 :
664+ # Fall back to comfy_cast_weights for non-2D tensors
665+ return self .forward_comfy_cast_weights (input .reshape (input_shape ), * args , ** kwargs )
666+
667+ # dtype is now implicit in the layout class
668+ input = QuantizedTensor .from_float (input , self .layout_type , scale = getattr (self , 'input_scale' , None ))
669+
670+ output = self ._forward (input , self .weight , self .bias )
671+
672+ # Reshape output back to 3D if input was 3D
673+ if tensor_3d :
674+ output = output .reshape ((input_shape [0 ], input_shape [1 ], self .weight .shape [0 ]))
675+
676+ return output
616677
617678 def convert_weight (self , weight , inplace = False , ** kwargs ):
618679 if isinstance (weight , QuantizedTensor ):
@@ -622,7 +683,8 @@ def convert_weight(self, weight, inplace=False, **kwargs):
622683
623684 def set_weight (self , weight , inplace_update = False , seed = None , return_weight = False , ** kwargs ):
624685 if getattr (self , 'layout_type' , None ) is not None :
625- weight = QuantizedTensor .from_float (weight , self .layout_type , scale = "recalculate" , dtype = self .weight .dtype , stochastic_rounding = seed , inplace_ops = True )
686+ # dtype is now implicit in the layout class
687+ weight = QuantizedTensor .from_float (weight , self .layout_type , scale = "recalculate" , stochastic_rounding = seed , inplace_ops = True )
626688 else :
627689 weight = weight .to (self .weight .dtype )
628690 if return_weight :
0 commit comments