Skip to content

Commit 3789518

Browse files
committed
replace ifs with merge
1 parent 14af3f9 commit 3789518

File tree

1 file changed

+11
-40
lines changed

1 file changed

+11
-40
lines changed

src/stdlib_specialfunctions_activations.fypp

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,14 @@ elemental module function elu_${rk}$( x , a ) result ( y )
3535
${rt}$, intent(in) :: x
3636
${rt}$, intent(in) :: a
3737
${rt}$ :: y
38-
if(x >= 0._${rk}$)then
39-
y = x
40-
else
41-
y = a * (exp(x) - 1._${rk}$)
42-
end if
38+
y = merge( x , a * (exp(x) - 1._${rk}$), x >= 0._${rk}$)
4339
end function
4440

4541
elemental module function elu_grad_${rk}$( x , a ) result ( y )
4642
${rt}$, intent(in) :: x
4743
${rt}$, intent(in) :: a
4844
${rt}$ :: y
49-
if(x >= 0._${rk}$)then
50-
y = 1._${rk}$
51-
else
52-
y = a * exp(x)
53-
end if
45+
y = merge( 1._${rk}$ , a * exp(x), x >= 0._${rk}$)
5446
end function
5547

5648
#:endfor
@@ -68,11 +60,7 @@ end function
6860
elemental module function relu_grad_${rk}$( x ) result( y )
6961
${rt}$, intent(in) :: x
7062
${rt}$ :: y
71-
if(x > 0._${rk}$)then
72-
y = 1._${rk}$
73-
else
74-
y = 0._${rk}$
75-
end if
63+
y = merge( 1._${rk}$ , 0._${rk}$, x > 0._${rk}$)
7664
end function
7765

7866
#:endfor
@@ -121,23 +109,16 @@ elemental module function selu_${rk}$( x ) result( y )
121109
${rt}$ :: y
122110
${rt}$, parameter :: scale = 1.0507009873554804934193349852946_${rk}$
123111
${rt}$, parameter :: alpha = 1.6732632423543772848170429916717_${rk}$
124-
if(x > 0._${rk}$)then
125-
y = scale * x
126-
else
127-
y = scale * (alpha * exp(x) - alpha)
128-
end if
112+
y = merge( x , alpha * exp(x) - alpha, x > 0._${rk}$)
113+
y = scale * y
129114
end function
130115

131116
elemental module function selu_grad_${rk}$( x ) result( y )
132117
${rt}$, intent(in) :: x
133118
${rt}$ :: y
134119
${rt}$, parameter :: scale = 1.0507009873554804934193349852946_${rk}$
135120
${rt}$, parameter :: alpha = 1.6732632423543772848170429916717_${rk}$
136-
if(x > 0._${rk}$)then
137-
y = scale
138-
else
139-
y = scale * alpha * exp(x)
140-
end if
121+
y = merge( scale , scale * alpha * exp(x), x > 0._${rk}$)
141122
end function
142123

143124
#:endfor
@@ -186,11 +167,7 @@ end function
186167
elemental module function Step_${rk}$( x ) result( y )
187168
${rt}$, intent(in) :: x
188169
${rt}$ :: y
189-
if(x > 0._${rk}$)then
190-
y = 1._${rk}$
191-
else
192-
y = 0._${rk}$
193-
end if
170+
y = merge( 1._${rk}$ , 0._${rk}$, x > 0._${rk}$)
194171
end function
195172

196173
elemental module function Step_grad_${rk}$( x ) result( y )
@@ -360,16 +337,10 @@ elemental module function ftanh_${rk}$( x ) result( y )
360337
${rt}$ :: y
361338
${rt}$ :: x2, a, b
362339

363-
if (x > 5._${rk}$) then
364-
y = 1._${rk}$
365-
elseif (x < -5._${rk}$) then
366-
y = -1._${rk}$
367-
else
368-
x2 = x*x
369-
a = x * (135135.0_${rk}$ + x2 * (17325.0_${rk}$ + x2 * (378.0_${rk}$ + x2)))
370-
b = 135135.0_${rk}$ + x2 * (62370.0_${rk}$ + x2 * (3150.0_${rk}$ + x2 * 28.0_${rk}$))
371-
y = a / b
372-
end if
340+
x2 = x*x
341+
a = x * (135135.0_${rk}$ + x2 * (17325.0_${rk}$ + x2 * (378.0_${rk}$ + x2)))
342+
b = 135135.0_${rk}$ + x2 * (62370.0_${rk}$ + x2 * (3150.0_${rk}$ + x2 * 28.0_${rk}$))
343+
y = merge( a / b , sign(1._${rk}$,x) , x2 <= 25._${rk}$ )
373344
end function
374345

375346
elemental module function ferf_${rk}$( x ) result( y )

0 commit comments

Comments
 (0)