Skip to content

Commit 300a59c

Browse files
Avoid division by zero in cache DS MLA kernel (vllm-project#26174)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent d76541a commit 300a59c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

csrc/cache_kernels.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#include <algorithm>
1818
#include <cassert>
19-
#include <cfloat> // FLT_MIN
19+
#include <cfloat>
2020

2121
#ifdef USE_ROCM
2222
#include <hip/hip_bf16.h>
@@ -479,6 +479,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
479479

480480
// Compute the scale for the tile
481481
float tile_scale = max_abs / 448.f;
482+
tile_scale = fmaxf(tile_scale, FLT_MIN);
482483

483484
// The first lane of each half-warp writes the scale to kv_cache
484485
if ((lane_idx == 0) || (lane_idx == 16)) {

0 commit comments

Comments
 (0)