Skip to content

Commit e13728e

Browse files
committed
remove comparison of emulation and reference
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent 81f3f8d commit e13728e

File tree

1 file changed

+0
-21
lines changed

1 file changed

+0
-21
lines changed

thunder/tests/test_ops.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,6 @@ def quantize_to_fp8(tensor):
533533
decode = encode.reciprocal()
534534
return quant, decode, encode
535535

536-
def mm_float8_emulated(mat_a, encode_a, mat_b, encode_b, out_dtype):
537-
a_fp32 = mat_a.to(torch.float32) / encode_a
538-
b_fp32 = mat_b.to(torch.float32) / encode_b
539-
return (a_fp32 @ b_fp32).to(out_dtype)
540-
541536
def scaled_mm_fp8(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
542537
return torch.nn.functional.scaled_mm(
543538
mat_a,
@@ -568,10 +563,7 @@ def scaled_mm_fp8(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
568563
jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_fp8(a, b, sa, sb, out_dtype=torch.float32))
569564
thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b)
570565

571-
emulated = mm_float8_emulated(mat_a_lp, encode_a, mat_b_lp, encode_b, torch.float32)
572-
573566
assert_close(thunder_out, reference)
574-
assert_close(reference, emulated, atol=2e-2, rtol=2e-2)
575567

576568

577569
@requiresCUDA
@@ -628,11 +620,6 @@ def rowwise_quantize(tensor, *, dim):
628620
decode = encode.reciprocal()
629621
return quant, decode, encode
630622

631-
def mm_float8_emulated(mat_a, encode_a, mat_b, encode_b, out_dtype):
632-
a_fp32 = mat_a.to(torch.float32) / encode_a
633-
b_fp32 = mat_b.to(torch.float32) / encode_b
634-
return (a_fp32 @ b_fp32).to(out_dtype)
635-
636623
def scaled_mm_rowwise(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
637624
return torch.nn.functional.scaled_mm(
638625
mat_a,
@@ -661,13 +648,10 @@ def scaled_mm_rowwise(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
661648
jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_rowwise(a, b, sa, sb, out_dtype=torch.bfloat16))
662649
thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b)
663650

664-
emulated = mm_float8_emulated(mat_a_lp, encode_a, mat_b_lp, encode_b, torch.float32)
665-
666651
reference_f32 = reference.to(torch.float32)
667652
thunder_out_f32 = thunder_out.to(torch.float32)
668653

669654
assert_close(thunder_out_f32, reference_f32, atol=3e-2, rtol=3e-2)
670-
assert_close(reference_f32, emulated, atol=3e-2, rtol=3e-2)
671655

672656

673657
def _blockwise_quantize(tensor: torch.Tensor, block_rows: int, block_cols: int) -> tuple[torch.Tensor, torch.Tensor]:
@@ -802,15 +786,10 @@ def test_scaled_mm_blockwise_matches_emulation(output_dtype, lhs_block, rhs_bloc
802786
)
803787
thunder_out = fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
804788

805-
a_dequant = _dequantize_blockwise(mat_a_lp, encode_a, lhs_block, 128)
806-
b_dequant = _dequantize_blockwise(mat_b_lp_rows, encode_b, rhs_block, 128).t()
807-
emulated = (a_dequant @ b_dequant).to(torch.float32)
808-
809789
reference_f32 = reference.to(torch.float32)
810790
thunder_out_f32 = thunder_out.to(torch.float32)
811791

812792
assert_close(thunder_out_f32, reference_f32, atol=7e-1, rtol=7e-2)
813-
assert_close(reference_f32, emulated, atol=7e-1, rtol=7e-2)
814793

815794

816795
# https://github.com/Lightning-AI/lightning-thunder/issues/1857

0 commit comments

Comments
 (0)