Skip to content

Commit af4b7b5

Browse files
More fp8 torch.compile regressions fixed. (#10625)
1 parent 0f4ef3a commit af4b7b5

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

comfy/quant_ops.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,25 @@ def fp8_linear(func, args, kwargs):
446446

447447
return torch.nn.functional.linear(input_tensor, weight, bias)
448448

449+
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
450+
if out_dtype is None:
451+
out_dtype = input_tensor._layout_params['orig_dtype']
452+
453+
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
454+
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
455+
456+
output = torch._scaled_mm(
457+
plain_input.contiguous(),
458+
plain_weight,
459+
bias=bias,
460+
scale_a=scale_a,
461+
scale_b=scale_b,
462+
out_dtype=out_dtype,
463+
)
464+
465+
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
466+
output = output[0]
467+
return output
449468

450469
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
451470
def fp8_addmm(func, args, kwargs):
@@ -454,25 +473,7 @@ def fp8_addmm(func, args, kwargs):
454473
bias = args[0]
455474

456475
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
457-
out_dtype = kwargs.get("out_dtype")
458-
if out_dtype is None:
459-
out_dtype = input_tensor._layout_params['orig_dtype']
460-
461-
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
462-
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
463-
464-
output = torch._scaled_mm(
465-
plain_input.contiguous(),
466-
plain_weight,
467-
bias=bias,
468-
scale_a=scale_a,
469-
scale_b=scale_b,
470-
out_dtype=out_dtype,
471-
)
472-
473-
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
474-
output = output[0]
475-
return output
476+
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
476477

477478
a = list(args)
478479
if isinstance(args[0], QuantizedTensor):
@@ -484,6 +485,21 @@ def fp8_addmm(func, args, kwargs):
484485

485486
return func(*a, **kwargs)
486487

488+
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
489+
def fp8_mm(func, args, kwargs):
490+
input_tensor = args[0]
491+
weight = args[1]
492+
493+
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
494+
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
495+
496+
a = list(args)
497+
if isinstance(args[0], QuantizedTensor):
498+
a[0] = args[0].dequantize()
499+
if isinstance(args[1], QuantizedTensor):
500+
a[1] = args[1].dequantize()
501+
return func(*a, **kwargs)
502+
487503
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
488504
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
489505
def fp8_func(func, args, kwargs):

0 commit comments

Comments
 (0)