@@ -412,26 +412,34 @@ def fp8_linear(self, input):
412412 return None
413413
414414 input_dtype = input .dtype
415+ input_shape = input .shape
416+ tensor_3d = input .ndim == 3
415417
416- if input .ndim == 3 or input .ndim == 2 :
417- w , bias , offload_stream = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype , offloadable = True )
418- scale_weight = torch .ones ((), device = input .device , dtype = torch .float32 )
418+ if tensor_3d :
419+ input = input .reshape (- 1 , input_shape [2 ])
419420
420- scale_input = torch .ones ((), device = input .device , dtype = torch .float32 )
421- input = torch .clamp (input , min = - 448 , max = 448 , out = input )
422- layout_params_weight = {'scale' : scale_input , 'orig_dtype' : input_dtype }
423- quantized_input = QuantizedTensor (input .to (dtype ).contiguous (), "TensorCoreFP8Layout" , layout_params_weight )
421+ if input .ndim != 2 :
422+ return None
423+ w , bias , offload_stream = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype , offloadable = True )
424+ scale_weight = torch .ones ((), device = input .device , dtype = torch .float32 )
425+
426+ scale_input = torch .ones ((), device = input .device , dtype = torch .float32 )
427+ input = torch .clamp (input , min = - 448 , max = 448 , out = input )
428+ input_fp8 = input .to (dtype ).contiguous ()
429+ layout_params_input = TensorCoreFP8Layout .Params (scale = scale_input , orig_dtype = input_dtype , orig_shape = tuple (input_fp8 .shape ))
430+ quantized_input = QuantizedTensor (input_fp8 , TensorCoreFP8Layout , layout_params_input )
424431
425- # Wrap weight in QuantizedTensor - this enables unified dispatch
426- # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
427- layout_params_weight = { ' scale' : scale_weight , ' orig_dtype' : input_dtype }
428- quantized_weight = QuantizedTensor (w , " TensorCoreFP8Layout" , layout_params_weight )
429- o = torch .nn .functional .linear (quantized_input , quantized_weight , bias )
432+ # Wrap weight in QuantizedTensor - this enables unified dispatch
433+ # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
434+ layout_params_weight = TensorCoreFP8Layout . Params ( scale = scale_weight , orig_dtype = input_dtype , orig_shape = tuple ( w . shape ))
435+ quantized_weight = QuantizedTensor (w , TensorCoreFP8Layout , layout_params_weight )
436+ o = torch .nn .functional .linear (quantized_input , quantized_weight , bias )
430437
431- uncast_bias_weight (self , w , bias , offload_stream )
432- return o
438+ uncast_bias_weight (self , w , bias , offload_stream )
439+ if tensor_3d :
440+ o = o .reshape ((- 1 , input_shape [1 ], w .shape [0 ]))
433441
434- return None
442+ return o
435443
436444class fp8_ops (manual_cast ):
437445 class Linear (manual_cast .Linear ):
@@ -477,7 +485,15 @@ def forward(self, *args, **kwargs):
477485# ==============================================================================
478486# Mixed Precision Operations
479487# ==============================================================================
480- from .quant_ops import QuantizedTensor , QUANT_ALGOS
488+ from .quant_ops import (
489+ QuantizedTensor ,
490+ QUANT_ALGOS ,
491+ LAYOUTS ,
492+ TensorCoreFP8Layout ,
493+ TensorCoreFP8E4M3Layout ,
494+ TensorCoreFP8E5M2Layout ,
495+ TensorCoreNVFP4Layout
496+ )
481497
482498
483499def mixed_precision_ops (quant_config = {}, compute_dtype = torch .bfloat16 , full_precision_mm = False ):
@@ -497,21 +513,32 @@ def __init__(
497513 ) -> None :
498514 super ().__init__ ()
499515
500- if dtype is None :
501- dtype = MixedPrecisionOps ._compute_dtype
502-
503- self .factory_kwargs = {"device" : device , "dtype" : dtype }
516+ self .factory_kwargs = {"device" : device , "dtype" : MixedPrecisionOps ._compute_dtype }
517+ # self.factory_kwargs = {"device": device, "dtype": dtype}
504518
505519 self .in_features = in_features
506520 self .out_features = out_features
507- self ._has_bias = bias
521+ if bias :
522+ self .bias = torch .nn .Parameter (torch .empty (out_features , ** self .factory_kwargs ))
523+ else :
524+ self .register_parameter ("bias" , None )
508525
509526 self .tensor_class = None
510527 self ._full_precision_mm = MixedPrecisionOps ._full_precision_mm
511528
512529 def reset_parameters (self ):
513530 return None
514531
532+ def _load_scale_param (self , state_dict , prefix , param_name , device , manually_loaded_keys , dtype = None ):
533+ key = f"{ prefix } { param_name } "
534+ value = state_dict .pop (key , None )
535+ if value is not None :
536+ value = value .to (device = device )
537+ if dtype is not None :
538+ value = value .view (dtype = dtype )
539+ manually_loaded_keys .append (key )
540+ return value
541+
515542 def _load_from_state_dict (self , state_dict , prefix , local_metadata ,
516543 strict , missing_keys , unexpected_keys , error_msgs ):
517544
@@ -529,14 +556,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
529556 layer_conf = json .loads (layer_conf .numpy ().tobytes ())
530557
531558 if layer_conf is None :
532- dtype = self .factory_kwargs ["dtype" ]
533- self .weight = torch .nn .Parameter (weight .to (device = device , dtype = dtype ), requires_grad = False )
534- if dtype != MixedPrecisionOps ._compute_dtype :
535- self .comfy_cast_weights = True
536- if self ._has_bias :
537- self .bias = torch .nn .Parameter (torch .empty (self .out_features , device = device , dtype = dtype ))
538- else :
539- self .register_parameter ("bias" , None )
559+ self .weight = torch .nn .Parameter (weight .to (device = device , dtype = MixedPrecisionOps ._compute_dtype ), requires_grad = False )
540560 else :
541561 self .quant_format = layer_conf .get ("format" , None )
542562 if not self ._full_precision_mm :
@@ -547,31 +567,46 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
547567
548568 qconfig = QUANT_ALGOS [self .quant_format ]
549569 self .layout_type = qconfig ["comfy_tensor_layout" ]
550-
551- weight_scale_key = f"{ prefix } weight_scale"
552- scale = state_dict .pop (weight_scale_key , None )
553- if scale is not None :
554- scale = scale .to (device )
555- layout_params = {
556- 'scale' : scale ,
557- 'orig_dtype' : MixedPrecisionOps ._compute_dtype ,
558- 'block_size' : qconfig .get ("group_size" , None ),
559- }
560-
561- if scale is not None :
562- manually_loaded_keys .append (weight_scale_key )
570+ layout_cls = LAYOUTS [self .layout_type ]
571+
572+ # Load format-specific parameters
573+ if self .quant_format in ["float8_e4m3fn" , "float8_e5m2" ]:
574+ # FP8: single tensor scale
575+ scale = self ._load_scale_param (state_dict , prefix , "weight_scale" , device , manually_loaded_keys )
576+
577+ params = layout_cls .Params (
578+ scale = scale ,
579+ orig_dtype = MixedPrecisionOps ._compute_dtype ,
580+ orig_shape = (self .out_features , self .in_features ),
581+ )
582+
583+ elif self .quant_format == "nvfp4" :
584+ # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
585+ tensor_scale = self ._load_scale_param (state_dict , prefix , "weight_scale_2" , device , manually_loaded_keys )
586+ block_scale = self ._load_scale_param (state_dict , prefix , "weight_scale" , device , manually_loaded_keys ,
587+ dtype = torch .float8_e4m3fn )
588+
589+ if tensor_scale is None or block_scale is None :
590+ raise ValueError (f"Missing NVFP4 scales for layer { layer_name } " )
591+
592+ params = layout_cls .Params (
593+ scale = tensor_scale ,
594+ block_scale = block_scale ,
595+ orig_dtype = MixedPrecisionOps ._compute_dtype ,
596+ orig_shape = (self .out_features , self .in_features ),
597+ )
598+ else :
599+ raise ValueError (f"Unsupported quantization format: { self .quant_format } " )
563600
564601 self .weight = torch .nn .Parameter (
565- QuantizedTensor (weight .to (device = device , dtype = qconfig . get ( "storage_t" , None )), self . layout_type , layout_params ),
602+ QuantizedTensor (weight .to (device = device , dtype = qconfig [ "storage_t" ]), layout_cls , params ),
566603 requires_grad = False
567604 )
568605
569- if self ._has_bias :
570- self .bias = torch .nn .Parameter (torch .empty (self .out_features , device = device , dtype = MixedPrecisionOps ._compute_dtype ))
571- else :
572- self .register_parameter ("bias" , None )
573-
574606 for param_name in qconfig ["parameters" ]:
607+ if param_name in {"weight_scale" , "weight_scale_2" }:
608+ continue # Already handled above
609+
575610 param_key = f"{ prefix } { param_name } "
576611 _v = state_dict .pop (param_key , None )
577612 if _v is None :
@@ -588,11 +623,20 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
588623 def state_dict (self , * args , destination = None , prefix = "" , ** kwargs ):
589624 sd = super ().state_dict (* args , destination = destination , prefix = prefix , ** kwargs )
590625 if isinstance (self .weight , QuantizedTensor ):
591- sd ["{}weight_scale" .format (prefix )] = self .weight ._layout_params ['scale' ]
626+ layout_cls = self .weight ._layout_cls
627+
628+ # Check if it's any FP8 variant (E4M3 or E5M2)
629+ if layout_cls in (TensorCoreFP8E4M3Layout , TensorCoreFP8E5M2Layout ) or \
630+ layout_cls .__name__ in ("TensorCoreFP8E4M3Layout" , "TensorCoreFP8E5M2Layout" , "TensorCoreFP8Layout" ):
631+ sd ["{}weight_scale" .format (prefix )] = self .weight ._params .scale
632+ elif layout_cls == TensorCoreNVFP4Layout or layout_cls .__name__ == "TensorCoreNVFP4Layout" :
633+ sd ["{}weight_scale_2" .format (prefix )] = self .weight ._params .scale
634+ sd ["{}weight_scale" .format (prefix )] = self .weight ._params .block_scale
635+
592636 quant_conf = {"format" : self .quant_format }
593637 if self ._full_precision_mm :
594638 quant_conf ["full_precision_matrix_mult" ] = True
595- sd ["{}comfy_quant" .format (prefix )] = torch .tensor ( list ( json .dumps (quant_conf ).encode ('utf-8' ) ), dtype = torch .uint8 )
639+ sd ["{}comfy_quant" .format (prefix )] = torch .frombuffer ( json .dumps (quant_conf ).encode ('utf-8' ), dtype = torch .uint8 )
596640 return sd
597641
598642 def _forward (self , input , weight , bias ):
@@ -607,12 +651,34 @@ def forward_comfy_cast_weights(self, input):
607651 def forward (self , input , * args , ** kwargs ):
608652 run_every_op ()
609653
654+ input_shape = input .shape
655+ tensor_3d = input .ndim == 3
656+
610657 if self ._full_precision_mm or self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
611658 return self .forward_comfy_cast_weights (input , * args , ** kwargs )
659+
612660 if (getattr (self , 'layout_type' , None ) is not None and
613661 not isinstance (input , QuantizedTensor )):
614- input = QuantizedTensor .from_float (input , self .layout_type , scale = getattr (self , 'input_scale' , None ), dtype = self .weight .dtype )
615- return self ._forward (input , self .weight , self .bias )
662+ layout_cls = LAYOUTS [self .layout_type ]
663+
664+ # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
665+ if tensor_3d :
666+ input = input .reshape (- 1 , input_shape [2 ])
667+
668+ if input .ndim != 2 :
669+ # Fall back to comfy_cast_weights for non-2D tensors
670+ return self .forward_comfy_cast_weights (input .reshape (input_shape ), * args , ** kwargs )
671+
672+ # dtype is now implicit in the layout class
673+ input = QuantizedTensor .from_float (input , layout_cls , scale = getattr (self , 'input_scale' , None ))
674+
675+ output = self ._forward (input , self .weight , self .bias )
676+
677+ # Reshape output back to 3D if input was 3D
678+ if tensor_3d :
679+ output = output .reshape ((- 1 , input_shape [1 ], self .weight .shape [0 ]))
680+
681+ return output
616682
617683 def convert_weight (self , weight , inplace = False , ** kwargs ):
618684 if isinstance (weight , QuantizedTensor ):
@@ -622,7 +688,9 @@ def convert_weight(self, weight, inplace=False, **kwargs):
622688
623689 def set_weight (self , weight , inplace_update = False , seed = None , return_weight = False , ** kwargs ):
624690 if getattr (self , 'layout_type' , None ) is not None :
625- weight = QuantizedTensor .from_float (weight , self .layout_type , scale = "recalculate" , dtype = self .weight .dtype , stochastic_rounding = seed , inplace_ops = True )
691+ layout_cls = LAYOUTS [self .layout_type ]
692+ # dtype is now implicit in the layout class
693+ weight = QuantizedTensor .from_float (weight , layout_cls , scale = "recalculate" , stochastic_rounding = seed , inplace_ops = True )
626694 else :
627695 weight = weight .to (self .weight .dtype )
628696 if return_weight :
0 commit comments