Skip to content

Commit 3488adf

Browse files
committed
ag_demonstrate_fattn_memory_issue
1 parent 873279b commit 3488adf

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,17 @@ static __global__ void flash_attn_vec_ext_f16(
195195

196196
#pragma unroll
197197
for (int j = 0; j < ncols; ++j) {
198+
199+
// Print debug values on single thread in first iter of i_KQ_0 loop
200+
bool debug_print = (i_KQ_0==0 && blockIdx.x==0 && threadIdx.x == 0 && blockIdx.y==0 && threadIdx.y==0);
201+
if(debug_print)
202+
printf("Before vec_dot_KQ: Q_ds=%f\n",__half2float(Q_ds[0][0].x));
203+
198204
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
205+
206+
if(debug_print) // should be same as above, but is instead NAN
207+
printf("After vec_dot_KQ: Q_ds=%f\n",__half2float(Q_ds[0][0].x));
208+
199209
sum = warp_reduce_sum((float)sum);
200210

201211
if (use_logit_softcap) {

0 commit comments

Comments
 (0)