Skip to content

Commit 5d41402

Browse files
committed
refactor tanh, add docs, tests on all real precisions
1 parent 9dacc48 commit 5d41402

File tree

4 files changed

+413
-318
lines changed

4 files changed

+413
-318
lines changed

doc/specs/stdlib_specialfunctions_activations.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,4 +762,85 @@ Elemental function
762762

763763
### Return value
764764

765+
The function returns a value with the same type and kind as input argument.
766+
767+
## `Fast tanh` - Approximation of the hyperbolic tangent function
768+
769+
### Status
770+
771+
Experimental
772+
773+
### Description
774+
775+
Computes an approximated but faster solution to:
776+
$$f(x)=\tanh(x)$$
777+
778+
### Syntax
779+
780+
`result = ` [[stdlib_specialfunctions(module):fast_tanh(interface)]] ` (x)`
781+
782+
### Class
783+
784+
Elemental function
785+
786+
### Arguments
787+
788+
`x`: Shall be a scalar or array of any `real` kind.
789+
790+
### Return value
791+
792+
The function returns a value with the same type and kind as input argument.
793+
794+
## `fast_tanh_grad` - Gradient of the approximation of the hyperbolic tangent function
795+
796+
### Status
797+
798+
Experimental
799+
800+
### Description
801+
802+
Computes the gradient of the `fast_tanh` function:
803+
$$f(x)=1 - \fast_tanh(x)^2$$
804+
805+
### Syntax
806+
807+
`result = ` [[stdlib_specialfunctions(module):fast_tanh_grad(interface)]] ` (x)`
808+
809+
### Class
810+
811+
Elemental function
812+
813+
### Arguments
814+
815+
`x`: Shall be a scalar or array of any `real` kind.
816+
817+
### Return value
818+
819+
The function returns a value with the same type and kind as input argument.
820+
821+
## `Fast erf` - Approximation of the error function
822+
823+
### Status
824+
825+
Experimental
826+
827+
### Description
828+
829+
Computes an approximated but faster solution to:
830+
$$f(x)=\erf(x)$$
831+
832+
### Syntax
833+
834+
`result = ` [[stdlib_specialfunctions(module):fast_erf(interface)]] ` (x)`
835+
836+
### Class
837+
838+
Elemental function
839+
840+
### Arguments
841+
842+
`x`: Shall be a scalar or array of any `real` kind.
843+
844+
### Return value
845+
765846
The function returns a value with the same type and kind as input argument.

src/stdlib_specialfunctions.fypp

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -315,32 +315,6 @@ module stdlib_specialfunctions
315315
end interface
316316
public :: step_grad
317317

318-
interface tanh
319-
!! Version: experimental
320-
!!
321-
!! gaussian function
322-
#:for rk, rt in REAL_KINDS_TYPES
323-
elemental module function tanh_${rk}$( x ) result( y )
324-
${rt}$, intent(in) :: x
325-
${rt}$ :: y
326-
end function
327-
#:endfor
328-
end interface
329-
public :: tanh
330-
331-
interface tanh_grad
332-
!! Version: experimental
333-
!!
334-
!! gradient of the hyperbolic tangent function
335-
#:for rk, rt in REAL_KINDS_TYPES
336-
elemental module function tanh_grad_${rk}$( x ) result( y )
337-
${rt}$, intent(in) :: x
338-
${rt}$ :: y
339-
end function
340-
#:endfor
341-
end interface
342-
public :: tanh_grad
343-
344318
interface softmax
345319
!! Version: experimental
346320
!!
@@ -435,32 +409,43 @@ module stdlib_specialfunctions
435409
end interface
436410
public :: softplus_grad
437411

438-
interface ftanh
412+
interface fast_tanh
439413
!! Version: experimental
440414
!!
441415
!! Fast approximation of the tanh function
442-
!! Source: https://fortran-lang.discourse.group/t/fastgpt-faster-than-pytorch-in-300-lines-of-fortran/5385/31
443416
#:for rk, rt in REAL_KINDS_TYPES
444-
elemental module function ftanh_${rk}$( x ) result( y )
417+
elemental module function fast_tanh_${rk}$( x ) result( y )
418+
${rt}$, intent(in) :: x
419+
${rt}$ :: y
420+
end function
421+
#:endfor
422+
end interface
423+
public :: fast_tanh
424+
425+
interface fast_tanh_grad
426+
!! Version: experimental
427+
!!
428+
!! gradient of the hyperbolic tangent function
429+
#:for rk, rt in REAL_KINDS_TYPES
430+
elemental module function fast_tanh_grad_${rk}$( x ) result( y )
445431
${rt}$, intent(in) :: x
446432
${rt}$ :: y
447433
end function
448434
#:endfor
449435
end interface
450-
public :: ftanh
436+
public :: fast_tanh_grad
451437

452-
interface ferf
438+
interface fast_erf
453439
!! Version: experimental
454440
!!
455441
!! Fast approximation of the erf function
456-
!! Source: https://fortran-lang.discourse.group/t/fastgpt-faster-than-pytorch-in-300-lines-of-fortran/5385/31
457442
#:for rk, rt in REAL_KINDS_TYPES
458-
elemental module function ferf_${rk}$( x ) result( y )
443+
elemental module function fast_erf_${rk}$( x ) result( y )
459444
${rt}$, intent(in) :: x
460445
${rt}$ :: y
461446
end function
462447
#:endfor
463448
end interface
464-
public :: ferf
449+
public :: fast_erf
465450

466-
end module stdlib_specialfunctions
451+
end module stdlib_specialfunctions

src/stdlib_specialfunctions_activations.fypp

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ end function
108108
elemental module function gelu_approx_${rk}$( x ) result( y )
109109
${rt}$, intent(in) :: x
110110
${rt}$ :: y
111-
y = 0.5_${rk}$ * x * (1._${rk}$ + ferf(x * isqrt2_${rk}$))
111+
y = 0.5_${rk}$ * x * (1._${rk}$ + fast_erf(x * isqrt2_${rk}$))
112112
end function
113113

114114
elemental module function gelu_approx_grad_${rk}$( x ) result( y )
115115
${rt}$, intent(in) :: x
116116
${rt}$ :: y
117-
y = 0.5_${rk}$ * (1._${rk}$ + ferf(x * isqrt2_${rk}$) )
117+
y = 0.5_${rk}$ * (1._${rk}$ + fast_erf(x * isqrt2_${rk}$) )
118118
y = y + x * isqrt2_${rk}$ * exp( - 0.5_${rk}$ * x**2 )
119119
end function
120120

@@ -198,24 +198,6 @@ end function
198198

199199
#:endfor
200200

201-
!==================================================
202-
! tanh
203-
!==================================================
204-
#:for rk, rt in REAL_KINDS_TYPES
205-
elemental module function tanh_${rk}$( x ) result( y )
206-
${rt}$, intent(in) :: x
207-
${rt}$ :: y
208-
y = ftanh(x)
209-
end function
210-
211-
elemental module function tanh_grad_${rk}$( x ) result( y )
212-
${rt}$, intent(in) :: x
213-
${rt}$ :: y
214-
y = 1._${rk}$ - ftanh(x)**2
215-
end function
216-
217-
#:endfor
218-
219201
!==================================================
220202
! softmax
221203
!==================================================
@@ -352,7 +334,8 @@ end function
352334
!==================================================
353335

354336
#:for rk, rt in REAL_KINDS_TYPES
355-
elemental module function ftanh_${rk}$( x ) result( y )
337+
! Source: https://fortran-lang.discourse.group/t/fastgpt-faster-than-pytorch-in-300-lines-of-fortran/5385/31
338+
elemental module function fast_tanh_${rk}$( x ) result( y )
356339
${rt}$, intent(in) :: x
357340
${rt}$ :: y
358341
${rt}$ :: x2, a, b
@@ -363,7 +346,14 @@ elemental module function ftanh_${rk}$( x ) result( y )
363346
y = merge( a / b , sign(1._${rk}$,x) , x2 <= 25._${rk}$ )
364347
end function
365348

366-
elemental module function ferf_${rk}$( x ) result( y )
349+
elemental module function fast_tanh_grad_${rk}$( x ) result( y )
350+
${rt}$, intent(in) :: x
351+
${rt}$ :: y
352+
y = 1._${rk}$ - fast_tanh(x)**2
353+
end function
354+
355+
! Source: https://fortran-lang.discourse.group/t/fastgpt-faster-than-pytorch-in-300-lines-of-fortran/5385/31
356+
elemental module function fast_erf_${rk}$( x ) result( y )
367357
${rt}$, intent(in) :: x
368358
${rt}$ :: y
369359
${rt}$ :: abs_x

0 commit comments

Comments
 (0)