@@ -51,45 +51,85 @@ def per_token_cast_to_fp8_e8m0(
5151 g , m , n ), sf
5252
5353
54- def per_block_cast_to_fp8_e8m0 (
55- x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
56- if x .dim () == 2 :
57- m , n = x .shape
58- x_padded = torch .zeros ((align (m , 128 ), align (n , 128 )),
59- dtype = x .dtype ,
60- device = x .device )
61- x_padded [:m , :n ] = x
62- x_view = x_padded .view (- 1 , 128 , x_padded .size (1 ) // 128 , 128 )
63- x_amax = x_view .abs ().float ().amax (dim = (1 , 3 ), keepdim = True ).clamp (1e-4 )
64- sf = ceil_to_ue8m0 (x_amax / 448.0 )
65- x_scaled = (x_view * (1.0 / sf )).to (torch .float8_e4m3fn )
66- return x_scaled .view_as (x_padded )[:m , :n ].contiguous (), sf .view (
67- x_view .size (0 ), x_view .size (2 ))
68- else :
69- g , m , n = x .shape
70- x_padded = torch .zeros ((g , align (m , 128 ), align (n , 128 )),
71- dtype = x .dtype ,
72- device = x .device )
73- x_padded [:, :m , :n ] = x
74- x_view = x_padded .view (g , - 1 , 128 , x_padded .size (- 1 ) // 128 , 128 )
75- x_amax = x_view .abs ().float ().amax (dim = (2 , 4 ), keepdim = True ).clamp (1e-4 )
76- sf = ceil_to_ue8m0 (x_amax / 448.0 )
77- x_scaled = (x_view * (1.0 / sf )).to (torch .float8_e4m3fn )
78- return x_scaled .view_as (x_padded )[:, :m , :n ].contiguous (), sf .view (
79- x_view .size (0 ), x_view .size (1 ), x_view .size (3 ))
80-
81-
82- def resmooth_to_fp8_e8m0 (weight : torch .Tensor ,
83- sf : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
84- weight = weight .cuda ()
85- sf = sf .cuda ()
86- if weight .dim () == 2 :
87- x = weight .float () * sf .repeat_interleave (128 , dim = 0 ).repeat_interleave (
88- 128 , dim = 1 )[:weight .shape [0 ], :weight .shape [1 ]]
89- else :
90- x = weight .float () * sf .repeat_interleave (128 , dim = 1 ).repeat_interleave (
91- 128 , dim = 2 )[:weight .shape [0 ], :weight .shape [1 ], :weight .shape [2 ]]
92- return per_block_cast_to_fp8_e8m0 (x )
54+ @triton .jit
55+ def _resmooth_kernel (
56+ w_ptr ,
57+ s_ptr ,
58+ M ,
59+ K ,
60+ stride_wb ,
61+ stride_wm ,
62+ stride_wk ,
63+ stride_sb ,
64+ stride_sm ,
65+ stride_sk ,
66+ BLOCK_M : tl .constexpr ,
67+ BLOCK_K : tl .constexpr ,
68+ ):
69+ batch_idx = tl .program_id (0 )
70+ pid_m = tl .program_id (1 )
71+ pid_k = tl .program_id (2 )
72+
73+ curr_w_ptr = w_ptr + batch_idx * stride_wb
74+ curr_s_ptr = s_ptr + batch_idx * stride_sb
75+
76+ rm = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
77+ rk = pid_k * BLOCK_K + tl .arange (0 , BLOCK_K )
78+
79+ s_offset = pid_m * stride_sm + pid_k * stride_sk
80+ old_scale = tl .load (curr_s_ptr + s_offset )
81+
82+ w_mask = (rm [:, None ] < M ) & (rk [None , :] < K )
83+ w_offsets = rm [:, None ] * stride_wm + rk [None , :] * stride_wk
84+ w_fp8 = tl .load (curr_w_ptr + w_offsets , mask = w_mask , other = 0.0 )
85+ w_fp32 = w_fp8 .to (tl .float32 )
86+
87+ w_val = w_fp32 * old_scale
88+ block_amax = tl .maximum (tl .max (tl .abs (w_val )), 1e-4 )
89+
90+ # UE8M0 sf = 2 ^ ceil(log2(sf))
91+ new_scale = tl .math .exp2 (tl .math .ceil (tl .math .log2 (block_amax / 448.0 )))
92+ w_requant = w_val * (1.0 / new_scale )
93+
94+ tl .store (curr_w_ptr + w_offsets , w_requant , mask = w_mask )
95+ tl .store (curr_s_ptr + s_offset , new_scale )
96+
97+
98+ def resmooth_to_fp8_e8m0 (
99+ weight : torch .Tensor ,
100+ weight_scale : torch .Tensor ,
101+ block_size : tuple [int , int ] = (128 , 128 ),
102+ ):
103+ assert weight .dtype == torch .float8_e4m3fn
104+ assert weight_scale .dtype == torch .float32
105+
106+ orig_shape = weight .shape
107+ M , K = orig_shape [- 2 :]
108+ w_view = weight .view (- 1 , M , K )
109+ s_view = weight_scale .view (- 1 , weight_scale .shape [- 2 ],
110+ weight_scale .shape [- 1 ])
111+
112+ num_batches = w_view .shape [0 ]
113+ BLOCK_M , BLOCK_K = block_size
114+
115+ grid = (num_batches , triton .cdiv (M , BLOCK_M ), triton .cdiv (K , BLOCK_K ))
116+
117+ _resmooth_kernel [grid ](
118+ w_view ,
119+ s_view ,
120+ M ,
121+ K ,
122+ w_view .stride (0 ),
123+ w_view .stride (1 ),
124+ w_view .stride (2 ),
125+ s_view .stride (0 ),
126+ s_view .stride (1 ),
127+ s_view .stride (2 ),
128+ BLOCK_M = BLOCK_M ,
129+ BLOCK_K = BLOCK_K ,
130+ )
131+ # this is an in-place operation, however, we return for simplicity
132+ return weight , weight_scale
93133
94134
95135def get_m_alignment_for_contiguous_layout ():
0 commit comments