File tree Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -11655,13 +11655,16 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1165511655 float * dst_data = (float * ) dst -> data ;
1165611656 float * state = ((float * ) dst -> data ) + C * T ;
1165711657
11658- if ((int64_t )params -> ith >= HEADS ) {
11658+ const int ith = params -> ith ;
11659+ const int nth = params -> nth ;
11660+
11661+ if (ith >= HEADS ) {
1165911662 return ;
1166011663 }
1166111664
11662- int64_t h_start = (HEADS * params -> ith ) / params -> nth ;
11663- int64_t h_end = ((HEADS * (params -> ith + 1 )) / params -> nth < HEADS ) ?
11664- (HEADS * (params -> ith + 1 )) / params -> nth : HEADS ;
11665+ const int h_start = (HEADS * ith ) / nth ;
11666+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS ) ?
11667+ (HEADS * (ith + 1 )) / nth : HEADS ;
1166511668
1166611669 float * k = (float * ) dst -> src [0 ]-> data ;
1166711670 float * v = (float * ) dst -> src [1 ]-> data ;
@@ -11675,7 +11678,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1167511678 GGML_ASSERT (C % HEADS == 0 ); // C must be divisible by HEADS
1167611679 size_t h_stride_2d = head_size * head_size ;
1167711680
11678- if (params -> ith == 0 ) {
11681+ if (ith == 0 ) {
1167911682 memset (dst_data , 0 , T * C * sizeof (float ));
1168011683 }
1168111684 ggml_barrier (params -> threadpool );
You can’t perform that action at this time.
0 commit comments