Skip to content

Commit 8f8e95e

Browse files
committed
Review: add const and use int64_t for nelements
1 parent 225bf8c commit 8f8e95e

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

ggml/src/ggml-cuda/upscale.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,22 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
2828
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
2929
const float sf0, const float sf1, const float sf2, const float sf3,
3030
const float pixel_offset) {
31-
int index = threadIdx.x + blockIdx.x * blockDim.x;
32-
int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
31+
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
32+
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
3333

3434
if (index >= dst_total_elements) {
3535
return;
3636
}
3737

38-
int i10_dst = index % ne10_dst;
39-
int i11_dst = (index / ne10_dst) % ne11_dst;
40-
int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
41-
int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
38+
const int i10_dst = index % ne10_dst;
39+
const int i11_dst = (index / ne10_dst) % ne11_dst;
40+
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
41+
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
4242

43-
int i02_src = (int)(i12_dst / sf2);
44-
int i03_src = (int)(i13_dst / sf3);
43+
const int i02_src = (int)(i12_dst / sf2);
44+
const int i03_src = (int)(i13_dst / sf3);
4545

46-
float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
46+
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
4747
int y0_src = (int)floorf(y_src_f);
4848
int y1_src = y0_src + 1;
4949

@@ -63,10 +63,10 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
6363
float dx = x_src_f - (float)x0_src;
6464
dx = max(0.0f, min(dx, 1.0f));
6565

66-
const float * p_a = (const float *)((const char *)x + (long)x0_src * nb00 + (long)y0_src * nb01 + (long)i02_src * nb02 + (long)i03_src * nb03);
67-
const float * p_b = (const float *)((const char *)x + (long)x1_src * nb00 + (long)y0_src * nb01 + (long)i02_src * nb02 + (long)i03_src * nb03);
68-
const float * p_c = (const float *)((const char *)x + (long)x0_src * nb00 + (long)y1_src * nb01 + (long)i02_src * nb02 + (long)i03_src * nb03);
69-
const float * p_d = (const float *)((const char *)x + (long)x1_src * nb00 + (long)y1_src * nb01 + (long)i02_src * nb02 + (long)i03_src * nb03);
66+
const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
67+
const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
68+
const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
69+
const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
7070

7171
const float val_a = *p_a;
7272
const float val_b = *p_b;
@@ -86,8 +86,8 @@ static void upscale_f32_cuda(const float * x, float * dst,
8686
const int ne10, const int ne11, const int ne12, const int ne13,
8787
const float sf0, const float sf1, const float sf2, const float sf3,
8888
cudaStream_t stream) {
89-
int dst_size = ne10 * ne11 * ne12 * ne13;
90-
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
89+
const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
90+
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
9191

9292
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
9393
}
@@ -98,8 +98,8 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
9898
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
9999
const float sf0, const float sf1, const float sf2, const float sf3,
100100
const float pixel_offset, cudaStream_t stream) {
101-
int dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
102-
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
101+
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
102+
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
103103

104104
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
105105
}
@@ -119,7 +119,7 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
119119
float sf0 = (float)dst->ne[0]/src0->ne[0];
120120
float sf1 = (float)dst->ne[1]/src0->ne[1];
121121
float sf2 = (float)dst->ne[2]/src0->ne[2];
122-
float sf3 = (float)dst->ne[3]/src0->ne[3];
122+
const float sf3 = (float)dst->ne[3]/src0->ne[3];
123123

124124
if (mode == GGML_SCALE_MODE_NEAREST) {
125125
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);

0 commit comments

Comments
 (0)