99import torch
1010import triton
1111import triton .language as tl
12+ from torch .library import triton_op , wrap_triton
1213
1314from torchao .prototype .moe_training .utils import (
1415 _is_column_major ,
@@ -119,7 +120,7 @@ def triton_fp8_gemm_1x128_128x128(
119120 triton .cdiv (M , META ["BLOCK_SIZE_M" ]),
120121 triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
121122 )
122- triton_fp8_gemm_1x128_128x128_kernel [grid ](
123+ wrap_triton ( triton_fp8_gemm_1x128_128x128_kernel ) [grid ](
123124 a ,
124125 a .stride (0 ),
125126 a .stride (1 ),
@@ -234,7 +235,7 @@ def triton_fp8_gemm_1x128_128x1(
234235 triton .cdiv (M , META ["BLOCK_SIZE_M" ]),
235236 triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
236237 )
237- triton_fp8_gemm_1x128_128x1_kernel [grid ](
238+ wrap_triton ( triton_fp8_gemm_1x128_128x1_kernel ) [grid ](
238239 a ,
239240 a .stride (0 ),
240241 a .stride (1 ),
@@ -281,7 +282,7 @@ def triton_fp8_gemm_1x128_128x1(
281282
282283@triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
283284@triton .jit
284- def fp8_blockwise_act_quant_lhs_kernel (
285+ def triton_fp8_blockwise_act_quant_lhs_kernel (
285286 x_ptr ,
286287 x_stride_dim_0 ,
287288 x_stride_dim_1 ,
@@ -327,7 +328,8 @@ def fp8_blockwise_act_quant_lhs_kernel(
327328 tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
328329
329330
330- def fp8_blockwise_act_quant_lhs (
331+ @triton_op ("torchao::triton_fp8_blockwise_act_quant_lhs" , mutates_args = {})
332+ def triton_fp8_blockwise_act_quant_lhs (
331333 x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
332334) -> Tuple [torch .Tensor , torch .Tensor ]:
333335 """
@@ -352,7 +354,7 @@ def fp8_blockwise_act_quant_lhs(
352354 triton .cdiv (M , meta ["NUM_GROUPS" ]),
353355 triton .cdiv (K , meta ["BLOCK_SIZE" ]),
354356 )
355- fp8_blockwise_act_quant_lhs_kernel [grid ](
357+ wrap_triton ( triton_fp8_blockwise_act_quant_lhs_kernel ) [grid ](
356358 x ,
357359 x .stride (0 ),
358360 x .stride (1 ),
@@ -372,7 +374,7 @@ def fp8_blockwise_act_quant_lhs(
372374
373375@triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
374376@triton .jit
375- def fp8_blockwise_act_quant_rhs_kernel (
377+ def triton_fp8_blockwise_act_quant_rhs_kernel (
376378 x_ptr ,
377379 x_stride_dim_0 ,
378380 x_stride_dim_1 ,
@@ -420,7 +422,8 @@ def fp8_blockwise_act_quant_rhs_kernel(
420422 tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
421423
422424
423- def fp8_blockwise_act_quant_rhs (
425+ @triton_op ("torchao::triton_fp8_blockwise_act_quant_rhs" , mutates_args = {})
426+ def triton_fp8_blockwise_act_quant_rhs (
424427 x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
425428) -> Tuple [torch .Tensor , torch .Tensor ]:
426429 """
@@ -444,7 +447,7 @@ def fp8_blockwise_act_quant_rhs(
444447 triton .cdiv (M , meta ["BLOCK_SIZE" ]),
445448 triton .cdiv (K , meta ["NUM_GROUPS" ]),
446449 )
447- fp8_blockwise_act_quant_rhs_kernel [grid ](
450+ wrap_triton ( triton_fp8_blockwise_act_quant_rhs_kernel ) [grid ](
448451 x ,
449452 x .stride (0 ),
450453 x .stride (1 ),
@@ -464,7 +467,7 @@ def fp8_blockwise_act_quant_rhs(
464467
465468@triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
466469@triton .jit
467- def fp8_blockwise_act_quant_transposed_lhs_kernel (
470+ def triton_fp8_blockwise_act_quant_transposed_lhs_kernel (
468471 x_ptr ,
469472 x_stride_dim_0 ,
470473 x_stride_dim_1 ,
@@ -524,7 +527,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
524527 tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
525528
526529
527- def fp8_blockwise_act_quant_transposed_lhs (
530+ @triton_op ("torchao::triton_fp8_blockwise_act_quant_transposed_lhs" , mutates_args = {})
531+ def triton_fp8_blockwise_act_quant_transposed_lhs (
528532 x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
529533) -> Tuple [torch .Tensor , torch .Tensor ]:
530534 assert x .is_contiguous (), "Input tensor must be contiguous"
@@ -550,7 +554,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
550554 triton .cdiv (K , meta ["NUM_GROUPS" ]),
551555 )
552556
553- fp8_blockwise_act_quant_transposed_lhs_kernel [grid ](
557+ wrap_triton ( triton_fp8_blockwise_act_quant_transposed_lhs_kernel ) [grid ](
554558 x ,
555559 x .stride (0 ),
556560 x .stride (1 ),
@@ -570,7 +574,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
570574
571575@triton .autotune (configs = quant_kernel_configs , key = ["M" , "N" ])
572576@triton .jit
573- def fp8_blockwise_weight_quant_rhs_kernel (
577+ def triton_fp8_blockwise_weight_quant_rhs_kernel (
574578 x_ptr ,
575579 x_stride_dim_0 ,
576580 x_stride_dim_1 ,
@@ -615,8 +619,9 @@ def fp8_blockwise_weight_quant_rhs_kernel(
615619 tl .store (s_ptr + scale_m_off + scale_n_off , tl .div_rn (1.0 , scale ))
616620
617621
618- def fp8_blockwise_weight_quant_rhs (
619- x : torch .Tensor , block_size : int = 128 , dtype = torch .float8_e4m3fn
622+ @triton_op ("torchao::triton_fp8_blockwise_weight_quant_rhs" , mutates_args = {})
623+ def triton_fp8_blockwise_weight_quant_rhs (
624+ x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
620625) -> Tuple [torch .Tensor , torch .Tensor ]:
621626 assert x .is_contiguous (), "Input tensor must be contiguous"
622627 assert x .dim () == 2 , "Input tensor must have 2 dimensions"
@@ -638,7 +643,7 @@ def fp8_blockwise_weight_quant_rhs(
638643 triton .cdiv (M , meta ["BLOCK_SIZE" ]),
639644 triton .cdiv (N , meta ["BLOCK_SIZE" ]),
640645 )
641- fp8_blockwise_weight_quant_rhs_kernel [grid ](
646+ wrap_triton ( triton_fp8_blockwise_weight_quant_rhs_kernel ) [grid ](
642647 x ,
643648 x .stride (0 ),
644649 x .stride (1 ),
@@ -658,7 +663,7 @@ def fp8_blockwise_weight_quant_rhs(
658663
659664@triton .autotune (configs = quant_kernel_configs , key = ["M" , "N" ])
660665@triton .jit
661- def fp8_blockwise_weight_quant_transposed_rhs_kernel (
666+ def triton_fp8_blockwise_weight_quant_transposed_rhs_kernel (
662667 x_ptr ,
663668 x_stride_dim_0 ,
664669 x_stride_dim_1 ,
@@ -719,8 +724,9 @@ def fp8_blockwise_weight_quant_transposed_rhs_kernel(
719724 tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
720725
721726
722- def fp8_blockwise_weight_quant_transposed_rhs (
723- x : torch .Tensor , block_size : int = 128 , dtype = torch .float8_e4m3fn
727+ @triton_op ("torchao::triton_fp8_blockwise_weight_quant_transposed_rhs" , mutates_args = {})
728+ def triton_fp8_blockwise_weight_quant_transposed_rhs (
729+ x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
724730) -> Tuple [torch .Tensor , torch .Tensor ]:
725731 assert x .is_contiguous (), "Input tensor must be contiguous"
726732 assert x .dim () == 2 , "Input tensor must have 2 dimensions"
@@ -742,7 +748,7 @@ def fp8_blockwise_weight_quant_transposed_rhs(
742748 triton .cdiv (M , meta ["BLOCK_SIZE" ]),
743749 triton .cdiv (N , meta ["BLOCK_SIZE" ]),
744750 )
745- fp8_blockwise_weight_quant_transposed_rhs_kernel [grid ](
751+ wrap_triton ( triton_fp8_blockwise_weight_quant_transposed_rhs_kernel ) [grid ](
746752 x ,
747753 x .stride (0 ),
748754 x .stride (1 ),
0 commit comments