3
3
# Some of activation functions have its wrapper function for GPU in NNlibCUDA.jl.
4
4
# https://github.com/JuliaGPU/CuArrays.jl/issues/614
5
5
6
- const ACTIVATIONS = [
6
+ ACTIVATIONS = [
7
7
:σ , :hardσ , :hardtanh , :relu ,
8
8
:leakyrelu , :relu6 , :rrelu , :elu , :gelu , :swish , :hardswish , :selu ,
9
9
:celu , :softplus , :softsign , :logσ , :logcosh ,
@@ -32,6 +32,8 @@ The ascii name `sigmoid` is also exported.
32
32
See also [`sigmoid_fast`](@ref).
33
33
34
34
```
35
+ julia> using UnicodePlots
36
+
35
37
julia> lineplot(sigmoid, -5, 5, height=7)
36
38
┌────────────────────────────────────────┐
37
39
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)
@@ -292,9 +294,9 @@ julia> elu(-10f0, 2)
292
294
-1.9999092f0
293
295
```
294
296
"""
295
- elu (x, α= 1 ) = ifelse (x ≥ 0 , float (x), @fastmath α * (exp (x) - 1 ))
297
+ elu (x, α= 1 ) = ifelse (x ≥ 0 , float (x), @fastmath oftf (x, α) * (exp (x) - 1 ))
296
298
297
- deriv_elu (Ω, α= 1 ) = ifelse (Ω ≥ 0 , one (Ω), Ω + α )
299
+ deriv_elu (Ω, α= 1 ) = ifelse (Ω ≥ 0 , one (Ω), Ω + oftype (Ω, α) )
298
300
299
301
"""
300
302
gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
@@ -314,6 +316,21 @@ julia> lineplot(gelu, -2, 2, height=7)
314
316
└────────────────────────────────────────┘
315
317
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
316
318
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
319
+
320
+ julia> lineplot(gelu, -5, 0, height=7);
321
+
322
+ julia> lineplot!(ans, swish)
323
+ ┌────────────────────────────────────────┐
324
+ 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu(x)
325
+ │⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x)
326
+ │⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│
327
+ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│
328
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⣄⠀⠀⠀⠀⠀⢠⡞⠀⠀│
329
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⢄⣀⣀⡤⢣⠃⠀⠀│
330
+ -0.2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠇⠀⠀⠀│
331
+ └────────────────────────────────────────┘
332
+ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀
333
+ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
317
334
```
318
335
"""
319
336
function gelu (x)
@@ -353,26 +370,47 @@ julia> lineplot(swish, -2, 2, height=7)
353
370
"""
354
371
hardswish(x) = x * hardσ(x)
355
372
356
- Hard-Swish activation function
357
- See (["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244)).
358
-
359
- ```
360
- julia> lineplot(hardswish, -2, 2, height = 7)
361
- ┌────────────────────────────────────────┐
362
- 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ hardswish(x)
363
- │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋⠁│
364
- │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⠤⠒⠋⠁⠀⠀⠀⠀│
365
- f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
366
- │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⡤⡧⠴⠶⠯⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
367
- │⠒⠲⠤⠤⠤⠤⠤⠤⠖⠒⠒⠒⠒⠊⠉⠉⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
368
- -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
369
- └────────────────────────────────────────┘
370
- ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
371
- ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
373
+ Hard-Swish activation function.
374
+ See ["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244).
375
+
376
+ ```
377
+ julia> lineplot(hardswish, -2, 5, height = 7)
378
+ ┌────────────────────────────────────────┐
379
+ 5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠒⠉│ hardswish(x)
380
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠒⠉⠁⠀⠀⠀⠀│
381
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
382
+ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
383
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
384
+ │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣤⣤⣖⣚⣉⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│
385
+ -1 │⠉⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
386
+ └────────────────────────────────────────┘
387
+ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
388
+ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
389
+
390
+ julia> lineplot(hardswish, -4, 0, height = 7);
391
+
392
+ julia> lineplot!(ans, swish)
393
+ ┌────────────────────────────────────────┐
394
+ 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜│ hardswish(x)
395
+ │⠒⠒⠢⠤⢄⣀⡀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀│ swish(x)
396
+ │⠀⠀⠀⠀⠀⠀⠈⠉⠑⠒⠦⢄⣘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠃⠀⠀│
397
+ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠑⡖⠦⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⢔⠏⠁⠀⠀⠀│
398
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠣⣄⠀⠉⠑⠒⠦⠤⢄⣀⣀⣀⣀⡠⠤⠖⣊⠕⠁⠀⠀⠀⠀⠀│
399
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⠤⡀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀│
400
+ -0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠒⠢⠤⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
401
+ └────────────────────────────────────────┘
402
+ ⠀-4⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀
403
+ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
404
+
405
+ julia> hardswish.(-5:5)'
406
+ 1×11 adjoint(::Vector{Float64}) with eltype Float64:
407
+ -0.0 -0.0 -0.0 -0.333333 -0.333333 0.0 0.666667 1.66667 3.0 4.0 5.0
372
408
```
373
409
"""
374
410
@inline hardswish (x) = x * hardσ (x)
375
411
412
+ deriv_hardswish (x) = ifelse (x < - 3 , oftf (x,0 ), ifelse (x > 3 , oftf (x,1 ), x/ 3 + 1 // 2 ))
413
+
376
414
"""
377
415
lisht(x) = x * tanh(x)
378
416
@@ -392,6 +430,19 @@ julia> lineplot(lisht, -2, 2, height=7)
392
430
└────────────────────────────────────────┘
393
431
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
394
432
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
433
+
434
+ julia> lineplot!(ans, logcosh)
435
+ ┌────────────────────────────────────────┐
436
+ 2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)
437
+ │⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│ logcosh(x)
438
+ │⠢⣄⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⣀⠔│
439
+ f(x) │⠀⠈⠑⠢⣀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⣀⠔⠊⠁⠀│
440
+ │⠀⠀⠀⠀⠀⠉⠢⢄⡀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⡠⠔⠋⠁⠀⠀⠀⠀│
441
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠦⣌⡓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⣁⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀│
442
+ 0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠪⠷⣦⣄⣀⣀⣇⣀⣀⣤⠶⠕⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
443
+ └────────────────────────────────────────┘
444
+ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
445
+ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
395
446
```
396
447
"""
397
448
lisht (x) = x * tanh_fast (x)
@@ -406,17 +457,17 @@ Scaled exponential linear units.
406
457
See ["Self-Normalizing Neural Networks"](https://arxiv.org/abs/1706.02515).
407
458
408
459
```
409
- julia> lineplot(selu, -2 , 2, height=7)
460
+ julia> lineplot(selu, -3 , 2, height=7)
410
461
┌────────────────────────────────────────┐
411
- 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x)
412
- │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠔ ⠒│
413
- │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ ⠀⠀⠀⠀⠀⠀⠀⣀⣀⠤⠔⠒⠋⠉⠀ ⠀⠀⠀⠀│
414
- f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⣀⡤⠤⠒⠊⠉⠁⠀⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀│
415
- │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⡩⠭⠛⡏ ⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│
416
- │⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⡤⠤⠔ ⠒⠊⠉⠁ ⠀⠀⠀⡇⠀⠀⠀⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
417
- -2 │⠒⠒ ⠉⠉⠉⠉⠉ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
462
+ 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x)
463
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤ ⠒│
464
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ ⠀⠀⠀⠀⠀⢀⣀⠤⠖⠊⠉ ⠀⠀⠀⠀│
465
+ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⡠⠤⠒⠋⠁ ⠀⠀⠀⠀⠀⠀⠀⠀⠀│
466
+ │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⠭⠛⡏ ⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│
467
+ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⡤⠤ ⠒⠊⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
468
+ -2 │⠤⠤⠖⠒⠒⠒⠒⠒⠒⠒ ⠉⠉⠉⠁ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
418
469
└────────────────────────────────────────┘
419
- ⠀-2 ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
470
+ ⠀-3 ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
420
471
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
421
472
422
473
julia> selu(-10f0)
@@ -461,7 +512,9 @@ julia> celu(-10f0)
461
512
-0.9999546f0
462
513
```
463
514
"""
464
- celu (x, α= 1 ) = ifelse (x ≥ 0 , float (x), α * (exp (x/ α) - 1 ))
515
+ celu (x, α= 1 ) = ifelse (x ≥ 0 , float (x), oftf (x,α) * (exp (x/ oftf (x,α)) - 1 ))
516
+
517
+ deriv_celu (Ω, α= 1 ) = ifelse (Ω > 0 , oftf (Ω, 1 ), Ω / oftf (Ω, α) + 1 )
465
518
466
519
"""
467
520
trelu(x, theta=1) = x > theta ? x : 0
@@ -529,6 +582,8 @@ julia> softsign(100f0)
529
582
"""
530
583
softsign (x) = x / (1 + abs (x))
531
584
585
+ deriv_softsign (x) = 1 / (1 + abs (x))^ 2
586
+
532
587
"""
533
588
softplus(x) = log(exp(x) + 1)
534
589
@@ -675,9 +730,9 @@ julia> softshrink.((-10f0, 10f0))
675
730
(-9.5f0, 9.5f0)
676
731
```
677
732
"""
678
- function softshrink (x, λ= oftf (x, 0.5 ) )
679
- lo = x - λ
680
- hi = x + λ
733
+ function softshrink (x, λ = 0.5 )
734
+ lo = x - oftf (x, λ)
735
+ hi = x + oftf (x, λ)
681
736
ifelse (hi > 0 , ifelse (lo < 0 , zero (hi), lo), hi)
682
737
end
683
738
@@ -785,58 +840,84 @@ this replacement for some array or element types.
785
840
786
841
# # Define rrules for some activation functions, along with the
787
842
# # broadcasted rrule activation functions.
788
- # # TODO : add to the lists below all activations.
789
843
790
844
# # This is a performance hack specifically for Zygote, because it doesn't handle fused
791
845
# # broadcasts well; but it generally should be good (or at least harmless) for any AD, as
792
846
# # it saves ADing the broadcasting machinery.
793
847
# # Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271
794
848
795
- UNARY_ACTS = [ # f, df
796
- ( :relu , :(x > 0 )),
797
- ( :hardtanh , :( - 1 < x < 1 )),
798
- ( :selu , :( deriv_selu (Ω))),
849
+ # # TODO : add to the lists below all activations.
850
+
851
+ UNARY_ACTS = [ # f, dfdx
852
+ # # In the same order as above!
799
853
(:σ , :(conj (Ω * (1 - Ω)))),
854
+ (:hardσ , :(ifelse ((Ω> 0 )& (Ω< 1 ), 1 // 6 , 1 // 1 ))),
855
+ (:logσ , :(sigmoid_fast (- x))),
856
+ (:hardtanh , :((Ω> - 1 ) & (Ω< 1 ))),
857
+ (:relu , :(Ω > 0 )),
858
+ (:leakyrelu , :(ifelse (Ω > 0 , 1 // 1 , 1 // 100 ))),
859
+ (:relu6 , :((Ω> 0 ) & (Ω< 6 ))),
860
+ # rrelu is random, can't write a rule.
800
861
(:elu , :(deriv_elu (Ω))),
801
- (:softplus , :(σ (x))),
802
-
862
+ # gelu
863
+ (:swish , :(Ω + sigmoid_fast (x) * (1 - Ω))),
864
+ (:hardswish , :(deriv_hardswish (x))),
865
+ # lisht
866
+ (:selu , :(deriv_selu (Ω))),
867
+ (:celu , :(deriv_celu (Ω))),
868
+ (:trelu , :(Ω > 0 )),
869
+ (:softsign , :(deriv_softsign (x))),
870
+ (:softplus , :(sigmoid_fast (x))),
871
+ # (:softplus, :(1 - @fastmath exp(-Ω))), # slightly faster, check accuracy?
872
+ # logcosh
873
+ # mish
874
+ (:tanhshrink , :((x - Ω)^ 2 )),
875
+ (:softshrink , :(Ω != 0 )),
876
+ # # Fast variants are the same!
803
877
(:tanh_fast , :(conj (1 - Ω^ 2 ))),
804
878
(:sigmoid_fast , :(conj (Ω * (1 - Ω)))),
805
- ]
879
+ ]
806
880
807
- for (f, df ) in UNARY_ACTS
808
- @eval @scalar_rule ($ f (x), $ df )
881
+ for (f, dfdx ) in UNARY_ACTS
882
+ @eval @scalar_rule ($ f (x), $ dfdx )
809
883
810
884
pullback = Symbol (:broadcasted_ , f, :_pullback )
811
885
@eval function rrule (:: typeof (broadcasted),
812
886
:: typeof ($ f), x:: Numeric )
813
887
Ω = $ f .(x)
814
- function $pullback (Δ )
888
+ function $pullback (dΩ )
815
889
x_thunk = InplaceableThunk (
816
- dx -> @. (dx += Δ * $ df ),
817
- @thunk @. (Δ * $ df )
890
+ dx -> @. (dx += dΩ * $ dfdx ),
891
+ @thunk @. (dΩ * $ dfdx )
818
892
)
819
893
NoTangent (), NoTangent (), x_thunk
820
894
end
821
895
return Ω, $ pullback
822
896
end
823
897
end
824
898
825
- BINARY_ACTS = [ # f, df1, df2
826
- (:elu , :(deriv_elu (Ω, x2)), :(NoTangent ())), # TODO use real deriv instead of DNE
827
- ]
899
+ # NO_ACT_GRAD = ChainRulesCore.@not_implemented "for simplicitly NNlib assumes the 2nd argument of this activation function is a constant"
900
+ NO_ACT_GRAD = NaN # # Still reminds you not to use this, but is perhaps more GPU friendly.
828
901
829
- for (f, df1, df2) in BINARY_ACTS
830
- @eval @scalar_rule ($ f (x1, x2), ($ df1, $ df2))
902
+ BINARY_ACTS = [ # f, dfdx1, dfdx2
903
+ # # In the same order as above!
904
+ (:leakyrelu , :(ifelse (Ω > 0 , oftf (Ω, 1 ), oftf (Ω, x2))), NO_ACT_GRAD),
905
+ (:elu , :(deriv_elu (Ω, x2)), NO_ACT_GRAD),
906
+ (:celu , :(deriv_celu (Ω, x2)), NO_ACT_GRAD),
907
+ (:trelu , :(Ω > 0 ), ZeroTangent ()),
908
+ (:softshrink , :(Ω != 0 ), NO_ACT_GRAD),
909
+ ]
831
910
832
- pullback = Symbol (:broadcasted_ , f, :_pullback )
911
+ for (f, dfdx1, dfdx2) in BINARY_ACTS
912
+ @eval @scalar_rule ($ f (x1, x2), ($ dfdx1, $ dfdx2))
913
+
914
+ pullback = Symbol (:broadcasted_ , f, :_pullback_2arg )
833
915
@eval function rrule (:: typeof (broadcasted),
834
916
:: typeof ($ f),
835
- x1:: Numeric , x2:: Numeric )
917
+ x1:: Numeric , x2:: Number )
836
918
Ω = $ f .(x1, x2)
837
- function $pullback (Δ)
838
- NoTangent (), NoTangent (), @. (Δ * $ df1), @. (Δ * $ df2)
839
- end
919
+ # # Allowing x2::Array would allow size(Ω) != size(x1), which is not handled here:
920
+ $ pullback (dΩ) = (NoTangent (), NoTangent (), @. (dΩ * $ dfdx1), NO_ACT_GRAD)
840
921
return Ω, $ pullback
841
922
end
842
923
end
0 commit comments