Skip to content

Commit 1120d94

Browse files
remove restrict from pointers
1 parent deab327 commit 1120d94

File tree

5 files changed

+52
-47
lines changed

5 files changed

+52
-47
lines changed

ggml/src/ggml-alloc.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
3737
return true;
3838
}
3939

40+
// ops that return true for this function must not use restrict pointers for their backend implementations
4041
static bool ggml_op_can_inplace(enum ggml_op op) {
4142
switch (op) {
4243
case GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
5253
case GGML_OP_LOG:
5354
case GGML_OP_UNARY:
5455
case GGML_OP_ROPE:
56+
case GGML_OP_ROPE_BACK:
57+
case GGML_OP_SILU_BACK:
5558
case GGML_OP_RMS_NORM:
59+
case GGML_OP_RMS_NORM_BACK:
5660
case GGML_OP_SOFT_MAX:
61+
case GGML_OP_SOFT_MAX_BACK:
5762
return true;
5863

5964
default:

ggml/src/ggml-cuda/norm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "norm.cuh"
22

33
template <int block_size>
4-
static __global__ void norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const float eps) {
4+
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
55
const int row = blockIdx.x*blockDim.y + threadIdx.y;
66
const int tid = threadIdx.x;
77

@@ -41,7 +41,7 @@ static __global__ void norm_f32(const float * __restrict__ x, float * __restrict
4141
}
4242

4343
template <int block_size>
44-
static __global__ void group_norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int group_size, const int ne_elements, const float eps) {
44+
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
4545
// blockIdx.x: num_groups idx
4646
// threadIdx.x: block_size idx
4747
const int start = blockIdx.x*group_size + threadIdx.x;
@@ -97,7 +97,7 @@ static __global__ void group_norm_f32(const float * __restrict__ x, float * __re
9797
}
9898

9999
template <int block_size>
100-
static __global__ void rms_norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const float eps) {
100+
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
101101
const int row = blockIdx.x*blockDim.y + threadIdx.y;
102102
const int tid = threadIdx.x;
103103

@@ -136,7 +136,7 @@ static __global__ void rms_norm_f32(const float * __restrict__ x, float * __rest
136136

137137
template <int block_size>
138138
static __global__ void rms_norm_back_f32(
139-
const float * __restrict__ grad, const float * __restrict__ xf, float * __restrict__ dst, const int ncols, const float eps) {
139+
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
140140
const int row = blockIdx.x*blockDim.y + threadIdx.y;
141141
const int tid = threadIdx.x;
142142

ggml/src/ggml-cuda/rope.cu

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ static __device__ void rope_yarn(
3939

4040
template<bool forward, bool has_ff, typename T>
4141
static __global__ void rope_norm(
42-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43-
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
44-
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
42+
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43+
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
44+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
4545
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
4646

4747
if (i0 >= ne0) {
@@ -83,9 +83,9 @@ static __global__ void rope_norm(
8383

8484
template<bool forward, bool has_ff, typename T>
8585
static __global__ void rope_neox(
86-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
87-
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
88-
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
86+
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
87+
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
88+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
8989
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
9090

9191
if (i0 >= ne0) {
@@ -127,9 +127,9 @@ static __global__ void rope_neox(
127127

128128
template<bool forward, bool has_ff, typename T>
129129
static __global__ void rope_multi(
130-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
131-
const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
132-
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
130+
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
131+
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
132+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
133133
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
134134

135135
if (i0 >= ne0) {
@@ -187,9 +187,9 @@ static __global__ void rope_multi(
187187

188188
template<bool forward, bool has_ff, typename T>
189189
static __global__ void rope_vision(
190-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
191-
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
192-
const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
190+
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
191+
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
192+
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
193193
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
194194

195195
if (i0 >= ne0) {
@@ -234,9 +234,9 @@ static __global__ void rope_vision(
234234

235235
template<bool forward, typename T>
236236
static void rope_norm_cuda(
237-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
238-
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
239-
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
237+
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
238+
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
239+
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
240240
GGML_ASSERT(ne0 % 2 == 0);
241241
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
242242
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -257,9 +257,9 @@ static void rope_norm_cuda(
257257

258258
template<bool forward, typename T>
259259
static void rope_neox_cuda(
260-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
261-
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
262-
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
260+
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
261+
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
262+
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
263263
GGML_ASSERT(ne0 % 2 == 0);
264264
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
265265
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -280,9 +280,9 @@ static void rope_neox_cuda(
280280

281281
template<bool forward, typename T>
282282
static void rope_multi_cuda(
283-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
284-
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
285-
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
283+
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
284+
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
285+
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
286286
GGML_ASSERT(ne0 % 2 == 0);
287287
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
288288
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -303,9 +303,9 @@ static void rope_multi_cuda(
303303

304304
template<bool forward, typename T>
305305
static void rope_vision_cuda(
306-
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
307-
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
308-
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
306+
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
307+
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
308+
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
309309
GGML_ASSERT(ne0 % 2 == 0);
310310
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
311311
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);

ggml/src/ggml-cuda/softmax.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1515

1616
template <bool use_shared, int ncols_template, int block_size_template, typename T>
1717
static __global__ void soft_max_f32(
18-
const float * __restrict__ x, const T * __restrict__ mask, float * __restrict__ dst, const int ncols_par, const int nrows_y,
18+
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
1919
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
2020
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
2121

@@ -120,7 +120,7 @@ static __global__ void soft_max_f32(
120120
}
121121

122122
static __global__ void soft_max_back_f32(
123-
const float * __restrict__ grad, const float * __restrict__ dstf, float * __restrict__ dst, const int ncols, const float scale) {
123+
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
124124
const int tid = threadIdx.x;
125125
const int rowx = blockIdx.x;
126126

ggml/src/ggml-cuda/unary.cu

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "unary.cuh"
22

3-
static __global__ void neg_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
3+
static __global__ void neg_f32(const float * x, float * dst, const int k) {
44
const int i = blockDim.x*blockIdx.x + threadIdx.x;
55

66
if (i >= k) {
@@ -10,7 +10,7 @@ static __global__ void neg_f32(const float * __restrict__ x, float * __restrict_
1010
dst[i] = -x[i];
1111
}
1212

13-
static __global__ void step_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
13+
static __global__ void step_f32(const float * x, float * dst, const int k) {
1414
const int i = blockDim.x*blockIdx.x + threadIdx.x;
1515

1616
if (i >= k) {
@@ -20,7 +20,7 @@ static __global__ void step_f32(const float * __restrict__ x, float * __restrict
2020
dst[i] = x[i] > 0.0f;
2121
}
2222

23-
static __global__ void gelu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
23+
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
2424
const float GELU_COEF_A = 0.044715f;
2525
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
2626
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -33,7 +33,7 @@ static __global__ void gelu_f32(const float * __restrict__ x, float * __restrict
3333
dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
3434
}
3535

36-
static __global__ void gelu_quick_f32(const float * __restrict__ x, float * __restrict__ dst, int k) {
36+
static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
3737
const float GELU_QUICK_COEF = -1.702f;
3838
const int i = blockDim.x*blockIdx.x + threadIdx.x;
3939
if (i >= k) {
@@ -42,7 +42,7 @@ static __global__ void gelu_quick_f32(const float * __restrict__ x, float * __re
4242
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
4343
}
4444

45-
static __global__ void silu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
45+
static __global__ void silu_f32(const float * x, float * dst, const int k) {
4646
const int i = blockDim.x*blockIdx.x + threadIdx.x;
4747

4848
if (i >= k) {
@@ -52,7 +52,7 @@ static __global__ void silu_f32(const float * __restrict__ x, float * __restrict
5252
}
5353

5454
static __global__ void silu_back_f32(
55-
const float * __restrict__ grad, const float * __restrict__ xf, float * __restrict__ dst, const int k) {
55+
const float * grad, const float * xf, float * dst, const int k) {
5656
const int i = blockDim.x*blockIdx.x + threadIdx.x;
5757

5858
if (i >= k) {
@@ -64,15 +64,15 @@ static __global__ void silu_back_f32(
6464
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
6565
}
6666

67-
static __global__ void tanh_f32(const float * __restrict__ x, float * __restrict__ dst, int k) {
67+
static __global__ void tanh_f32(const float * x, float * dst, int k) {
6868
const int i = blockDim.x*blockIdx.x + threadIdx.x;
6969
if (i >= k) {
7070
return;
7171
}
7272
dst[i] = tanhf(x[i]);
7373
}
7474

75-
static __global__ void relu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
75+
static __global__ void relu_f32(const float * x, float * dst, const int k) {
7676
const int i = blockDim.x*blockIdx.x + threadIdx.x;
7777

7878
if (i >= k) {
@@ -81,7 +81,7 @@ static __global__ void relu_f32(const float * __restrict__ x, float * __restrict
8181
dst[i] = fmaxf(x[i], 0);
8282
}
8383

84-
static __global__ void sigmoid_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
84+
static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
8585
const int i = blockDim.x*blockIdx.x + threadIdx.x;
8686

8787
if (i >= k) {
@@ -90,7 +90,7 @@ static __global__ void sigmoid_f32(const float * __restrict__ x, float * __restr
9090
dst[i] = 1.0f / (1.0f + expf(-x[i]));
9191
}
9292

93-
static __global__ void hardsigmoid_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
93+
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
9494
const int i = blockDim.x*blockIdx.x + threadIdx.x;
9595

9696
if (i >= k) {
@@ -99,7 +99,7 @@ static __global__ void hardsigmoid_f32(const float * __restrict__ x, float * __r
9999
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
100100
}
101101

102-
static __global__ void hardswish_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
102+
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
103103
const int i = blockDim.x*blockIdx.x + threadIdx.x;
104104

105105
if (i >= k) {
@@ -108,7 +108,7 @@ static __global__ void hardswish_f32(const float * __restrict__ x, float * __res
108108
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
109109
}
110110

111-
static __global__ void exp_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
111+
static __global__ void exp_f32(const float * x, float * dst, const int k) {
112112
const int i = blockDim.x*blockIdx.x + threadIdx.x;
113113

114114
if (i >= k) {
@@ -117,15 +117,15 @@ static __global__ void exp_f32(const float * __restrict__ x, float * __restrict_
117117
dst[i] = expf(x[i]);
118118
}
119119

120-
static __global__ void leaky_relu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k, const float negative_slope) {
120+
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
121121
const int i = blockDim.x*blockIdx.x + threadIdx.x;
122122
if (i >= k) {
123123
return;
124124
}
125125
dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
126126
}
127127

128-
static __global__ void sqr_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
128+
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
129129
const int i = blockDim.x*blockIdx.x + threadIdx.x;
130130

131131
if (i >= k) {
@@ -134,7 +134,7 @@ static __global__ void sqr_f32(const float * __restrict__ x, float * __restrict_
134134
dst[i] = x[i] * x[i];
135135
}
136136

137-
static __global__ void sqrt_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
137+
static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
138138
const int i = blockDim.x*blockIdx.x + threadIdx.x;
139139

140140
if (i >= k) {
@@ -143,7 +143,7 @@ static __global__ void sqrt_f32(const float * __restrict__ x, float * __restrict
143143
dst[i] = sqrtf(x[i]);
144144
}
145145

146-
static __global__ void sin_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
146+
static __global__ void sin_f32(const float * x, float * dst, const int k) {
147147
const int i = blockDim.x*blockIdx.x + threadIdx.x;
148148

149149
if (i >= k) {
@@ -152,7 +152,7 @@ static __global__ void sin_f32(const float * __restrict__ x, float * __restrict_
152152
dst[i] = sinf(x[i]);
153153
}
154154

155-
static __global__ void cos_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
155+
static __global__ void cos_f32(const float * x, float * dst, const int k) {
156156
const int i = blockDim.x*blockIdx.x + threadIdx.x;
157157

158158
if (i >= k) {

0 commit comments

Comments
 (0)