Skip to content

Commit 81cb301

Browse files
committed
update the function to use appropriate types
1 parent bb0685f commit 81cb301

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

ggml/src/ggml-cpu.c

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11646,22 +11646,22 @@ static void ggml_compute_forward_add_rel_pos(
1164611646
static void ggml_compute_forward_rwkv_wkv6_f32(
1164711647
const struct ggml_compute_params * params,
1164811648
struct ggml_tensor * dst) {
11649-
const size_t T = dst->src[1]->ne[3];
11650-
const size_t C = dst->ne[0];
11651-
const size_t HEADS = dst->src[1]->ne[2];
11652-
const size_t n_seqs = dst->src[5]->ne[1];
11653-
const size_t head_size = C / HEADS;
11649+
const int64_t T = dst->src[1]->ne[3];
11650+
const int64_t C = dst->ne[0];
11651+
const int64_t HEADS = dst->src[1]->ne[2];
11652+
const int64_t n_seqs = dst->src[5]->ne[1];
11653+
const int64_t head_size = C / HEADS;
1165411654

1165511655
float * dst_data = (float *) dst->data;
1165611656
float * state = ((float *) dst->data) + C * T;
1165711657

11658-
if ((size_t)params->ith >= HEADS) {
11658+
if ((int64_t)params->ith >= HEADS) {
1165911659
return;
1166011660
}
1166111661

11662-
size_t h_start = (HEADS * params->ith) / params->nth;
11663-
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ?
11664-
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS;
11662+
int64_t h_start = (HEADS * params->ith) / params->nth;
11663+
int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ?
11664+
(HEADS * (params->ith + 1)) / params->nth : HEADS;
1166511665

1166611666
float * k = (float *) dst->src[0]->data;
1166711667
float * v = (float *) dst->src[1]->data;
@@ -11708,20 +11708,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1170811708
#endif
1170911709

1171011710
#ifdef WKV_VECTOR_SIZE
11711-
const size_t vec_count = head_size / WKV_VECTOR_SIZE;
11711+
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
1171211712

11713-
for (size_t t = 0; t < T; t++) {
11713+
for (int64_t t = 0; t < T; t++) {
1171411714
size_t t_offset = t * t_stride;
1171511715
size_t state_offset = head_size * C * (t / (T / n_seqs));
1171611716
float * state_cur = state + state_offset;
1171711717
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1171811718

11719-
for (size_t h = h_start; h < h_end; h++) {
11719+
for (int64_t h = h_start; h < h_end; h++) {
1172011720
size_t h_offset = h * h_stride;
1172111721
size_t t_h_offset = t_offset + h_offset;
1172211722
size_t h_2d_offset = h * h_stride_2d;
1172311723

11724-
for (size_t i = 0; i < head_size; i++) {
11724+
for (int64_t i = 0; i < head_size; i++) {
1172511725
size_t t_h_i_offset = t_h_offset + i;
1172611726
size_t h_i_offset = h_offset + i;
1172711727
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
@@ -11737,7 +11737,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1173711737
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
1173811738
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
1173911739

11740-
for (size_t j = 0; j < vec_count; j++) {
11740+
for (int64_t j = 0; j < vec_count; j++) {
1174111741
size_t base_j = j * WKV_VECTOR_SIZE;
1174211742
size_t t_h_j_offset = t_h_offset + base_j;
1174311743
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
@@ -11763,7 +11763,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1176311763
}
1176411764

1176511765
// Handle remaining elements, this will not be used.
11766-
for (size_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
11766+
for (int64_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
1176711767
size_t t_h_j_offset = t_h_offset + j;
1176811768
size_t h_2d_i_j_offset = h_2d_i_offset + j;
1176911769
float v_val = v[t_h_j_offset];
@@ -11782,18 +11782,18 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1178211782
// dst = r @ (time_faaaa * (k @ v) + state),
1178311783
// state = time_decay * state + (k @ v),
1178411784
// recursive through each token
11785-
for (size_t t = 0; t < T; t++) {
11785+
for (int64_t t = 0; t < T; t++) {
1178611786
size_t t_offset = t * t_stride;
1178711787
size_t state_offset = head_size * C * (t / (T / n_seqs));
1178811788
float * state_cur = state + state_offset;
1178911789
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1179011790

11791-
for (size_t h = h_start; h < h_end; h++) {
11791+
for (int64_t h = h_start; h < h_end; h++) {
1179211792
size_t h_offset = h * h_stride;
1179311793
size_t t_h_offset = t_offset + h_offset;
1179411794
size_t h_2d_offset = h * h_stride_2d;
1179511795

11796-
for (size_t i = 0; i < head_size; i++) {
11796+
for (int64_t i = 0; i < head_size; i++) {
1179711797
size_t t_h_i_offset = t_h_offset + i;
1179811798
size_t h_i_offset = h_offset + i;
1179911799
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
@@ -11804,7 +11804,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1180411804
// RWKV v6: different time_decay for each token.
1180511805
float time_decay_val = time_decay[t_h_i_offset];
1180611806

11807-
for (size_t j = 0; j < head_size; j ++) {
11807+
for (int64_t j = 0; j < head_size; j ++) {
1180811808
size_t t_h_j_offset = t_h_offset + j;
1180911809
size_t h_2d_i_j_offset = h_2d_i_offset + j;
1181011810

0 commit comments

Comments
 (0)