Skip to content

Commit d433277

Browse files
committed
remove redundant test_scaled_mm_blockwise_matches_emulation
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent ee4d105 commit d433277

File tree

1 file changed

+0
-60
lines changed

1 file changed

+0
-60
lines changed

thunder/tests/test_ops.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -719,66 +719,6 @@ def test_scaled_mm_blockwise_matches_torch(output_dtype, lhs_block, rhs_block):
719719
assert_close(thunder_out, expected)
720720

721721

722-
@requiresCUDA
723-
@_require_scaled_mm
724-
@pytest.mark.parametrize("output_dtype", [torch.bfloat16])
725-
@pytest.mark.parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
726-
def test_scaled_mm_blockwise_matches_emulation(output_dtype, lhs_block, rhs_block):
727-
device = torch.device("cuda")
728-
_require_fp8_blockwise(device)
729-
730-
M, K, N = 256, 256, 256
731-
mat_a = torch.randn(M, K, device=device, dtype=output_dtype).pow(3)
732-
mat_b_rows = torch.randn(N, K, device=device, dtype=output_dtype).pow(3)
733-
734-
mat_a_lp, encode_a = _blockwise_quantize(mat_a.to(torch.float32), lhs_block, 128)
735-
mat_b_lp_rows, encode_b = _blockwise_quantize(mat_b_rows.to(torch.float32), rhs_block, 128)
736-
mat_b_lp = mat_b_lp_rows.t().contiguous()
737-
738-
scale_a = encode_a.reciprocal().contiguous()
739-
scale_b = encode_b.reciprocal().t().contiguous()
740-
741-
recipe_map = {
742-
1: ScalingType.BlockWise1x128,
743-
128: ScalingType.BlockWise128x128,
744-
}
745-
746-
try:
747-
reference = torch.nn.functional.scaled_mm(
748-
mat_a_lp,
749-
mat_b_lp,
750-
scale_a,
751-
recipe_map[lhs_block],
752-
scale_b,
753-
recipe_map[rhs_block],
754-
swizzle_a=SwizzleType.NO_SWIZZLE,
755-
swizzle_b=SwizzleType.NO_SWIZZLE,
756-
output_dtype=output_dtype,
757-
)
758-
except (RuntimeError, NotImplementedError, ValueError) as exc:
759-
pytest.skip(str(exc))
760-
761-
fn = thunder.jit(
762-
lambda a, b, sa, sb: torch.nn.functional.scaled_mm(
763-
a,
764-
b,
765-
sa,
766-
recipe_map[lhs_block],
767-
sb,
768-
recipe_map[rhs_block],
769-
swizzle_a=SwizzleType.NO_SWIZZLE,
770-
swizzle_b=SwizzleType.NO_SWIZZLE,
771-
output_dtype=output_dtype,
772-
)
773-
)
774-
thunder_out = fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
775-
776-
reference_f32 = reference.to(torch.float32)
777-
thunder_out_f32 = thunder_out.to(torch.float32)
778-
779-
assert_close(thunder_out_f32, reference_f32, atol=7e-1, rtol=7e-2)
780-
781-
782722
# https://github.com/Lightning-AI/lightning-thunder/issues/1857
783723
def test_max_with_int():
784724
def f(x, ids):

0 commit comments

Comments
 (0)