File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change 1515
1616# To check compatibility
1717IS_TURING = current_platform .get_device_capability () == (7 , 5 )
18+ float8_info = torch .finfo (current_platform .fp8_dtype ())
1819
1920
2021# Here's an example autotuner config for this kernel. This config does provide
@@ -83,7 +84,9 @@ def _fwd_kernel(Q,
8384 SKIP_DECODE : tl .constexpr ,
8485 USE_FP8 : tl .constexpr ,
8586 MAX_Q_LEN : tl .constexpr = 0 ,
86- MAX_CTX_LEN : tl .constexpr = 0 ):
87+ MAX_CTX_LEN : tl .constexpr = 0 ,
88+ FP8_MIN : tl .constexpr = float8_info .min ,
89+ FP8_MAX : tl .constexpr = float8_info .max ):
8790
8891 cur_batch = tl .program_id (0 )
8992 cur_head = tl .program_id (1 )
@@ -278,6 +281,7 @@ def _fwd_kernel(Q,
278281 out_ptrs = Out + off_o
279282 if USE_FP8 :
280283 acc = acc / tl .load (out_scale )
284+ acc = tl .clamp (acc , FP8_MIN , FP8_MAX )
281285 tl .store (out_ptrs ,
282286 acc ,
283287 mask = dim_mask [None , :] & (offs_m [:, None ] < cur_batch_query_len ))
You can’t perform that action at this time.
0 commit comments