@@ -11646,22 +11646,22 @@ static void ggml_compute_forward_add_rel_pos(
1164611646static void ggml_compute_forward_rwkv_wkv6_f32 (
1164711647 const struct ggml_compute_params * params ,
1164811648 struct ggml_tensor * dst ) {
11649- const size_t T = dst -> src [1 ]-> ne [3 ];
11650- const size_t C = dst -> ne [0 ];
11651- const size_t HEADS = dst -> src [1 ]-> ne [2 ];
11652- const size_t n_seqs = dst -> src [5 ]-> ne [1 ];
11653- const size_t head_size = C / HEADS ;
11649+ const int64_t T = dst -> src [1 ]-> ne [3 ];
11650+ const int64_t C = dst -> ne [0 ];
11651+ const int64_t HEADS = dst -> src [1 ]-> ne [2 ];
11652+ const int64_t n_seqs = dst -> src [5 ]-> ne [1 ];
11653+ const int64_t head_size = C / HEADS ;
1165411654
1165511655 float * dst_data = (float * ) dst -> data ;
1165611656 float * state = ((float * ) dst -> data ) + C * T ;
1165711657
11658- if ((size_t )params -> ith >= HEADS ) {
11658+ if ((int64_t )params -> ith >= HEADS ) {
1165911659 return ;
1166011660 }
1166111661
11662- size_t h_start = (HEADS * params -> ith ) / params -> nth ;
11663- size_t h_end = ((HEADS * (size_t )( params -> ith + 1 )) / ( size_t ) params -> nth < HEADS ) ?
11664- (HEADS * (size_t )( params -> ith + 1 )) / ( size_t ) params -> nth : HEADS ;
11662+ int64_t h_start = (HEADS * params -> ith ) / params -> nth ;
11663+ int64_t h_end = ((HEADS * (params -> ith + 1 )) / params -> nth < HEADS ) ?
11664+ (HEADS * (params -> ith + 1 )) / params -> nth : HEADS ;
1166511665
1166611666 float * k = (float * ) dst -> src [0 ]-> data ;
1166711667 float * v = (float * ) dst -> src [1 ]-> data ;
@@ -11708,20 +11708,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1170811708 #endif
1170911709
1171011710 #ifdef WKV_VECTOR_SIZE
11711- const size_t vec_count = head_size / WKV_VECTOR_SIZE ;
11711+ const int64_t vec_count = head_size / WKV_VECTOR_SIZE ;
1171211712
11713- for (size_t t = 0 ; t < T ; t ++ ) {
11713+ for (int64_t t = 0 ; t < T ; t ++ ) {
1171411714 size_t t_offset = t * t_stride ;
1171511715 size_t state_offset = head_size * C * (t / (T / n_seqs ));
1171611716 float * state_cur = state + state_offset ;
1171711717 float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
1171811718
11719- for (size_t h = h_start ; h < h_end ; h ++ ) {
11719+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
1172011720 size_t h_offset = h * h_stride ;
1172111721 size_t t_h_offset = t_offset + h_offset ;
1172211722 size_t h_2d_offset = h * h_stride_2d ;
1172311723
11724- for (size_t i = 0 ; i < head_size ; i ++ ) {
11724+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
1172511725 size_t t_h_i_offset = t_h_offset + i ;
1172611726 size_t h_i_offset = h_offset + i ;
1172711727 size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
@@ -11737,7 +11737,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1173711737 GGML_F32X time_faaaa_vec = GGML_F32X_SET1 (time_faaaa_val );
1173811738 GGML_F32X time_decay_vec = GGML_F32X_SET1 (time_decay_val );
1173911739
11740- for (size_t j = 0 ; j < vec_count ; j ++ ) {
11740+ for (int64_t j = 0 ; j < vec_count ; j ++ ) {
1174111741 size_t base_j = j * WKV_VECTOR_SIZE ;
1174211742 size_t t_h_j_offset = t_h_offset + base_j ;
1174311743 size_t h_2d_i_j_offset = h_2d_i_offset + base_j ;
@@ -11763,7 +11763,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1176311763 }
1176411764
1176511765 // Handle remaining elements, this will not be used.
11766- for (size_t j = vec_count * VECTOR_SIZE ; j < head_size ; j ++ ) {
11766+ for (int64_t j = vec_count * VECTOR_SIZE ; j < head_size ; j ++ ) {
1176711767 size_t t_h_j_offset = t_h_offset + j ;
1176811768 size_t h_2d_i_j_offset = h_2d_i_offset + j ;
1176911769 float v_val = v [t_h_j_offset ];
@@ -11782,18 +11782,18 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1178211782 // dst = r @ (time_faaaa * (k @ v) + state),
1178311783 // state = time_decay * state + (k @ v),
1178411784 // recursive through each token
11785- for (size_t t = 0 ; t < T ; t ++ ) {
11785+ for (int64_t t = 0 ; t < T ; t ++ ) {
1178611786 size_t t_offset = t * t_stride ;
1178711787 size_t state_offset = head_size * C * (t / (T / n_seqs ));
1178811788 float * state_cur = state + state_offset ;
1178911789 float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
1179011790
11791- for (size_t h = h_start ; h < h_end ; h ++ ) {
11791+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
1179211792 size_t h_offset = h * h_stride ;
1179311793 size_t t_h_offset = t_offset + h_offset ;
1179411794 size_t h_2d_offset = h * h_stride_2d ;
1179511795
11796- for (size_t i = 0 ; i < head_size ; i ++ ) {
11796+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
1179711797 size_t t_h_i_offset = t_h_offset + i ;
1179811798 size_t h_i_offset = h_offset + i ;
1179911799 size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
@@ -11804,7 +11804,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1180411804 // RWKV v6: different time_decay for each token.
1180511805 float time_decay_val = time_decay [t_h_i_offset ];
1180611806
11807- for (size_t j = 0 ; j < head_size ; j ++ ) {
11807+ for (int64_t j = 0 ; j < head_size ; j ++ ) {
1180811808 size_t t_h_j_offset = t_h_offset + j ;
1180911809 size_t h_2d_i_j_offset = h_2d_i_offset + j ;
1181011810
0 commit comments