Skip to content

Commit 9deb764

Browse files
committed
Review: use int64_t for blockDim.x, rename nb->s for clarity
1 parent 85e2a20 commit 9deb764

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ static __global__ void k_set_rows(
2020
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
2121
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
2222
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
23-
const size_t nb01, const size_t nb02, const size_t nb03,
24-
const size_t nb10, const size_t nb11, const size_t nb12,
25-
const size_t nb1, const size_t nb2, const size_t nb3) {
23+
const int64_t s01, const int64_t s02, const int64_t s03,
24+
const int64_t s10, const int64_t s11, const int64_t s12,
25+
const int64_t s1, const int64_t s2, const int64_t s3) {
2626

27-
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
27+
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
2828
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
2929

3030
if (i >= ne_total) {
@@ -40,10 +40,10 @@ static __global__ void k_set_rows(
4040
const int64_t i11 = i02 % ne11;
4141
const int64_t i10 = i01;
4242

43-
const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12);
43+
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
4444

45-
const src_t * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
46-
dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
45+
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
46+
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
4747

4848
const src_t* src_elem = src0_row + i00;
4949
dst_t* dst_elem = dst_row_ptr + i00;

0 commit comments

Comments
 (0)