Skip to content

Commit f0e9f87

Browse files
dccipytorchmergebot
authored andcommitted
[BE/mps] Mark input args as constant to prevent incorrect usage. (pytorch#145535)
Pull Request resolved: pytorch#145535 Approved by: https://github.com/malfet, https://github.com/jansel
1 parent 6aaae9d commit f0e9f87

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ constant float d[2] = {3.543889200, 1.637067800};
1111
template <typename T0, typename T1>
1212
kernel void erfinv_kernel(
1313
device T0* output [[buffer(0)]],
14-
device T1* input [[buffer(1)]],
14+
constant T1* input [[buffer(1)]],
1515
uint index [[thread_position_in_grid]]) {
1616
float y = input[index];
1717
float x, z, num, dem; /*working variables */
@@ -40,15 +40,15 @@ kernel void erfinv_kernel(
4040
template <typename T0, typename T1>
4141
kernel void exp_kernel(
4242
device T0* output [[buffer(0)]],
43-
device T1* input [[buffer(1)]],
43+
constant T1* input [[buffer(1)]],
4444
uint index [[thread_position_in_grid]]) {
4545
output[index] = T0(precise::exp(input[index]));
4646
}
4747

4848
template <typename T0>
4949
kernel void exp_complex_kernel(
5050
device vec2type_t<T0>* output [[buffer(0)]],
51-
device vec2type_t<T0>* input [[buffer(1)]],
51+
constant vec2type_t<T0>* input [[buffer(1)]],
5252
uint index [[thread_position_in_grid]]) {
5353
output[index].x =
5454
T0(precise::exp(input[index].x) * precise::cos(input[index].y));
@@ -59,7 +59,7 @@ kernel void exp_complex_kernel(
5959
template <typename T0, typename T1>
6060
kernel void tanh_kernel(
6161
device T0* output [[buffer(0)]],
62-
device T1* input [[buffer(1)]],
62+
constant T1* input [[buffer(1)]],
6363
uint index [[thread_position_in_grid]]) {
6464
output[index] = T0(precise::tanh(input[index]));
6565
}
@@ -83,7 +83,7 @@ T complex_div(T a, T b) {
8383
template <typename T0>
8484
kernel void tanh_complex_kernel(
8585
device vec2type_t<T0>* output [[buffer(0)]],
86-
device vec2type_t<T0>* input [[buffer(1)]],
86+
constant vec2type_t<T0>* input [[buffer(1)]],
8787
uint index [[thread_position_in_grid]]) {
8888
// tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
8989
auto tanh_x = T0(precise::tanh(input[index].x));
@@ -96,15 +96,15 @@ kernel void tanh_complex_kernel(
9696
template [[host_name("erfinv_" #DTYPE0 "_" #DTYPE1)]] kernel void \
9797
erfinv_kernel( \
9898
device DTYPE0* output [[buffer(0)]], \
99-
device DTYPE1* input [[buffer(1)]], \
99+
constant DTYPE1* input [[buffer(1)]], \
100100
uint id [[thread_position_in_grid]]); \
101101
template [[host_name("exp_" #DTYPE0 "_" #DTYPE1)]] kernel void exp_kernel( \
102102
device DTYPE0* output [[buffer(0)]], \
103-
device DTYPE1* input [[buffer(1)]], \
103+
constant DTYPE1* input [[buffer(1)]], \
104104
uint id [[thread_position_in_grid]]); \
105105
template [[host_name("tanh_" #DTYPE0 "_" #DTYPE1)]] kernel void tanh_kernel( \
106106
device DTYPE0* output [[buffer(0)]], \
107-
device DTYPE1* input [[buffer(1)]], \
107+
constant DTYPE1* input [[buffer(1)]], \
108108
uint id [[thread_position_in_grid]]);
109109

110110
#if __METAL_VERSION__ >= 310
@@ -123,12 +123,12 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
123123
template [[host_name("exp_complex_" #DTYPE0 "_" #DTYPE1)]] kernel void \
124124
exp_complex_kernel<DTYPE0>( \
125125
device vec2type_t<DTYPE0> * output [[buffer(0)]], \
126-
device vec2type_t<DTYPE0> * input [[buffer(1)]], \
126+
constant vec2type_t<DTYPE0> * input [[buffer(1)]], \
127127
uint did [[thread_position_in_grid]]); \
128128
template [[host_name("tanh_complex_" #DTYPE0 "_" #DTYPE1)]] kernel void \
129129
tanh_complex_kernel<DTYPE0>( \
130130
device vec2type_t<DTYPE0> * output [[buffer(0)]], \
131-
device vec2type_t<DTYPE0> * input [[buffer(1)]], \
131+
constant vec2type_t<DTYPE0> * input [[buffer(1)]], \
132132
uint did [[thread_position_in_grid]]);
133133

134134
INSTANTIATE_UNARY_KERNELS_VEC2(short, short);

0 commit comments

Comments
 (0)