@@ -446,6 +446,25 @@ def fp8_linear(func, args, kwargs):
446446
447447 return torch .nn .functional .linear (input_tensor , weight , bias )
448448
449+ def fp8_mm_ (input_tensor , weight , bias = None , out_dtype = None ):
450+ if out_dtype is None :
451+ out_dtype = input_tensor ._layout_params ['orig_dtype' ]
452+
453+ plain_input , scale_a = TensorCoreFP8Layout .get_plain_tensors (input_tensor )
454+ plain_weight , scale_b = TensorCoreFP8Layout .get_plain_tensors (weight )
455+
456+ output = torch ._scaled_mm (
457+ plain_input .contiguous (),
458+ plain_weight ,
459+ bias = bias ,
460+ scale_a = scale_a ,
461+ scale_b = scale_b ,
462+ out_dtype = out_dtype ,
463+ )
464+
465+ if isinstance (output , tuple ): # TODO: remove when we drop support for torch 2.4
466+ output = output [0 ]
467+ return output
449468
450469@register_layout_op (torch .ops .aten .addmm .default , "TensorCoreFP8Layout" )
451470def fp8_addmm (func , args , kwargs ):
@@ -454,25 +473,7 @@ def fp8_addmm(func, args, kwargs):
454473 bias = args [0 ]
455474
456475 if isinstance (input_tensor , QuantizedTensor ) and isinstance (weight , QuantizedTensor ):
457- out_dtype = kwargs .get ("out_dtype" )
458- if out_dtype is None :
459- out_dtype = input_tensor ._layout_params ['orig_dtype' ]
460-
461- plain_input , scale_a = TensorCoreFP8Layout .get_plain_tensors (input_tensor )
462- plain_weight , scale_b = TensorCoreFP8Layout .get_plain_tensors (weight )
463-
464- output = torch ._scaled_mm (
465- plain_input .contiguous (),
466- plain_weight ,
467- bias = bias ,
468- scale_a = scale_a ,
469- scale_b = scale_b ,
470- out_dtype = out_dtype ,
471- )
472-
473- if isinstance (output , tuple ): # TODO: remove when we drop support for torch 2.4
474- output = output [0 ]
475- return output
476+ return fp8_mm_ (input_tensor , weight , bias = bias , out_dtype = kwargs .get ("out_dtype" , None ))
476477
477478 a = list (args )
478479 if isinstance (args [0 ], QuantizedTensor ):
@@ -484,6 +485,21 @@ def fp8_addmm(func, args, kwargs):
484485
485486 return func (* a , ** kwargs )
486487
488+ @register_layout_op (torch .ops .aten .mm .default , "TensorCoreFP8Layout" )
489+ def fp8_mm (func , args , kwargs ):
490+ input_tensor = args [0 ]
491+ weight = args [1 ]
492+
493+ if isinstance (input_tensor , QuantizedTensor ) and isinstance (weight , QuantizedTensor ):
494+ return fp8_mm_ (input_tensor , weight , bias = None , out_dtype = kwargs .get ("out_dtype" , None ))
495+
496+ a = list (args )
497+ if isinstance (args [0 ], QuantizedTensor ):
498+ a [0 ] = args [0 ].dequantize ()
499+ if isinstance (args [1 ], QuantizedTensor ):
500+ a [1 ] = args [1 ].dequantize ()
501+ return func (* a , ** kwargs )
502+
487503@register_layout_op (torch .ops .aten .view .default , "TensorCoreFP8Layout" )
488504@register_layout_op (torch .ops .aten .t .default , "TensorCoreFP8Layout" )
489505def fp8_func (func , args , kwargs ):
0 commit comments