@@ -11,7 +11,7 @@ constant float d[2] = {3.543889200, 1.637067800};
1111template <typename T0, typename T1>
1212kernel void erfinv_kernel (
1313 device T0* output [[buffer(0 )]],
14- device T1* input [[buffer(1 )]],
14+ constant T1* input [[buffer(1 )]],
1515 uint index [[thread_position_in_grid]]) {
1616 float y = input[index];
1717 float x, z, num, dem; /* working variables */
@@ -40,15 +40,15 @@ kernel void erfinv_kernel(
4040template <typename T0, typename T1>
4141kernel void exp_kernel (
4242 device T0* output [[buffer(0 )]],
43- device T1* input [[buffer(1 )]],
43+ constant T1* input [[buffer(1 )]],
4444 uint index [[thread_position_in_grid]]) {
4545 output[index] = T0 (precise::exp (input[index]));
4646}
4747
4848template <typename T0>
4949kernel void exp_complex_kernel (
5050 device vec2type_t <T0>* output [[buffer(0 )]],
51- device vec2type_t<T0>* input [[buffer(1 )]],
51+ constant vec2type_t<T0>* input [[buffer(1 )]],
5252 uint index [[thread_position_in_grid]]) {
5353 output[index].x =
5454 T0 (precise::exp (input[index].x ) * precise::cos (input[index].y ));
@@ -59,7 +59,7 @@ kernel void exp_complex_kernel(
5959template <typename T0, typename T1>
6060kernel void tanh_kernel (
6161 device T0* output [[buffer(0 )]],
62- device T1* input [[buffer(1 )]],
62+ constant T1* input [[buffer(1 )]],
6363 uint index [[thread_position_in_grid]]) {
6464 output[index] = T0 (precise::tanh (input[index]));
6565}
@@ -83,7 +83,7 @@ T complex_div(T a, T b) {
8383template <typename T0>
8484kernel void tanh_complex_kernel (
8585 device vec2type_t <T0>* output [[buffer(0 )]],
86- device vec2type_t<T0>* input [[buffer(1 )]],
86+ constant vec2type_t<T0>* input [[buffer(1 )]],
8787 uint index [[thread_position_in_grid]]) {
8888 // tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
8989 auto tanh_x = T0 (precise::tanh (input[index].x ));
@@ -96,15 +96,15 @@ kernel void tanh_complex_kernel(
9696 template [[host_name(" erfinv_" #DTYPE0 " _" #DTYPE1)]] kernel void \
9797 erfinv_kernel ( \
9898 device DTYPE0* output [[buffer(0 )]], \
99- device DTYPE1* input [[buffer(1 )]], \
99+ constant DTYPE1* input [[buffer(1 )]], \
100100 uint id [[thread_position_in_grid]]); \
101101 template [[host_name(" exp_" #DTYPE0 " _" #DTYPE1)]] kernel void exp_kernel ( \
102102 device DTYPE0* output [[buffer(0 )]], \
103- device DTYPE1* input [[buffer(1 )]], \
103+ constant DTYPE1* input [[buffer(1 )]], \
104104 uint id [[thread_position_in_grid]]); \
105105 template [[host_name(" tanh_" #DTYPE0 " _" #DTYPE1)]] kernel void tanh_kernel ( \
106106 device DTYPE0* output [[buffer(0 )]], \
107- device DTYPE1* input [[buffer(1 )]], \
107+ constant DTYPE1* input [[buffer(1 )]], \
108108 uint id [[thread_position_in_grid]]);
109109
110110#if __METAL_VERSION__ >= 310
@@ -123,12 +123,12 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
123123 template [[host_name(" exp_complex_" #DTYPE0 " _" #DTYPE1)]] kernel void \
124124 exp_complex_kernel<DTYPE0>( \
125125 device vec2type_t <DTYPE0> * output [[buffer(0 )]], \
126- device vec2type_t <DTYPE0> * input [[buffer(1 )]], \
126+ constant vec2type_t <DTYPE0> * input [[buffer(1 )]], \
127127 uint did [[thread_position_in_grid]]); \
128128 template [[host_name(" tanh_complex_" #DTYPE0 " _" #DTYPE1)]] kernel void \
129129 tanh_complex_kernel<DTYPE0>( \
130130 device vec2type_t <DTYPE0> * output [[buffer(0 )]], \
131- device vec2type_t <DTYPE0> * input [[buffer(1 )]], \
131+ constant vec2type_t <DTYPE0> * input [[buffer(1 )]], \
132132 uint did [[thread_position_in_grid]]);
133133
134134INSTANTIATE_UNARY_KERNELS_VEC2 (short , short );
0 commit comments