@@ -264,11 +264,22 @@ def triton_fp8_gemm_1x128_128x1(
264264 num_stages = stages ,
265265 )
266266 for warps in [4 , 8 ]
267+ for stages in [2 , 4 ]
268+ ]
269+
270+ quant_kernel_configs_with_groups = [
271+ triton .Config (
272+ {"NUM_GROUPS" : groups },
273+ num_warps = warps ,
274+ num_stages = stages ,
275+ )
276+ for groups in [2 , 16 , 32 , 64 , 128 ]
277+ for warps in [2 , 4 , 8 ]
267278 for stages in [2 , 4 , 6 ]
268279]
269280
270281
271- @triton .autotune (configs = quant_kernel_configs , key = ["K" ])
282+ @triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
272283@triton .jit
273284def fp8_blockwise_act_quant_lhs_kernel (
274285 x_ptr ,
@@ -283,13 +294,14 @@ def fp8_blockwise_act_quant_lhs_kernel(
283294 M ,
284295 K : tl .constexpr ,
285296 BLOCK_SIZE : tl .constexpr ,
297+ NUM_GROUPS : tl .constexpr ,
286298 EPS : tl .constexpr ,
287299):
288300 pid_m = tl .program_id (axis = 0 )
289301 pid_k = tl .program_id (axis = 1 )
290302
291- # Load (1 x block_size) tile of x, where input is row major
292- m_offs = pid_m
303+ # Load (num_groups x block_size) tile of x, where input is row major
304+ m_offs = pid_m * NUM_GROUPS + tl . arange ( 0 , NUM_GROUPS )
293305 k_offs = pid_k * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
294306 x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
295307 x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
@@ -298,8 +310,10 @@ def fp8_blockwise_act_quant_lhs_kernel(
298310 # Perform scaling
299311 max_fp8_e4m3 = 448.0
300312 min_fp8_e4m3 = - 448.0
301- amax = tl .clamp (tl .max (tl .abs (x )), min = EPS , max = float ("inf" )).to (tl .float64 )
302- scale = (max_fp8_e4m3 / amax ).to (tl .float32 )
313+
314+ # Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1)
315+ amax = tl .clamp (tl .max (tl .abs (x ), axis = 1 ), min = EPS , max = float ("inf" )).to (tl .float64 )
316+ scale = (max_fp8_e4m3 / amax ).to (tl .float32 )[:, None ]
303317 y = x * scale
304318 y = tl .clamp (y , min = min_fp8_e4m3 , max = max_fp8_e4m3 ).to (y_ptr .dtype .element_ty )
305319
@@ -309,7 +323,7 @@ def fp8_blockwise_act_quant_lhs_kernel(
309323 tl .store (y_ptr + y_offs , y , mask = y_mask )
310324
311325 # Write reciprocal scales
312- scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
326+ scale_offs = m_offs [:, None ] * s_stride_dim_0 + pid_k * s_stride_dim_1
313327 tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
314328
315329
@@ -334,7 +348,10 @@ def fp8_blockwise_act_quant_lhs(
334348 (M , K // block_size ),
335349 (1 , M ),
336350 )
337- grid = lambda meta : (M , triton .cdiv (K , meta ["BLOCK_SIZE" ]))
351+ grid = lambda meta : (
352+ triton .cdiv (M , meta ["NUM_GROUPS" ]),
353+ triton .cdiv (K , meta ["BLOCK_SIZE" ]),
354+ )
338355 fp8_blockwise_act_quant_lhs_kernel [grid ](
339356 x ,
340357 x .stride (0 ),
@@ -353,7 +370,7 @@ def fp8_blockwise_act_quant_lhs(
353370 return y , s
354371
355372
356- @triton .autotune (configs = quant_kernel_configs , key = ["K" ])
373+ @triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
357374@triton .jit
358375def fp8_blockwise_act_quant_rhs_kernel (
359376 x_ptr ,
@@ -368,33 +385,38 @@ def fp8_blockwise_act_quant_rhs_kernel(
368385 M ,
369386 K : tl .constexpr ,
370387 BLOCK_SIZE : tl .constexpr ,
388+ NUM_GROUPS : tl .constexpr ,
371389 EPS : tl .constexpr ,
372390):
373391 pid_m = tl .program_id (axis = 0 )
374392 pid_k = tl .program_id (axis = 1 )
375393
376- # Load (block_size x 1) tile of x, where input is row major
394+ # Load (block_size x block_size) tile of x, where input is row major.
395+ # Each scaling group is (block_size x 1), but we load (block_size x block_size)
396+ # to facilitate coalesced gmem accesses and improve efficiency.
377397 m_offs = pid_m * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
378- k_offs = pid_k
398+ k_offs = pid_k * NUM_GROUPS + tl . arange ( 0 , NUM_GROUPS )
379399 x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
380400 x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
381401 x = tl .load (x_ptr + x_offs , mask = x_mask )
382402
383403 # Perform scaling
384404 max_fp8_e4m3 = 448.0
385405 min_fp8_e4m3 = - 448.0
386- amax = tl .clamp (tl .max (tl .abs (x )), min = EPS , max = float ("inf" )).to (tl .float64 )
387- scale = (max_fp8_e4m3 / amax ).to (tl .float32 )
406+
407+ # Column-wise scales for RHS operand, shape (1, block_size)
408+ amax = tl .clamp (tl .max (tl .abs (x ), axis = 0 ), min = EPS , max = float ("inf" )).to (tl .float64 )
409+ scale = (max_fp8_e4m3 / amax ).to (tl .float32 )[None , :]
388410 y = x * scale
389411 y = tl .clamp (y , min = min_fp8_e4m3 , max = max_fp8_e4m3 ).to (y_ptr .dtype .element_ty )
390412
391- # Write output to column major fomrat
413+ # Write output to column major format
392414 y_offs = m_offs [:, None ] * y_stride_dim_0 + k_offs [None , :] * y_stride_dim_1
393415 y_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
394416 tl .store (y_ptr + y_offs , y , mask = y_mask )
395417
396418 # Write scales
397- scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
419+ scale_offs = pid_m * s_stride_dim_0 + k_offs [ None , :] * s_stride_dim_1
398420 tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
399421
400422
@@ -420,7 +442,7 @@ def fp8_blockwise_act_quant_rhs(
420442
421443 grid = lambda meta : (
422444 triton .cdiv (M , meta ["BLOCK_SIZE" ]),
423- K ,
445+ triton . cdiv ( K , meta [ "NUM_GROUPS" ]) ,
424446 )
425447 fp8_blockwise_act_quant_rhs_kernel [grid ](
426448 x ,
@@ -440,7 +462,7 @@ def fp8_blockwise_act_quant_rhs(
440462 return y , s
441463
442464
443- @triton .autotune (configs = quant_kernel_configs , key = ["K" ])
465+ @triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
444466@triton .jit
445467def fp8_blockwise_act_quant_transposed_lhs_kernel (
446468 x_ptr ,
@@ -454,8 +476,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
454476 s_stride_dim_1 ,
455477 M ,
456478 K : tl .constexpr ,
457- SCALE_BLOCK_SIZE : tl .constexpr , # For scaling groups, not for grid/parallelization
458- BLOCK_SIZE_K : tl .constexpr , # For grid/parallelization, not for scaling groups
479+ BLOCK_SIZE : tl .constexpr , # For scaling groups, not for grid/parallelization
480+ NUM_GROUPS : tl .constexpr , # For grid/parallelization, not for scaling groups
459481 EPS : tl .constexpr ,
460482):
461483 # This kernel reads data in row-major format, and writes to an output tensor with
@@ -465,12 +487,12 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
465487 pid_m = tl .program_id (axis = 0 )
466488 pid_k = tl .program_id (axis = 1 )
467489
468- # Load (block_size x block_size_k ) block of input, where input is row major.
490+ # Load (block_size x num_groups ) block of input, where input is row major.
469491 # We will be computing (block_size x 1) scaling factors (columns), and computing
470- # `block_size_k ` at a time, so we aren't parallelizing with 1 thread per column,
492+ # `num_groups ` at a time, so we aren't parallelizing with 1 thread per column,
471493 # which will fail to launch for large tensors, due to max block number of 65535.
472- m_offs = pid_m * SCALE_BLOCK_SIZE + tl .arange (0 , SCALE_BLOCK_SIZE )
473- k_offs = pid_k * BLOCK_SIZE_K + tl .arange (0 , BLOCK_SIZE_K )
494+ m_offs = pid_m * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
495+ k_offs = pid_k * NUM_GROUPS + tl .arange (0 , NUM_GROUPS )
474496 x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
475497 x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
476498 x = tl .load (x_ptr + x_offs , mask = x_mask )
@@ -496,7 +518,7 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
496518
497519 # Scale tensor size is (K, M // SCALE_BLOCK_SIZE)
498520 scale_offs = scale_k_offs * s_stride_dim_0 + scale_m_off * s_stride_dim_1
499- scale_mask = (scale_k_offs < K ) & (scale_m_off < M // SCALE_BLOCK_SIZE )
521+ scale_mask = (scale_k_offs < K ) & (scale_m_off < M // BLOCK_SIZE )
500522
501523 # Write out reciprocal scales
502524 tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
@@ -524,8 +546,8 @@ def fp8_blockwise_act_quant_transposed_lhs(
524546 (1 , K ), # stride
525547 )
526548 grid = lambda meta : (
527- triton .cdiv (M , meta ["SCALE_BLOCK_SIZE " ]),
528- triton .cdiv (K , meta ["BLOCK_SIZE_K " ]),
549+ triton .cdiv (M , meta ["BLOCK_SIZE " ]),
550+ triton .cdiv (K , meta ["NUM_GROUPS " ]),
529551 )
530552
531553 fp8_blockwise_act_quant_transposed_lhs_kernel [grid ](
@@ -540,8 +562,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
540562 s .stride (1 ),
541563 M ,
542564 K = K ,
543- SCALE_BLOCK_SIZE = block_size , # Scaling group size
544- BLOCK_SIZE_K = block_size , # Just for parallelize the work along K as well
565+ BLOCK_SIZE = block_size , # Scaling group size
545566 EPS = EPS ,
546567 )
547568 return y , s
0 commit comments