Skip to content

Commit 1e9981a

Browse files
committed
fix MUSA compiler warning
1 parent c9a07d2 commit 1e9981a

File tree

2 files changed

+44
-44
lines changed

2 files changed

+44
-44
lines changed

ggml/src/ggml-cuda/ssm-conv.cu

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ template <size_t split_d_inner, size_t d_conv>
44
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
55
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
66
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
7-
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
7+
const int64_t n_t) {
8+
(void) src0_nb0;
89
const int tid = threadIdx.x;
910
const int bidx = blockIdx.x;
1011
const int bidy = blockIdx.y;
1112

12-
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
13-
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
13+
const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
14+
const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
1415
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
1516

1617
const int stride_x = src0_nb1 / sizeof(float);
@@ -21,43 +22,42 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
2122
float w[d_conv] = { 0.0f };
2223

2324
#pragma unroll
24-
for (int j = 0; j < d_conv; j++) {
25+
for (size_t j = 0; j < d_conv; j++) {
2526
w[j] = w_block[tid * stride_w + j];
2627
}
2728

28-
for (int i = 0; i < n_t; i++) {
29+
for (int64_t i = 0; i < n_t; i++) {
2930
float sumf = 0.0f;
3031

3132
if (i == 0) {
32-
for (int j = 0; j < d_conv; j++) {
33+
for (size_t j = 0; j < d_conv; j++) {
3334
x[j] = x_block[tid * stride_x + j];
3435
}
3536
} else {
3637
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
3738
}
3839

3940
#pragma unroll
40-
for (int j = 0; j < d_conv; j++) {
41+
for (size_t j = 0; j < d_conv; j++) {
4142
sumf += x[(i + j) % d_conv] * w[j];
4243
}
4344
y_block[i * stride_y + tid] = sumf;
4445
}
4546
}
4647

47-
template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
48+
template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
4849
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
4950
const int src0_nb0, const int src0_nb1, const int src0_nb2,
5051
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
51-
const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
52-
const int nr, const int n_t, const int n_s) {
52+
const int dst_nb1, const int dst_nb2, const int64_t n_t) {
5353
const int tid = threadIdx.x;
5454
const int bidx = blockIdx.x;
5555
const int bidy = blockIdx.y;
5656
const int bidz = blockIdx.z;
5757

58-
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
58+
const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
5959
bidz * split_n_t * src0_nb0);
60-
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
60+
const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
6161
float * y_block =
6262
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
6363

@@ -69,25 +69,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
6969
float w[d_conv] = { 0.0f };
7070

7171
#pragma unroll
72-
for (int j = 0; j < d_conv; j++) {
72+
for (size_t j = 0; j < d_conv; j++) {
7373
w[j] = w_block[tid * stride_w + j];
7474
}
7575

7676
#pragma unroll
77-
for (int i = 0; i < split_n_t; i++) {
77+
for (int64_t i = 0; i < split_n_t; i++) {
7878
if (bidz * split_n_t + i < n_t) {
7979
float sumf = 0.0f;
8080

8181
if (i == 0) {
82-
for (int j = 0; j < d_conv; j++) {
82+
for (size_t j = 0; j < d_conv; j++) {
8383
x[j] = x_block[tid * stride_x + j];
8484
}
8585
} else {
8686
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
8787
}
8888

8989
#pragma unroll
90-
for (int j = 0; j < d_conv; j++) {
90+
for (size_t j = 0; j < d_conv; j++) {
9191
sumf += x[(i + j) % d_conv] * w[j];
9292
}
9393
y_block[i * stride_y + tid] = sumf;
@@ -97,27 +97,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
9797

9898
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
9999
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
100-
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
101-
const int n_s, cudaStream_t stream) {
100+
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
101+
const int64_t n_s, cudaStream_t stream) {
102102
const int threads = 128;
103103
GGML_ASSERT(nr % threads == 0);
104104

105105
if (n_t <= 32) {
106106
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
107107
if (nc == 4) {
108108
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109-
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
110-
n_s);
109+
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
111110
} else {
112111
GGML_ABORT("Only support kernel size = 4 now.");
113112
}
114113
} else {
115114
if (nc == 4) {
116-
const int split_n_t = 32;
117-
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
118-
ssm_conv_long_token_f32<threads, 4, split_n_t>
119-
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
120-
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
115+
const int64_t split_n_t = 32;
116+
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
117+
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
118+
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
121119
} else {
122120
GGML_ABORT("Only support kernel size = 4 right now.");
123121
}
@@ -128,11 +126,10 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128126
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
129127
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
130128

131-
const int nc = src1->ne[0]; // d_conv
132-
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
133-
const int nr = src0->ne[1]; // d_inner
134-
const int n_t = dst->ne[1]; // tokens per sequence
135-
const int n_s = dst->ne[2]; // number of sequences in the batch
129+
const int64_t nc = src1->ne[0]; // d_conv
130+
const int64_t nr = src0->ne[1]; // d_inner
131+
const int64_t n_t = dst->ne[1]; // tokens per sequence
132+
const int64_t n_s = dst->ne[2]; // number of sequences in the batch
136133

137134
GGML_ASSERT(dst->ne[0] == nr);
138135
GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -147,5 +144,5 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
147144
GGML_ASSERT(src0->type == GGML_TYPE_F32);
148145
GGML_ASSERT(dst->type == GGML_TYPE_F32);
149146
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
150-
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
147+
dst->nb[2], nc, nr, n_t, n_s, stream);
151148
}

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ __global__ void __launch_bounds__(splitD, 2)
1212
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
1313
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
1414
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
15-
float * __restrict__ dst, const int D, const int L, const int B) {
15+
float * __restrict__ dst, const int64_t L) {
16+
(void) src1_nb0;
17+
(void) src2_nb0;
1618
const int bidx = blockIdx.x; // split along B
1719
const int bidy = blockIdx.y; // split along D
1820
const int tid = threadIdx.x;
@@ -25,12 +27,12 @@ __global__ void __launch_bounds__(splitD, 2)
2527
float * smem_A = smem;
2628
float * smem_s0 = smem_A + splitD * stride_sA;
2729

28-
const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
29-
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
30-
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
31-
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1);
32-
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2));
33-
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2));
30+
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
31+
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
32+
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
33+
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
34+
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
35+
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
3436
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
3537
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
3638

@@ -46,31 +48,31 @@ __global__ void __launch_bounds__(splitD, 2)
4648
// can N not be 16? for example 32?
4749
if (N == 16) {
4850
#pragma unroll
49-
for (int i = 0; i < splitD / 4; i += 2) {
51+
for (size_t i = 0; i < splitD / 4; i += 2) {
5052
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
5153
// todo: bank conflict
5254
// I am always confused with how to use the swizzling method to solve
5355
// bank conflit. Hoping somebody can tell me.
5456
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
5557
}
5658
#pragma unroll
57-
for (int i = 0; i < splitD / 4; i += 2) {
59+
for (size_t i = 0; i < splitD / 4; i += 2) {
5860
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
5961
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
6062
}
6163
}
6264

6365
__syncthreads();
6466

65-
for (int i = 0; i < L; i++) {
67+
for (int64_t i = 0; i < L; i++) {
6668
float dt_soft_plus = dt_block[i * stride_dt + tid];
6769
if (dt_soft_plus <= 20.0f) {
6870
dt_soft_plus = log1pf(exp(dt_soft_plus));
6971
}
7072
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
7173
float sumf = 0.0f;
7274
#pragma unroll
73-
for (int j = 0; j < N; j++) {
75+
for (size_t j = 0; j < N; j++) {
7476
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
7577
(B_block[i * stride_B + j] * x_dt);
7678
sumf += state * C_block[i * stride_C + j];
@@ -90,7 +92,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
9092
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
9193
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
9294
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
93-
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) {
95+
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
96+
cudaStream_t stream) {
9497
const int threads = 128;
9598
// todo: consider D cannot be divided,does this situation exist?
9699
GGML_ASSERT(D % threads == 0);
@@ -99,7 +102,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
99102
if (N == 16) {
100103
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
101104
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
102-
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B);
105+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
103106
} else {
104107
GGML_ABORT("doesn't support N!=16.");
105108
}

0 commit comments

Comments
 (0)