Skip to content

Commit fc3c52b

Browse files
committed
Initial ops changes to use comfy_kitchen, more changes needed. Right now requires comfy_kitchen package to be installed but pip package is not out
1 parent 809ce68 commit fc3c52b

File tree

3 files changed

+236
-587
lines changed

3 files changed

+236
-587
lines changed

comfy/model_management.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,7 @@ def pin_memory(tensor):
11461146
if not tensor.is_contiguous():
11471147
return False
11481148

1149-
size = tensor.numel() * tensor.element_size()
1149+
size = tensor.nbytes
11501150
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
11511151
return False
11521152

@@ -1170,7 +1170,7 @@ def unpin_memory(tensor):
11701170
return False
11711171

11721172
ptr = tensor.data_ptr()
1173-
size = tensor.numel() * tensor.element_size()
1173+
size = tensor.nbytes
11741174

11751175
size_stored = PINNED_MEMORY.get(ptr, None)
11761176
if size_stored is None:

comfy/ops.py

Lines changed: 121 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -412,26 +412,34 @@ def fp8_linear(self, input):
412412
return None
413413

414414
input_dtype = input.dtype
415+
input_shape = input.shape
416+
tensor_3d = input.ndim == 3
415417

416-
if input.ndim == 3 or input.ndim == 2:
417-
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
418-
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
418+
if tensor_3d:
419+
input = input.reshape(-1, input_shape[2])
419420

420-
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
421-
input = torch.clamp(input, min=-448, max=448, out=input)
422-
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
423-
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
421+
if input.ndim != 2:
422+
return None
423+
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
424+
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
425+
426+
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
427+
input = torch.clamp(input, min=-448, max=448, out=input)
428+
input_fp8 = input.to(dtype).contiguous()
429+
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
430+
quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input)
424431

425-
# Wrap weight in QuantizedTensor - this enables unified dispatch
426-
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
427-
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
428-
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
429-
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
432+
# Wrap weight in QuantizedTensor - this enables unified dispatch
433+
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
434+
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
435+
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
436+
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
430437

431-
uncast_bias_weight(self, w, bias, offload_stream)
432-
return o
438+
uncast_bias_weight(self, w, bias, offload_stream)
439+
if tensor_3d:
440+
o = o.reshape((-1, input_shape[1], w.shape[0]))
433441

434-
return None
442+
return o
435443

436444
class fp8_ops(manual_cast):
437445
class Linear(manual_cast.Linear):
@@ -477,7 +485,15 @@ def forward(self, *args, **kwargs):
477485
# ==============================================================================
478486
# Mixed Precision Operations
479487
# ==============================================================================
480-
from .quant_ops import QuantizedTensor, QUANT_ALGOS
488+
from .quant_ops import (
489+
QuantizedTensor,
490+
QUANT_ALGOS,
491+
LAYOUTS,
492+
TensorCoreFP8Layout,
493+
TensorCoreFP8E4M3Layout,
494+
TensorCoreFP8E5M2Layout,
495+
TensorCoreNVFP4Layout
496+
)
481497

482498

483499
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
@@ -497,21 +513,32 @@ def __init__(
497513
) -> None:
498514
super().__init__()
499515

500-
if dtype is None:
501-
dtype = MixedPrecisionOps._compute_dtype
502-
503-
self.factory_kwargs = {"device": device, "dtype": dtype}
516+
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
517+
# self.factory_kwargs = {"device": device, "dtype": dtype}
504518

505519
self.in_features = in_features
506520
self.out_features = out_features
507-
self._has_bias = bias
521+
if bias:
522+
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
523+
else:
524+
self.register_parameter("bias", None)
508525

509526
self.tensor_class = None
510527
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
511528

512529
def reset_parameters(self):
513530
return None
514531

532+
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
533+
key = f"{prefix}{param_name}"
534+
value = state_dict.pop(key, None)
535+
if value is not None:
536+
value = value.to(device=device)
537+
if dtype is not None:
538+
value = value.view(dtype=dtype)
539+
manually_loaded_keys.append(key)
540+
return value
541+
515542
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
516543
strict, missing_keys, unexpected_keys, error_msgs):
517544

@@ -529,14 +556,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
529556
layer_conf = json.loads(layer_conf.numpy().tobytes())
530557

531558
if layer_conf is None:
532-
dtype = self.factory_kwargs["dtype"]
533-
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
534-
if dtype != MixedPrecisionOps._compute_dtype:
535-
self.comfy_cast_weights = True
536-
if self._has_bias:
537-
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
538-
else:
539-
self.register_parameter("bias", None)
559+
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
540560
else:
541561
self.quant_format = layer_conf.get("format", None)
542562
if not self._full_precision_mm:
@@ -547,31 +567,46 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
547567

548568
qconfig = QUANT_ALGOS[self.quant_format]
549569
self.layout_type = qconfig["comfy_tensor_layout"]
550-
551-
weight_scale_key = f"{prefix}weight_scale"
552-
scale = state_dict.pop(weight_scale_key, None)
553-
if scale is not None:
554-
scale = scale.to(device)
555-
layout_params = {
556-
'scale': scale,
557-
'orig_dtype': MixedPrecisionOps._compute_dtype,
558-
'block_size': qconfig.get("group_size", None),
559-
}
560-
561-
if scale is not None:
562-
manually_loaded_keys.append(weight_scale_key)
570+
layout_cls = LAYOUTS[self.layout_type]
571+
572+
# Load format-specific parameters
573+
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
574+
# FP8: single tensor scale
575+
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
576+
577+
params = layout_cls.Params(
578+
scale=scale,
579+
orig_dtype=MixedPrecisionOps._compute_dtype,
580+
orig_shape=(self.out_features, self.in_features),
581+
)
582+
583+
elif self.quant_format == "nvfp4":
584+
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
585+
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
586+
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
587+
dtype=torch.float8_e4m3fn)
588+
589+
if tensor_scale is None or block_scale is None:
590+
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
591+
592+
params = layout_cls.Params(
593+
scale=tensor_scale,
594+
block_scale=block_scale,
595+
orig_dtype=MixedPrecisionOps._compute_dtype,
596+
orig_shape=(self.out_features, self.in_features),
597+
)
598+
else:
599+
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
563600

564601
self.weight = torch.nn.Parameter(
565-
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
602+
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), layout_cls, params),
566603
requires_grad=False
567604
)
568605

569-
if self._has_bias:
570-
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
571-
else:
572-
self.register_parameter("bias", None)
573-
574606
for param_name in qconfig["parameters"]:
607+
if param_name in {"weight_scale", "weight_scale_2"}:
608+
continue # Already handled above
609+
575610
param_key = f"{prefix}{param_name}"
576611
_v = state_dict.pop(param_key, None)
577612
if _v is None:
@@ -588,11 +623,20 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
588623
def state_dict(self, *args, destination=None, prefix="", **kwargs):
589624
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
590625
if isinstance(self.weight, QuantizedTensor):
591-
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
626+
layout_cls = self.weight._layout_cls
627+
628+
# Check if it's any FP8 variant (E4M3 or E5M2)
629+
if layout_cls in (TensorCoreFP8E4M3Layout, TensorCoreFP8E5M2Layout) or \
630+
layout_cls.__name__ in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
631+
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
632+
elif layout_cls == TensorCoreNVFP4Layout or layout_cls.__name__ == "TensorCoreNVFP4Layout":
633+
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
634+
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
635+
592636
quant_conf = {"format": self.quant_format}
593637
if self._full_precision_mm:
594638
quant_conf["full_precision_matrix_mult"] = True
595-
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
639+
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
596640
return sd
597641

598642
def _forward(self, input, weight, bias):
@@ -607,12 +651,34 @@ def forward_comfy_cast_weights(self, input):
607651
def forward(self, input, *args, **kwargs):
608652
run_every_op()
609653

654+
input_shape = input.shape
655+
tensor_3d = input.ndim == 3
656+
610657
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
611658
return self.forward_comfy_cast_weights(input, *args, **kwargs)
659+
612660
if (getattr(self, 'layout_type', None) is not None and
613661
not isinstance(input, QuantizedTensor)):
614-
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
615-
return self._forward(input, self.weight, self.bias)
662+
layout_cls = LAYOUTS[self.layout_type]
663+
664+
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
665+
if tensor_3d:
666+
input = input.reshape(-1, input_shape[2])
667+
668+
if input.ndim != 2:
669+
# Fall back to comfy_cast_weights for non-2D tensors
670+
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
671+
672+
# dtype is now implicit in the layout class
673+
input = QuantizedTensor.from_float(input, layout_cls, scale=getattr(self, 'input_scale', None))
674+
675+
output = self._forward(input, self.weight, self.bias)
676+
677+
# Reshape output back to 3D if input was 3D
678+
if tensor_3d:
679+
output = output.reshape((-1, input_shape[1], self.weight.shape[0]))
680+
681+
return output
616682

617683
def convert_weight(self, weight, inplace=False, **kwargs):
618684
if isinstance(weight, QuantizedTensor):
@@ -622,7 +688,9 @@ def convert_weight(self, weight, inplace=False, **kwargs):
622688

623689
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
624690
if getattr(self, 'layout_type', None) is not None:
625-
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
691+
layout_cls = LAYOUTS[self.layout_type]
692+
# dtype is now implicit in the layout class
693+
weight = QuantizedTensor.from_float(weight, layout_cls, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
626694
else:
627695
weight = weight.to(self.weight.dtype)
628696
if return_weight:

0 commit comments

Comments
 (0)