Skip to content

Commit e483325

Browse files
committed
add leaky relu activation
1 parent f06ab3b commit e483325

File tree

6 files changed

+156
-0
lines changed

6 files changed

+156
-0
lines changed

doc/specs/stdlib_specialfunctions_activations.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,79 @@ Elemental function
203203

204204
The function returns a value with the same type and kind as input argument.
205205

206+
## `Leaky_relu` - Leaky Rectified Linear Unit function
207+
208+
### Status
209+
210+
Experimental
211+
212+
### Description
213+
214+
Computes the gaussian function:
215+
$$
216+
\text{f}(x) =
217+
\begin{cases}
218+
x, & \text{if } x \geq 0 \\
219+
a * x, & \text{otherwise}
220+
\end{cases}
221+
$$
222+
223+
### Syntax
224+
225+
`result = ` [[stdlib_specialfunctions(module):leaky_relu(interface)]] ` (x,a)`
226+
227+
### Class
228+
229+
Elemental function
230+
231+
### Arguments
232+
233+
`x`: Shall be a scalar or array of any `real` kind.
234+
`a`: Shall be a scalar or array of any `real` kind.
235+
236+
### Return value
237+
238+
The function returns a value with the same type and kind as input argument.
239+
240+
### Example
241+
```fortran
242+
{!example/specialfunctions_activations/example_leaky_relu.f90!}
243+
```
244+
245+
## `Leaky_relu_grad` - Gradient of the Leaky Rectified Linear Unit function
246+
247+
### Status
248+
249+
Experimental
250+
251+
### Description
252+
253+
Computes the gradient of the leaky_relu function:
254+
$$
255+
\text{f}(x) =
256+
\begin{cases}
257+
1, & \text{if } x \geq 0 \\
258+
a , & \text{otherwise}
259+
\end{cases}
260+
$$
261+
262+
### Syntax
263+
264+
`result = ` [[stdlib_specialfunctions(module):leaky_relu_grad(interface)]] ` (x,a)`
265+
266+
### Class
267+
268+
Elemental function
269+
270+
### Arguments
271+
272+
`x`: Shall be a scalar or array of any `real` kind.
273+
`a`: Shall be a scalar or array of any `real` kind.
274+
275+
### Return value
276+
277+
The function returns a value with the same type and kind as the input argument.
278+
206279
## `Gelu` - Gaussian Error Linear Unit function
207280

208281
### Status

example/specialfunctions_activations/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
ADD_EXAMPLE(elu)
22
ADD_EXAMPLE(gaussian)
33
ADD_EXAMPLE(gelu)
4+
ADD_EXAMPLE(leaky_relu)
45
ADD_EXAMPLE(relu)
56
ADD_EXAMPLE(selu)
67
ADD_EXAMPLE(silu)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
program example_gelu
2+
use stdlib_kinds, only: sp
3+
use stdlib_math, only: linspace
4+
use stdlib_specialfunctions, only: leaky_relu
5+
implicit none
6+
integer, parameter :: n = 10
7+
real(sp) :: x(n), y(n)
8+
9+
x = linspace(-2._sp, 2._sp, n)
10+
y = leaky_relu( x , 0.1_sp )
11+
print *, y
12+
end program example_gelu
13+

src/stdlib_specialfunctions.fypp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,36 @@ module stdlib_specialfunctions
117117
end interface
118118
public :: relu_grad
119119

120+
interface leaky_relu
121+
!! Version: experimental
122+
!!
123+
!! Leaky Rectified linear unit function
124+
!> ([Specification](../page/specs/stdlib_specialfunctions.html#leaky_relu))
125+
#:for rk, rt in REAL_KINDS_TYPES
126+
elemental module function leaky_relu_${rk}$( x , a ) result( y )
127+
${rt}$, intent(in) :: x
128+
${rt}$, intent(in) :: a
129+
${rt}$ :: y
130+
end function
131+
#:endfor
132+
end interface
133+
public :: leaky_relu
134+
135+
interface leaky_relu_grad
136+
!! Version: experimental
137+
!!
138+
!! Gradient of the Leaky Rectified linear unit function
139+
!> ([Specification](../page/specs/stdlib_specialfunctions.html#leaky_relu_grad))
140+
#:for rk, rt in REAL_KINDS_TYPES
141+
elemental module function leaky_relu_grad_${rk}$( x , a ) result( y )
142+
${rt}$, intent(in) :: x
143+
${rt}$, intent(in) :: a
144+
${rt}$ :: y
145+
end function
146+
#:endfor
147+
end interface
148+
public :: leaky_relu_grad
149+
120150
interface gelu
121151
!! Version: experimental
122152
!!

src/stdlib_specialfunctions_activations.fypp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,26 @@ end function
6565

6666
#:endfor
6767

68+
!==================================================
69+
! Leaky Rectified Linear Unit
70+
!==================================================
71+
#:for rk, rt in REAL_KINDS_TYPES
72+
elemental module function Leaky_relu_${rk}$( x , a ) result( y )
73+
${rt}$, intent(in) :: x
74+
${rt}$, intent(in) :: a
75+
${rt}$ :: y
76+
y = merge( x, a * x , x >= 0._${rk}$)
77+
end function
78+
79+
elemental module function Leaky_relu_grad_${rk}$( x , a ) result( y )
80+
${rt}$, intent(in) :: x
81+
${rt}$, intent(in) :: a
82+
${rt}$ :: y
83+
y = merge( 1._${rk}$ , a , x >= 0._${rk}$)
84+
end function
85+
86+
#:endfor
87+
6888
!==================================================
6989
! GELU: Gaussian Error Linear Units function
7090
!==================================================

test/specialfunctions/test_specialfunctions_activations.fypp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ contains
2424
new_unittest("gaussian", test_gaussian), &
2525
new_unittest("elu", test_elu), &
2626
new_unittest("relu", test_relu), &
27+
new_unittest("leaky_relu", test_leaky_relu), &
2728
new_unittest("gelu" , test_gelu), &
2829
new_unittest("selu" , test_selu), &
2930
new_unittest("sigmoid", test_sigmoid), &
@@ -147,6 +148,24 @@ contains
147148
if (allocated(error)) return
148149
end subroutine
149150

151+
subroutine test_leaky_relu(error)
152+
type(error_type), allocatable, intent(out) :: error
153+
integer, parameter :: n = 10
154+
real(sp) :: x(n), y(n), y_ref(n), a
155+
156+
call random_number(x)
157+
a = 0.1_sp
158+
where(x>=0._sp)
159+
y_ref = x
160+
elsewhere
161+
y_ref = a * x
162+
end where
163+
y = Leaky_relu( x , a )
164+
165+
call check(error, norm2(y-y_ref) < n*tol_sp )
166+
if (allocated(error)) return
167+
end subroutine
168+
150169
subroutine test_sigmoid(error)
151170
type(error_type), allocatable, intent(out) :: error
152171
integer, parameter :: n = 10

0 commit comments

Comments
 (0)