@@ -533,11 +533,6 @@ def quantize_to_fp8(tensor):
533533 decode = encode .reciprocal ()
534534 return quant , decode , encode
535535
536- def mm_float8_emulated (mat_a , encode_a , mat_b , encode_b , out_dtype ):
537- a_fp32 = mat_a .to (torch .float32 ) / encode_a
538- b_fp32 = mat_b .to (torch .float32 ) / encode_b
539- return (a_fp32 @ b_fp32 ).to (out_dtype )
540-
541536 def scaled_mm_fp8 (mat_a , mat_b , scale_a , scale_b , * , out_dtype ):
542537 return torch .nn .functional .scaled_mm (
543538 mat_a ,
@@ -568,10 +563,7 @@ def scaled_mm_fp8(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
568563 jf = thunder .jit (lambda a , b , sa , sb : scaled_mm_fp8 (a , b , sa , sb , out_dtype = torch .float32 ))
569564 thunder_out = jf (mat_a_lp , mat_b_lp , decode_a , decode_b )
570565
571- emulated = mm_float8_emulated (mat_a_lp , encode_a , mat_b_lp , encode_b , torch .float32 )
572-
573566 assert_close (thunder_out , reference )
574- assert_close (reference , emulated , atol = 2e-2 , rtol = 2e-2 )
575567
576568
577569@requiresCUDA
@@ -628,11 +620,6 @@ def rowwise_quantize(tensor, *, dim):
628620 decode = encode .reciprocal ()
629621 return quant , decode , encode
630622
631- def mm_float8_emulated (mat_a , encode_a , mat_b , encode_b , out_dtype ):
632- a_fp32 = mat_a .to (torch .float32 ) / encode_a
633- b_fp32 = mat_b .to (torch .float32 ) / encode_b
634- return (a_fp32 @ b_fp32 ).to (out_dtype )
635-
636623 def scaled_mm_rowwise (mat_a , mat_b , scale_a , scale_b , * , out_dtype ):
637624 return torch .nn .functional .scaled_mm (
638625 mat_a ,
@@ -661,13 +648,10 @@ def scaled_mm_rowwise(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
661648 jf = thunder .jit (lambda a , b , sa , sb : scaled_mm_rowwise (a , b , sa , sb , out_dtype = torch .bfloat16 ))
662649 thunder_out = jf (mat_a_lp , mat_b_lp , decode_a , decode_b )
663650
664- emulated = mm_float8_emulated (mat_a_lp , encode_a , mat_b_lp , encode_b , torch .float32 )
665-
666651 reference_f32 = reference .to (torch .float32 )
667652 thunder_out_f32 = thunder_out .to (torch .float32 )
668653
669654 assert_close (thunder_out_f32 , reference_f32 , atol = 3e-2 , rtol = 3e-2 )
670- assert_close (reference_f32 , emulated , atol = 3e-2 , rtol = 3e-2 )
671655
672656
673657def _blockwise_quantize (tensor : torch .Tensor , block_rows : int , block_cols : int ) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -802,15 +786,10 @@ def test_scaled_mm_blockwise_matches_emulation(output_dtype, lhs_block, rhs_bloc
802786 )
803787 thunder_out = fn (mat_a_lp , mat_b_lp , scale_a , scale_b )
804788
805- a_dequant = _dequantize_blockwise (mat_a_lp , encode_a , lhs_block , 128 )
806- b_dequant = _dequantize_blockwise (mat_b_lp_rows , encode_b , rhs_block , 128 ).t ()
807- emulated = (a_dequant @ b_dequant ).to (torch .float32 )
808-
809789 reference_f32 = reference .to (torch .float32 )
810790 thunder_out_f32 = thunder_out .to (torch .float32 )
811791
812792 assert_close (thunder_out_f32 , reference_f32 , atol = 7e-1 , rtol = 7e-2 )
813- assert_close (reference_f32 , emulated , atol = 7e-1 , rtol = 7e-2 )
814793
815794
816795# https://github.com/Lightning-AI/lightning-thunder/issues/1857
0 commit comments