Skip to content

Commit b63af60

Browse files
committed
Address review comments
1 parent 956a1d0 commit b63af60

File tree

1 file changed

+55
-57
lines changed

1 file changed

+55
-57
lines changed

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 55 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
3434
const int ne0,
3535
const int ne1,
3636
const int ne2,
37-
const uint3 ne3_fastdiv,
38-
const uint3 ne10_fastdiv,
39-
const uint3 ne11_fastdiv,
40-
const uint3 ne12_fastdiv,
41-
const uint3 ne13_fastdiv,
37+
const uint3 ne3,
38+
const uint3 ne10,
39+
const uint3 ne11,
40+
const uint3 ne12,
41+
const uint3 ne13,
4242
/*int s0, */ const int s1,
4343
const int s2,
4444
const int s3,
@@ -51,16 +51,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
5151
src1_ptrs... src1s) {
5252
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
5353
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
54-
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3_fastdiv);
55-
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3_fastdiv.z);
54+
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
55+
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
5656

57-
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3_fastdiv.z) {
57+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
5858
return;
5959
}
6060

61-
const uint32_t i11 = fastmodulo(i1, ne11_fastdiv);
62-
const uint32_t i12 = fastmodulo(i2, ne12_fastdiv);
63-
const uint32_t i13 = fastmodulo(i3, ne13_fastdiv);
61+
const uint32_t i11 = fastmodulo(i1, ne11);
62+
const uint32_t i12 = fastmodulo(i2, ne12);
63+
const uint32_t i13 = fastmodulo(i3, ne13);
6464

6565
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
6666
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -70,7 +70,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
7070
dst_t * dst_row = dst + i_dst;
7171

7272
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
73-
const uint32_t i10 = fastmodulo(i0, ne10_fastdiv);
73+
const uint32_t i10 = fastmodulo(i0, ne10);
7474

7575
float result = src0_row ? (float) src0_row[i0] : 0.0f;
7676
if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -91,16 +91,16 @@ template <float (*bin_op)(const float, const float),
9191
static __global__ void k_bin_bcast_unravel(const src0_t * src0,
9292
const src1_t * src1,
9393
dst_t * dst,
94-
const uint3 ne0_fastdiv,
95-
const uint3 ne1_fastdiv,
96-
const uint3 ne2_fastdiv,
94+
const uint3 ne0,
95+
const uint3 ne1,
96+
const uint3 ne2,
9797
const uint32_t ne3,
98-
const uint3 ne012_fastdiv,
99-
const uint3 ne01_fastdiv,
100-
const uint3 ne10_fastdiv,
101-
const uint3 ne11_fastdiv,
102-
const uint3 ne12_fastdiv,
103-
const uint3 ne13_fastdiv,
98+
const uint3 prod012,
99+
const uint3 prod01,
100+
const uint3 ne10,
101+
const uint3 ne11,
102+
const uint3 ne12,
103+
const uint3 ne13,
104104
/*int s0, */ const int s1,
105105
const int s2,
106106
const int s3,
@@ -113,18 +113,18 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
113113
src1_ptrs... src1s) {
114114
const int i = blockDim.x*blockIdx.x + threadIdx.x;
115115

116-
const uint32_t i3 = fastdiv(i, ne012_fastdiv);
117-
const uint32_t i2 = fastmodulo(fastdiv(i, ne01_fastdiv), ne2_fastdiv);
118-
const uint32_t i1 = fastmodulo(fastdiv(i, ne0_fastdiv), ne1_fastdiv);
119-
const uint32_t i0 = fastmodulo(i, ne0_fastdiv);
116+
const uint32_t i3 = fastdiv(i, prod012);
117+
const uint32_t i2 = fastdiv(i - i3 * prod012.z, prod01);
118+
const uint32_t i1 = fastdiv(i - i3 * prod012.z - i2 * prod01.z, ne0);
119+
const uint32_t i0 = i - i3 * prod012.z - i2 * prod01.z - i1 * ne0.z;
120120

121-
if (i0 >= ne0_fastdiv.z || i1 >= ne1_fastdiv.z || i2 >= ne2_fastdiv.z || i3 >= ne3) {
121+
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
122122
return;
123123
}
124124

125-
const int i11 = fastmodulo(i1, ne11_fastdiv);
126-
const int i12 = fastmodulo(i2, ne12_fastdiv);
127-
const int i13 = fastmodulo(i3, ne13_fastdiv);
125+
const int i11 = fastmodulo(i1, ne11);
126+
const int i12 = fastmodulo(i2, ne12);
127+
const int i13 = fastmodulo(i3, ne13);
128128

129129
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
130130
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -133,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
133133
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
134134
dst_t * dst_row = dst + i_dst;
135135

136-
const int i10 = fastmodulo(i0, ne10_fastdiv);
136+
const int i10 = fastmodulo(i0, ne10);
137137

138138
float result = src0_row ? (float) src0_row[i0] : 0.0f;
139139
if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -267,50 +267,48 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
267267
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
268268
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
269269

270-
271-
const uint3 ne10_fastdiv = init_fastdiv_values((uint32_t) cne1[0]);
272-
const uint3 ne11_fastdiv = init_fastdiv_values((uint32_t) cne1[1]);
273-
const uint3 ne12_fastdiv = init_fastdiv_values((uint32_t) cne1[2]);
274-
const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) cne1[3]);
270+
const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
271+
const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
272+
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
273+
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
275274

276275
if (block_nums.z > 65535) {
277-
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
278-
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
279-
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
280-
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
281-
const uint3 ne012_fastdiv = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
282-
const uint3 ne01_fastdiv = init_fastdiv_values((uint32_t) (ne0 * ne1));
276+
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
277+
const uint3 prod012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
278+
const uint3 prod01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
279+
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
280+
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
281+
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
282+
283283
if constexpr (sizeof...(I) > 0) {
284284
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
285-
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv,
286-
ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
285+
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod012, prod01, ne10, ne11,
286+
ne12, ne13,
287287
/* s0, */ s1, s2, s3,
288288
/* s00,*/ s01, s02, s03,
289289
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
290290
} else {
291-
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
292-
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv,
293-
ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
294-
/* s0, */ s1, s2, s3,
295-
/* s00,*/ s01, s02, s03,
296-
/* s10,*/ s11, s12, s13);
291+
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
292+
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
293+
ne2_fastdiv, ne3, prod012, prod01, ne10, ne11, ne12, ne13,
294+
/* s0, */ s1, s2, s3,
295+
/* s00,*/ s01, s02, s03,
296+
/* s10,*/ s11, s12, s13);
297297
}
298298
} else {
299299
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
300300
if constexpr (sizeof...(I) > 0) {
301301
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
302-
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10_fastdiv, ne11_fastdiv, ne12_fastdiv,
303-
ne13_fastdiv,
302+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
304303
/* s0, */ s1, s2, s3,
305304
/* s00,*/ s01, s02, s03,
306305
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
307306
} else {
308-
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
309-
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv,
310-
ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
311-
/* s0, */ s1, s2, s3,
312-
/* s00,*/ s01, s02, s03,
313-
/* s10,*/ s11, s12, s13);
307+
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
308+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
309+
/* s0, */ s1, s2, s3,
310+
/* s00,*/ s01, s02, s03,
311+
/* s10,*/ s11, s12, s13);
314312
}
315313
}
316314
}

0 commit comments

Comments
 (0)