Skip to content

Commit f7886ec

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 4add034 commit f7886ec

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,14 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
315315

316316
float vals[sizeof(int)] = {0.0f};
317317
#pragma unroll
318-
for (size_t l = 0; l < sizeof(int); ++l) {
318+
for (int l = 0; l < int(sizeof(int)); ++l) {
319319
vals[l] = scale * x[4*threadIdx.x + l];
320320
}
321321

322322
float amax = fabsf(vals[0]);
323323
float sum = vals[0];
324324
#pragma unroll
325-
for (size_t l = 1; l < sizeof(int); ++l) {
325+
for (int l = 1; l < int(sizeof(int)); ++l) {
326326
amax = fmaxf(amax, fabsf(vals[l]));
327327
sum += vals[l];
328328
}
@@ -338,7 +338,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
338338

339339
if (d != 0.0f) {
340340
#pragma unroll
341-
for (size_t l = 0; l < sizeof(int); ++l) {
341+
for (int l = 0; l < int(sizeof(int)); ++l) {
342342
q8[l] = roundf(vals[l] / d);
343343
}
344344
}
@@ -638,9 +638,9 @@ static __global__ void flash_attn_combine_results(
638638
float VKQ_denominator = 0.0f;
639639
for (int l = 0; l < parallel_blocks; ++l) {
640640
const float diff = meta[l].x - kqmax;
641-
const float KQ_max_scale = expf(diff);
641+
float KQ_max_scale = expf(diff);
642642
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
643-
*((uint32_t *) const_cast<float *>(&KQ_max_scale)) &= ftz_mask;
643+
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
644644

645645
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
646646
VKQ_denominator += KQ_max_scale * meta[l].y;

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
12531253
const float d = bxi->d;
12541254

12551255
#pragma unroll
1256-
for (size_t l = 0; l < sizeof(int); ++l) {
1256+
for (int l = 0; l < int(sizeof(int)); ++l) {
12571257
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
12581258
}
12591259
#else
@@ -1376,7 +1376,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
13761376
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
13771377

13781378
#pragma unroll
1379-
for (size_t l = 0; l < sizeof(int); ++l) {
1379+
for (int l = 0; l < int(sizeof(int)); ++l) {
13801380
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
13811381
}
13821382
}
@@ -1517,7 +1517,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
15171517
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
15181518

15191519
#pragma unroll
1520-
for (size_t l = 0; l < sizeof(int); ++l) {
1520+
for (int l = 0; l < int(sizeof(int)); ++l) {
15211521
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
15221522
}
15231523
}

0 commit comments

Comments
 (0)