@@ -719,66 +719,6 @@ def test_scaled_mm_blockwise_matches_torch(output_dtype, lhs_block, rhs_block):
719719 assert_close (thunder_out , expected )
720720
721721
722- @requiresCUDA
723- @_require_scaled_mm
724- @pytest .mark .parametrize ("output_dtype" , [torch .bfloat16 ])
725- @pytest .mark .parametrize ("lhs_block,rhs_block" , [(1 , 1 ), (128 , 1 ), (1 , 128 )])
726- def test_scaled_mm_blockwise_matches_emulation (output_dtype , lhs_block , rhs_block ):
727- device = torch .device ("cuda" )
728- _require_fp8_blockwise (device )
729-
730- M , K , N = 256 , 256 , 256
731- mat_a = torch .randn (M , K , device = device , dtype = output_dtype ).pow (3 )
732- mat_b_rows = torch .randn (N , K , device = device , dtype = output_dtype ).pow (3 )
733-
734- mat_a_lp , encode_a = _blockwise_quantize (mat_a .to (torch .float32 ), lhs_block , 128 )
735- mat_b_lp_rows , encode_b = _blockwise_quantize (mat_b_rows .to (torch .float32 ), rhs_block , 128 )
736- mat_b_lp = mat_b_lp_rows .t ().contiguous ()
737-
738- scale_a = encode_a .reciprocal ().contiguous ()
739- scale_b = encode_b .reciprocal ().t ().contiguous ()
740-
741- recipe_map = {
742- 1 : ScalingType .BlockWise1x128 ,
743- 128 : ScalingType .BlockWise128x128 ,
744- }
745-
746- try :
747- reference = torch .nn .functional .scaled_mm (
748- mat_a_lp ,
749- mat_b_lp ,
750- scale_a ,
751- recipe_map [lhs_block ],
752- scale_b ,
753- recipe_map [rhs_block ],
754- swizzle_a = SwizzleType .NO_SWIZZLE ,
755- swizzle_b = SwizzleType .NO_SWIZZLE ,
756- output_dtype = output_dtype ,
757- )
758- except (RuntimeError , NotImplementedError , ValueError ) as exc :
759- pytest .skip (str (exc ))
760-
761- fn = thunder .jit (
762- lambda a , b , sa , sb : torch .nn .functional .scaled_mm (
763- a ,
764- b ,
765- sa ,
766- recipe_map [lhs_block ],
767- sb ,
768- recipe_map [rhs_block ],
769- swizzle_a = SwizzleType .NO_SWIZZLE ,
770- swizzle_b = SwizzleType .NO_SWIZZLE ,
771- output_dtype = output_dtype ,
772- )
773- )
774- thunder_out = fn (mat_a_lp , mat_b_lp , scale_a , scale_b )
775-
776- reference_f32 = reference .to (torch .float32 )
777- thunder_out_f32 = thunder_out .to (torch .float32 )
778-
779- assert_close (thunder_out_f32 , reference_f32 , atol = 7e-1 , rtol = 7e-2 )
780-
781-
782722# https://github.com/Lightning-AI/lightning-thunder/issues/1857
783723def test_max_with_int ():
784724 def f (x , ids ):
0 commit comments