Skip to content

Commit 6523e6a

Browse files
comfyanonymousKosinkadink
authored andcommitted
Initial ops changes to use comfy_kitchen: Initial nvfp4 checkpoint support. (Comfy-Org#11635)
--------- Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
1 parent 5da014c commit 6523e6a

File tree

8 files changed

+219
-791
lines changed

8 files changed

+219
-791
lines changed

.github/workflows/test-build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
21+
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
2222
steps:
2323
- uses: actions/checkout@v4
2424
- name: Set up Python ${{ matrix.python-version }}

.github/workflows/test-launch.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ jobs:
3232
working-directory: ComfyUI
3333
- name: Check for unhandled exceptions in server log
3434
run: |
35-
if grep -qE "Exception|Error" console_output.log; then
35+
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
36+
cat console_output_filtered.log
37+
if grep -qE "Exception|Error" console_output_filtered.log; then
3638
echo "Unhandled exception/error found in server log."
3739
exit 1
3840
fi

comfy/model_management.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ def pin_memory(tensor):
11561156
if not tensor.is_contiguous():
11571157
return False
11581158

1159-
size = tensor.numel() * tensor.element_size()
1159+
size = tensor.nbytes
11601160
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
11611161
return False
11621162

@@ -1183,7 +1183,7 @@ def unpin_memory(tensor):
11831183
return False
11841184

11851185
ptr = tensor.data_ptr()
1186-
size = tensor.numel() * tensor.element_size()
1186+
size = tensor.nbytes
11871187

11881188
size_stored = PINNED_MEMORY.get(ptr, None)
11891189
if size_stored is None:

comfy/ops.py

Lines changed: 115 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
7979
if input is not None:
8080
if dtype is None:
8181
if isinstance(input, QuantizedTensor):
82-
dtype = input._layout_params["orig_dtype"]
82+
dtype = input.params.orig_dtype
8383
else:
8484
dtype = input.dtype
8585
if bias_dtype is None:
@@ -412,26 +412,34 @@ def fp8_linear(self, input):
412412
return None
413413

414414
input_dtype = input.dtype
415+
input_shape = input.shape
416+
tensor_3d = input.ndim == 3
415417

416-
if input.ndim == 3 or input.ndim == 2:
417-
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
418-
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
418+
if tensor_3d:
419+
input = input.reshape(-1, input_shape[2])
419420

420-
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
421-
input = torch.clamp(input, min=-448, max=448, out=input)
422-
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
423-
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
421+
if input.ndim != 2:
422+
return None
423+
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
424+
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
425+
426+
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
427+
input = torch.clamp(input, min=-448, max=448, out=input)
428+
input_fp8 = input.to(dtype).contiguous()
429+
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
430+
quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input)
424431

425-
# Wrap weight in QuantizedTensor - this enables unified dispatch
426-
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
427-
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
428-
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
429-
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
432+
# Wrap weight in QuantizedTensor - this enables unified dispatch
433+
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
434+
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
435+
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
436+
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
430437

431-
uncast_bias_weight(self, w, bias, offload_stream)
432-
return o
438+
uncast_bias_weight(self, w, bias, offload_stream)
439+
if tensor_3d:
440+
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
433441

434-
return None
442+
return o
435443

436444
class fp8_ops(manual_cast):
437445
class Linear(manual_cast.Linear):
@@ -477,7 +485,12 @@ def forward(self, *args, **kwargs):
477485
# ==============================================================================
478486
# Mixed Precision Operations
479487
# ==============================================================================
480-
from .quant_ops import QuantizedTensor, QUANT_ALGOS
488+
from .quant_ops import (
489+
QuantizedTensor,
490+
QUANT_ALGOS,
491+
TensorCoreFP8Layout,
492+
get_layout_class,
493+
)
481494

482495

483496
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
@@ -497,21 +510,32 @@ def __init__(
497510
) -> None:
498511
super().__init__()
499512

500-
if dtype is None:
501-
dtype = MixedPrecisionOps._compute_dtype
502-
503-
self.factory_kwargs = {"device": device, "dtype": dtype}
513+
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
514+
# self.factory_kwargs = {"device": device, "dtype": dtype}
504515

505516
self.in_features = in_features
506517
self.out_features = out_features
507-
self._has_bias = bias
518+
if bias:
519+
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
520+
else:
521+
self.register_parameter("bias", None)
508522

509523
self.tensor_class = None
510524
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
511525

512526
def reset_parameters(self):
513527
return None
514528

529+
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
530+
key = f"{prefix}{param_name}"
531+
value = state_dict.pop(key, None)
532+
if value is not None:
533+
value = value.to(device=device)
534+
if dtype is not None:
535+
value = value.view(dtype=dtype)
536+
manually_loaded_keys.append(key)
537+
return value
538+
515539
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
516540
strict, missing_keys, unexpected_keys, error_msgs):
517541

@@ -529,14 +553,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
529553
layer_conf = json.loads(layer_conf.numpy().tobytes())
530554

531555
if layer_conf is None:
532-
dtype = self.factory_kwargs["dtype"]
533-
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
534-
if dtype != MixedPrecisionOps._compute_dtype:
535-
self.comfy_cast_weights = True
536-
if self._has_bias:
537-
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
538-
else:
539-
self.register_parameter("bias", None)
556+
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
540557
else:
541558
self.quant_format = layer_conf.get("format", None)
542559
if not self._full_precision_mm:
@@ -547,31 +564,46 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
547564

548565
qconfig = QUANT_ALGOS[self.quant_format]
549566
self.layout_type = qconfig["comfy_tensor_layout"]
550-
551-
weight_scale_key = f"{prefix}weight_scale"
552-
scale = state_dict.pop(weight_scale_key, None)
553-
if scale is not None:
554-
scale = scale.to(device)
555-
layout_params = {
556-
'scale': scale,
557-
'orig_dtype': MixedPrecisionOps._compute_dtype,
558-
'block_size': qconfig.get("group_size", None),
559-
}
560-
561-
if scale is not None:
562-
manually_loaded_keys.append(weight_scale_key)
567+
layout_cls = get_layout_class(self.layout_type)
568+
569+
# Load format-specific parameters
570+
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
571+
# FP8: single tensor scale
572+
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
573+
574+
params = layout_cls.Params(
575+
scale=scale,
576+
orig_dtype=MixedPrecisionOps._compute_dtype,
577+
orig_shape=(self.out_features, self.in_features),
578+
)
579+
580+
elif self.quant_format == "nvfp4":
581+
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
582+
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
583+
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
584+
dtype=torch.float8_e4m3fn)
585+
586+
if tensor_scale is None or block_scale is None:
587+
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
588+
589+
params = layout_cls.Params(
590+
scale=tensor_scale,
591+
block_scale=block_scale,
592+
orig_dtype=MixedPrecisionOps._compute_dtype,
593+
orig_shape=(self.out_features, self.in_features),
594+
)
595+
else:
596+
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
563597

564598
self.weight = torch.nn.Parameter(
565-
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
599+
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
566600
requires_grad=False
567601
)
568602

569-
if self._has_bias:
570-
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
571-
else:
572-
self.register_parameter("bias", None)
573-
574603
for param_name in qconfig["parameters"]:
604+
if param_name in {"weight_scale", "weight_scale_2"}:
605+
continue # Already handled above
606+
575607
param_key = f"{prefix}{param_name}"
576608
_v = state_dict.pop(param_key, None)
577609
if _v is None:
@@ -588,7 +620,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
588620
def state_dict(self, *args, destination=None, prefix="", **kwargs):
589621
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
590622
if isinstance(self.weight, QuantizedTensor):
591-
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
623+
layout_cls = self.weight._layout_cls
624+
625+
# Check if it's any FP8 variant (E4M3 or E5M2)
626+
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
627+
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
628+
elif layout_cls == "TensorCoreNVFP4Layout":
629+
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
630+
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
631+
592632
quant_conf = {"format": self.quant_format}
593633
if self._full_precision_mm:
594634
quant_conf["full_precision_matrix_mult"] = True
@@ -607,12 +647,33 @@ def forward_comfy_cast_weights(self, input):
607647
def forward(self, input, *args, **kwargs):
608648
run_every_op()
609649

650+
input_shape = input.shape
651+
tensor_3d = input.ndim == 3
652+
610653
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
611654
return self.forward_comfy_cast_weights(input, *args, **kwargs)
655+
612656
if (getattr(self, 'layout_type', None) is not None and
613657
not isinstance(input, QuantizedTensor)):
614-
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
615-
return self._forward(input, self.weight, self.bias)
658+
659+
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
660+
if tensor_3d:
661+
input = input.reshape(-1, input_shape[2])
662+
663+
if input.ndim != 2:
664+
# Fall back to comfy_cast_weights for non-2D tensors
665+
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
666+
667+
# dtype is now implicit in the layout class
668+
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
669+
670+
output = self._forward(input, self.weight, self.bias)
671+
672+
# Reshape output back to 3D if input was 3D
673+
if tensor_3d:
674+
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
675+
676+
return output
616677

617678
def convert_weight(self, weight, inplace=False, **kwargs):
618679
if isinstance(weight, QuantizedTensor):
@@ -622,7 +683,8 @@ def convert_weight(self, weight, inplace=False, **kwargs):
622683

623684
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
624685
if getattr(self, 'layout_type', None) is not None:
625-
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
686+
# dtype is now implicit in the layout class
687+
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
626688
else:
627689
weight = weight.to(self.weight.dtype)
628690
if return_weight:

0 commit comments

Comments
 (0)