@@ -2525,6 +2525,126 @@ def format_scales(scales):
25252525 ref = (x32 * a_logical_scales ) @ (y32 * b_logical_scales ).T
25262526 np .testing .assert_allclose (z , ref , atol = 2e-4 , rtol = 5e-6 )
25272527
2528+ @parameterized .product (
2529+ in_jax_dtype = (jnp .float4_e2m1fn ,),
2530+ scale_jax_dtype = (jnp .float8_e8m0fnu , jnp .float8_e4m3fn ),
2531+ m = (128 ,),
2532+ n = (128 , 256 ),
2533+ swizzle = (128 ,),
2534+ )
2535+ def test_mma_block_scaled_sparse_f4 (self , m , n , in_jax_dtype , scale_jax_dtype , swizzle ):
2536+ out_jax_dtype = jnp .float32
2537+ sparse_meta_dtype = jnp .uint2
2538+ if scale_jax_dtype == jnp .float8_e8m0fnu :
2539+ block_size = 64
2540+ elif scale_jax_dtype == jnp .float8_e4m3fn :
2541+ block_size = 32
2542+ k_steps = 2
2543+
2544+ in_mlir_dtype = utils .dtype_to_ir_type (in_jax_dtype )
2545+ swizzle_elems = 8 * swizzle // bitwidth (in_mlir_dtype )
2546+ k = swizzle_elems * k_steps
2547+ lhs_tiling = rhs_tiling = (8 , swizzle_elems )
2548+
2549+ def kernel (ctx , lhs , rhs , lhs_sparse_gmem , lhs_scales_gmem , rhs_scales_gmem , out , scratch ):
2550+ (
2551+ lhs_smem , rhs_smem , lhs_sparse_smem ,
2552+ lhs_scales_smem , rhs_scales_smem ,
2553+ barriers , mma_barrier , acc , lhs_sparse , lhs_scales , rhs_scales ,
2554+ ) = scratch
2555+ operand_kwargs = dict (
2556+ swizzle = swizzle ,
2557+ gmem_transform = mgpu .TileTransform (lhs_tiling ),
2558+ )
2559+ ctx .async_copy (src_ref = lhs , dst_ref = lhs_smem , barrier = barriers [0 ], ** operand_kwargs )
2560+ ctx .async_copy (src_ref = rhs , dst_ref = rhs_smem , barrier = barriers [1 ], swizzle = swizzle , gmem_transform = mgpu .TileTransform (rhs_tiling ))
2561+ ctx .async_copy (src_ref = lhs_sparse_gmem , dst_ref = lhs_sparse_smem , barrier = barriers [2 ])
2562+ ctx .async_copy (src_ref = lhs_scales_gmem , dst_ref = lhs_scales_smem , barrier = barriers [3 ])
2563+ ctx .async_copy (src_ref = rhs_scales_gmem , dst_ref = rhs_scales_smem , barrier = barriers [4 ])
2564+ for i in range (5 ):
2565+ barriers [i ].wait ()
2566+ with mgpu .single_thread ():
2567+ tcgen05 .async_copy_sparse_metadata_smem_to_tmem (lhs_sparse_smem , lhs_sparse )
2568+ tcgen05 .async_copy_scales_smem_to_tmem (lhs_scales_smem , lhs_scales )
2569+ tcgen05 .async_copy_scales_smem_to_tmem (rhs_scales_smem , rhs_scales )
2570+ tcgen05 .mma (
2571+ acc ,
2572+ lhs_smem ,
2573+ mgpu .memref_transpose (rhs_smem , (1 , 0 , 3 , 2 )),
2574+ a_swizzle = swizzle ,
2575+ b_swizzle = swizzle ,
2576+ a_scale = lhs_scales ,
2577+ b_scale = rhs_scales ,
2578+ a_sparse_metadata = lhs_sparse ,
2579+ accumulate = False ,
2580+ )
2581+ tcgen05 .commit_arrive (mma_barrier )
2582+ mma_barrier .wait (orders_tensor_core = True )
2583+ acc .load ().store_untiled (out , optimized = False )
2584+
2585+ x_shape = (m , k // 2 )
2586+ x = self .prng .uniform (- 1 , 1 , x_shape ).astype (in_jax_dtype )
2587+ y_shape = (n , k )
2588+ y = self .prng .uniform (- 1 , 1 , y_shape ).astype (in_jax_dtype )
2589+ out_shape = jax .ShapeDtypeStruct ((m , n ), out_jax_dtype )
2590+ meta_k = k // 4
2591+ scratch_shape = [
2592+ jax .ShapeDtypeStruct (tile_shape (x_shape , lhs_tiling ), in_jax_dtype ),
2593+ jax .ShapeDtypeStruct (tile_shape (y_shape , rhs_tiling ), in_jax_dtype ),
2594+ jax .ShapeDtypeStruct ((m // 128 , meta_k // 64 , 128 , 64 ), sparse_meta_dtype ),
2595+ jax .ShapeDtypeStruct ((m // 128 , k // (block_size * 4 ), 32 , 16 ), scale_jax_dtype ),
2596+ jax .ShapeDtypeStruct ((n // 128 , k // (block_size * 4 ), 32 , 16 ), scale_jax_dtype ),
2597+ mgpu .TMABarrier (5 ),
2598+ mgpu .Barrier (1 ),
2599+ mgpu .TMEM ((m , n ), out_jax_dtype ),
2600+ mgpu .TMEM ((m , meta_k ), sparse_meta_dtype , layout = tcgen05 .sparse_meta_layout ()),
2601+ mgpu .TMEM ((m , k // block_size ), scale_jax_dtype , layout = tcgen05 .scales_layout ()),
2602+ mgpu .TMEM ((n , k // block_size ), scale_jax_dtype , layout = tcgen05 .scales_layout ()),
2603+ ]
2604+ n_groups = k // 8
2605+ index_pairs = np .asarray (np .meshgrid (range (4 ), range (4 ))).T .reshape (- 1 , 2 )
2606+ valid_pairs = index_pairs [index_pairs [:, 0 ] < index_pairs [:, 1 ]]
2607+ assert len (valid_pairs ) == 6
2608+ x_pairs = jax .random .randint (jax .random .key (1234 ), (m , n_groups ), 0 , 6 , dtype = jnp .uint8 )
2609+ x_sparse = valid_pairs [x_pairs ]
2610+ assert x_sparse .shape == (m , n_groups , 2 )
2611+ def format_sparse_meta (meta ):
2612+ mn , groups , _2 = meta .shape
2613+ assert _2 == 2
2614+ k_meta = groups * 2
2615+ meta_tiled = (
2616+ meta .reshape (mn // 128 , 128 , k_meta // 64 , 64 ).transpose (0 , 2 , 1 , 3 )
2617+ )
2618+ return (
2619+ meta_tiled .reshape (mn // 128 , k_meta // 64 , 128 , 64 )
2620+ .astype (sparse_meta_dtype )
2621+ )
2622+ x_gpu_sparse = format_sparse_meta (x_sparse )
2623+ a_scales , b_scales = self ._sample_scales (m , k , n , block_size , scale_jax_dtype )
2624+ def format_scales (scales ):
2625+ mn , k = scales .shape
2626+ assert mn % 128 == 0 and k % 4 == 0 , scales .shape
2627+ return (
2628+ scales .reshape (mn // 128 , 4 , 32 , k // 4 , 4 )
2629+ .transpose (0 , 3 , 2 , 1 , 4 )
2630+ .reshape (mn // 128 , k // 4 , 32 , 16 )
2631+ )
2632+ a_gpu_scales , b_gpu_scales = map (format_scales , (a_scales , b_scales ))
2633+ args = (x , y , x_gpu_sparse , a_gpu_scales , b_gpu_scales )
2634+ z = mgpu .as_gpu_kernel (
2635+ kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), args , out_shape , scratch_shape
2636+ )(* args )
2637+ # 4-bit sparse data is filled in pairs of elements.
2638+ x32 = x .astype (np .float32 ).reshape (m , n_groups , 2 , 2 )
2639+ x_logical32 = np .zeros ((m , n_groups , 4 , 2 ), dtype = np .float32 )
2640+ np .put_along_axis (x_logical32 , x_sparse [..., np .newaxis ], x32 , axis = - 2 )
2641+ x_logical32 = x_logical32 .reshape (m , k )
2642+ y32 = y .astype (np .float32 )
2643+ a_logical_scales = jnp .repeat (a_scales , block_size , axis = 1 ).astype (jnp .float32 )
2644+ b_logical_scales = jnp .repeat (b_scales , block_size , axis = 1 ).astype (jnp .float32 )
2645+ ref = (x_logical32 * a_logical_scales ) @ (y32 * b_logical_scales ).T
2646+ np .testing .assert_allclose (z , ref , atol = 2e-4 , rtol = 5e-6 )
2647+
25282648 @parameterized .product (
25292649 lhs_transpose = (False , True ),
25302650 rhs_transpose = (False , True ),
0 commit comments