1616#### cast_transpose
1717##########################################
1818
19+ @triton .autotune (
20+ configs = [
21+ triton .Config ({'BLOCK_M' : 64 , 'BLOCK_N' : 64 , 'GROUP_M' : 1 }, num_warps = 4 ),
22+ triton .Config ({'BLOCK_M' : 64 , 'BLOCK_N' : 64 , 'GROUP_M' : 8 }, num_warps = 4 ),
23+ triton .Config ({'BLOCK_M' : 128 , 'BLOCK_N' : 128 , 'GROUP_M' : 8 }, num_warps = 8 ),
24+ ],
25+ key = ['M' , 'N' ],
26+ )
27+ @triton .jit
28+ def _amax_reduce_triton (
29+ A ,
30+ stride_am , stride_an ,
31+ M , N ,
32+ amax_ptr , # float32[1], initialize to -inf on host
33+ BLOCK_M : tl .constexpr ,
34+ BLOCK_N : tl .constexpr ,
35+ GROUP_M : tl .constexpr ,
36+ ):
37+ pid = tl .program_id (0 )
38+
39+ grid_m = (M + BLOCK_M - 1 ) // BLOCK_M
40+ grid_n = (N + BLOCK_N - 1 ) // BLOCK_N
41+
42+ width = GROUP_M * grid_n
43+ group_id = pid // width
44+ group_size = tl .minimum (grid_m - group_id * GROUP_M , GROUP_M )
45+ pid_m = group_id * GROUP_M + (pid % group_size )
46+ pid_n = (pid % width ) // group_size
47+
48+ rm = pid_m .to (tl .int64 ) * BLOCK_M + tl .arange (0 , BLOCK_M )
49+ rn = pid_n .to (tl .int64 ) * BLOCK_N + tl .arange (0 , BLOCK_N )
50+
51+ A_ptrs = A + rm [:, None ] * stride_am + rn [None , :] * stride_an
52+ mask = (rm < M )[:, None ] & (rn < N )[None , :]
53+
54+ a = tl .load (A_ptrs , mask = mask , other = 0 ).to (tl .float32 )
55+ tile_amax = tl .max (tl .abs (a ))
56+ # accumulate tile-wise max into global amax
57+ tl .atomic_max (amax_ptr , tile_amax , sem = 'relaxed' )
58+
59+
60+ @triton .jit
61+ def _compute_scale_from_amax_triton (
62+ amax_ptr ,
63+ scale_ptr ,
64+ inv_ptr ,
65+ max_fp8 ,
66+ epsilon ,
67+ value_for_inf ,
68+ FORCE_POW_2_SCALES : tl .constexpr ,
69+ ):
70+ # This implementation mimics transformer_engine::compute_scale_from_amax()
71+
72+ a = tl .load (amax_ptr ).to (tl .float32 )
73+
74+ # amax < epsilon -> epsilon (NaNs pass through)
75+ a = tl .where (a < epsilon , epsilon , a )
76+
77+ # bad amax (NaN, inf, 0.0) -> scale = 1.0
78+ bad = (a != a ) | (tl .abs (a ) == float ('inf' )) | (a == 0.0 )
79+
80+ if bad :
81+ s = tl .full ((), 1.0 , tl .float32 )
82+ else :
83+ s = max_fp8 / a
84+ # inf -> scale = value_for_inf
85+ s = tl .where (tl .abs (a ) == float ('inf' ), value_for_inf , s )
86+ if FORCE_POW_2_SCALES :
87+ s = tl .math .exp2 (tl .floor (tl .log2 (s )))
88+
89+ tl .store (scale_ptr , s )
90+ tl .store (inv_ptr , 1.0 / s )
91+
92+
1993@triton .autotune (
2094 configs = [
2195 triton .Config ({'BLOCK_M' : 64 , 'BLOCK_N' : 64 , 'GROUP_M' : 1 }, num_warps = 4 ),
@@ -69,6 +143,52 @@ def _cast_transpose_triton(A, noop_ptr, C, T, stride_am, stride_an, stride_bn, s
69143 scale_inv_out = tl .fdiv (1.0 , scale )
70144 tl .store (scale_inv_ptr , scale_inv_out )
71145
146+
147+ @triton .autotune (
148+ configs = [
149+ triton .Config ({'BLOCK_M' : 64 , 'BLOCK_N' : 64 , 'GROUP_M' : 1 }, num_warps = 4 ),
150+ triton .Config ({'BLOCK_M' : 64 , 'BLOCK_N' : 64 , 'GROUP_M' : 8 }, num_warps = 4 ),
151+ triton .Config ({'BLOCK_M' : 128 , 'BLOCK_N' : 128 , 'GROUP_M' : 8 }, num_warps = 8 ),
152+ ],
153+ key = ['M' , 'N' ]
154+ )
155+ @triton .jit
156+ def _cast_transpose_triton_current_scaling (A , C , T , stride_am , stride_an , stride_bn , stride_bm , M , N , scale_ptr , max_fp8 : tl .constexpr , BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , GROUP_M : tl .constexpr ):
157+ # Similar (but slightly optimized) version of the delayed scaling kernel
158+ # implemented in _cast_transpose_triton().
159+ pid = tl .program_id (0 )
160+ scale = tl .load (scale_ptr )
161+
162+ grid_m = (M + BLOCK_M - 1 ) // BLOCK_M
163+ grid_n = (N + BLOCK_N - 1 ) // BLOCK_N
164+
165+ width = GROUP_M * grid_n
166+ group_id = pid // width
167+ group_size = min (grid_m - group_id * GROUP_M , GROUP_M )
168+ pid_m = group_id * GROUP_M + (pid % group_size )
169+ pid_n = (pid % width ) // group_size
170+
171+ rm = pid_m .to (tl .int64 ) * BLOCK_M + tl .arange (0 , BLOCK_M )
172+ rn = pid_n .to (tl .int64 ) * BLOCK_N + tl .arange (0 , BLOCK_N )
173+ A = A + rm [:, None ] * stride_am + rn [None , :] * stride_an
174+ mask = (rm < M )[:, None ] & (rn < N )[None , :]
175+ a = tl .load (A , mask = mask )
176+ a = a .to (tl .float32 )
177+
178+ scaled_a = a * scale
179+ scaled_a = tl .clamp (scaled_a , - max_fp8 , max_fp8 )
180+ fp8_a = scaled_a .to (C .type .element_ty )
181+ C = C + rm [:, None ] * stride_am + rn [None , :] * stride_an
182+ tl .store (C , fp8_a , mask = mask )
183+
184+ # rematerialize to save registers
185+ rm = pid_m .to (tl .int64 ) * BLOCK_M + tl .arange (0 , BLOCK_M )
186+ rn = pid_n .to (tl .int64 ) * BLOCK_N + tl .arange (0 , BLOCK_N )
187+ T = T + rm [:, None ] * stride_bm + rn [None , :] * stride_bn
188+ mask = (rm < M )[:, None ] & (rn < N )[None , :]
189+ tl .store (T , fp8_a , mask = mask )
190+
191+
72192FP32_EXPONENT_BIAS = tl .constexpr (127 )
73193FP32_MANTISSA_BITS = tl .constexpr (23 )
74194@triton .jit
@@ -232,7 +352,7 @@ def _dequantize_mxfp8_triton(
232352
233353# Reshapes input of any given shape to 2D for processing,
234354# then uses the Triton kernel to perform casting and transposition efficiently.
235- def te_cast_transpose_noop_triton (input , noop_flag , input_scale , cast_out , trans_out , amax_out , scale_inv_out , otype ):
355+ def te_cast_transpose_noop_triton (input , noop_flag , input_scale , cast_out , trans_out , amax_out , scale_inv_out , otype , current_scaling , eps , force_pow_2_scales ):
236356
237357 row_length = input .shape [- 1 ] if len (input .shape ) > 0 else 1
238358 num_rows = input .numel () // row_length
@@ -254,7 +374,35 @@ def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans
254374 use_noop = False
255375
256376 grid = lambda META : (triton .cdiv (num_rows , META ['BLOCK_M' ]) * triton .cdiv (row_length , META ['BLOCK_N' ]),)
257- _cast_transpose_triton [grid ](input_2d_view , noop_flag , triton .reinterpret (cast_out_2d_view , tl_dtype ), triton .reinterpret (trans_out_2d_view , tl_dtype ), input_stride_M , input_stride_N , trans_out_stride_M , trans_out_stride_N , num_rows , row_length , input_scale , amax_out , scale_inv_out , get_fp8_max (otype ), use_noop )
377+
378+ if current_scaling :
379+ # Current scaling:
380+ # 1) global amax reduction
381+ # 2) compute current scale
382+ # 3) cast+transpose with that current scale (otherwise same as delayed)
383+
384+ # global amax
385+ amax_out .fill_ (- float ("inf" ))
386+ _amax_reduce_triton [grid ](
387+ input_2d_view ,
388+ input_stride_M , input_stride_N ,
389+ num_rows , row_length ,
390+ amax_out ,
391+ )
392+
393+ # Compute scale
394+ fp8_max = get_fp8_max (otype )
395+
396+ _compute_scale_from_amax_triton [(1 ,)](
397+ amax_out , input_scale , scale_inv_out ,
398+ fp8_max , eps , torch .finfo (torch .float32 ).max ,
399+ FORCE_POW_2_SCALES = force_pow_2_scales ,
400+ )
401+
402+ _cast_transpose_triton_current_scaling [grid ](input_2d_view , triton .reinterpret (cast_out_2d_view , tl_dtype ), triton .reinterpret (trans_out_2d_view , tl_dtype ), input_stride_M , input_stride_N , trans_out_stride_M , trans_out_stride_N , num_rows , row_length , input_scale , get_fp8_max (otype ))
403+ else :
404+ # Delayed scaling
405+ _cast_transpose_triton [grid ](input_2d_view , noop_flag , triton .reinterpret (cast_out_2d_view , tl_dtype ), triton .reinterpret (trans_out_2d_view , tl_dtype ), input_stride_M , input_stride_N , trans_out_stride_M , trans_out_stride_N , num_rows , row_length , input_scale , amax_out , scale_inv_out , get_fp8_max (otype ), use_noop )
258406
259407def te_cast_transpose_mxfp8_triton (input , out , noop_flag = None ):
260408 row_length = input .shape [- 1 ] if len (input .shape ) > 0 else 1
0 commit comments