Skip to content

Commit 4693b46

Browse files
committed
rewrite to be more inline with the common pattern for distributing threads
1 parent a749ba7 commit 4693b46

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

ggml/src/ggml-cpu.c

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11655,13 +11655,16 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1165511655
float * dst_data = (float *) dst->data;
1165611656
float * state = ((float *) dst->data) + C * T;
1165711657

11658-
if ((int64_t)params->ith >= HEADS) {
11658+
const int ith = params->ith;
11659+
const int nth = params->nth;
11660+
11661+
if (ith >= HEADS) {
1165911662
return;
1166011663
}
1166111664

11662-
int64_t h_start = (HEADS * params->ith) / params->nth;
11663-
int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ?
11664-
(HEADS * (params->ith + 1)) / params->nth : HEADS;
11665+
const int h_start = (HEADS * ith) / nth;
11666+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
11667+
(HEADS * (ith + 1)) / nth : HEADS;
1166511668

1166611669
float * k = (float *) dst->src[0]->data;
1166711670
float * v = (float *) dst->src[1]->data;
@@ -11675,7 +11678,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
1167511678
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
1167611679
size_t h_stride_2d = head_size * head_size;
1167711680

11678-
if (params->ith == 0) {
11681+
if (ith == 0) {
1167911682
memset(dst_data, 0, T * C * sizeof(float));
1168011683
}
1168111684
ggml_barrier(params->threadpool);

0 commit comments

Comments
 (0)