Skip to content

Commit 886b34c

Browse files
authored
Avoid Rational in activation function gradients (#399)
* avoid rational numbers * move CUDA tests first, add overall testset * NNLIB_TEST_CUDA: true for v1 * two more rationals
1 parent 0c8396e commit 886b34c

File tree

4 files changed

+75
-70
lines changed

4 files changed

+75
-70
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ steps:
3636
agents:
3737
queue: "juliagpu"
3838
cuda: "*"
39+
env:
40+
NNLIB_TEST_CUDA: true
3941
timeout_in_minutes: 60
4042

4143
# - label: "GPU julia nightly"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3-
version = "0.8.3"
3+
version = "0.8.4"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/activations.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,16 @@ julia> lineplot(x -> leakyrelu(x, 0.5), -2, 2, height=7)
195195
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
196196
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
197197
198-
julia> leakyrelu(-10f0, 1//5)
198+
julia> leakyrelu(-10f0, 0.2)
199199
-2.0f0
200200
201-
julia> leakyrelu(-10f0, 1//20)
201+
julia> leakyrelu(-10f0, 0.02)
202202
-0.5f0
203203
```
204204
"""
205-
leakyrelu(x, a=oftf(x, 0.01)) = ifelse(x>0, float(x), oftf(x, a*x)) # max(a*x, x) is 3x slower
205+
leakyrelu(x, a=oftf(x, leakyrelu_a)) = ifelse(x>0, float(x), oftf(x, a*x)) # max(a*x, x) is 3x slower
206+
207+
const leakyrelu_a = 0.01 # also used in gradient below
206208

207209
"""
208210
relu6(x) = min(max(0, x), 6)
@@ -254,7 +256,7 @@ julia> extrema(rrelu.(fill(-10f0, 1000)))
254256
(-3.3316886f0, -1.2548422f0)
255257
```
256258
"""
257-
function rrelu(x::T, l=1//8, u=1//3) where T<:Number
259+
function rrelu(x::T, l=oftf(x,1/8), u=oftf(x,1/3)) where T<:Number
258260
a = (u - l) * rand(float(T)) + l
259261
return leakyrelu(x, a)
260262
end
@@ -402,7 +404,7 @@ julia> hardswish.(-5:5)'
402404
"""
403405
@inline hardswish(x) = x * hardσ(x)
404406

405-
deriv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + 1//2))
407+
deriv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + oftf(x,1/2)))
406408

407409
"""
408410
lisht(x) = x * tanh(x)
@@ -844,11 +846,11 @@ this replacement for some array or element types.
844846
UNARY_ACTS = [ # f, dfdx
845847
## In the same order as above!
846848
(, :(conj* (1 - Ω)))),
847-
(:hardσ, :(ifelse((Ω>0)&<1), 1//6, 1//1))),
849+
(:hardσ, :(ifelse((Ω>0)&<1), oftf(Ω, 1/6), oftf(Ω, 1)))),
848850
(:logσ, :(sigmoid_fast(-x))),
849851
(:hardtanh, :((Ω>-1) &<1))),
850852
(:relu, :(Ω > 0)),
851-
(:leakyrelu, :(ifelse> 0, 1//1, 1//100))),
853+
(:leakyrelu, :(ifelse> 0, oftf, 1), oftf(Ω, leakyrelu_a)))),
852854
(:relu6, :((Ω>0) &<6))),
853855
# rrelu is random, can't write a rule.
854856
(:elu, :(deriv_elu(Ω))),

test/runtests.jl

Lines changed: 63 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,83 +7,84 @@ using Zygote: gradient
77
using StableRNGs
88
using CUDA
99

10-
if VERSION < v"1.6"
11-
@info "skipping doctests, on Julia $VERSION"
12-
else
13-
using Documenter
14-
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)
15-
@testset "Doctests" begin
16-
doctest(NNlib, manual=false)
17-
end
18-
end
19-
2010
const rng = StableRNG(123)
21-
2211
include("test_utils.jl")
2312

24-
@testset "Activation Functions" begin
25-
include("activations.jl")
26-
end
13+
@testset verbose=true "NNlib.jl" begin
14+
if CUDA.functional()
15+
if get(ENV, "NNLIB_TEST_CUDA", "false") == "true"
16+
import Pkg
17+
using NNlibCUDA
18+
@testset "CUDA" begin
19+
Pkg.test("NNlibCUDA")
20+
end
21+
else
22+
@info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them"
23+
end
24+
else
25+
@info "Insufficient version or CUDA not found; Skipping CUDA tests"
26+
end
2727

28-
@testset "Batched Multiplication" begin
29-
include("batchedmul.jl")
30-
end
28+
if VERSION < v"1.6"
29+
@info "skipping doctests, on Julia $VERSION"
30+
else
31+
using Documenter
32+
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)
33+
@testset "Doctests" begin
34+
doctest(NNlib, manual=false)
35+
end
36+
end
3137

32-
@testset "Convolution" begin
33-
include("conv.jl")
34-
include("conv_bias_act.jl")
35-
end
38+
@testset "Activation Functions" begin
39+
include("activations.jl")
40+
end
3641

37-
@testset "Inference" begin
38-
include("inference.jl")
39-
end
42+
@testset "Batched Multiplication" begin
43+
include("batchedmul.jl")
44+
end
4045

41-
@testset "Pooling" begin
42-
include("pooling.jl")
43-
end
46+
@testset "Convolution" begin
47+
include("conv.jl")
48+
include("conv_bias_act.jl")
49+
end
4450

45-
@testset "Padding" begin
46-
include("padding.jl")
47-
end
51+
@testset "Inference" begin
52+
include("inference.jl")
53+
end
4854

49-
@testset "Softmax" begin
50-
include("softmax.jl")
51-
end
55+
@testset "Pooling" begin
56+
include("pooling.jl")
57+
end
5258

53-
@testset "Upsampling" begin
54-
include("upsample.jl")
55-
end
59+
@testset "Padding" begin
60+
include("padding.jl")
61+
end
5662

57-
@testset "Gather" begin
58-
include("gather.jl")
59-
end
63+
@testset "Softmax" begin
64+
include("softmax.jl")
65+
end
6066

61-
@testset "Scatter" begin
62-
include("scatter.jl")
63-
end
67+
@testset "Upsampling" begin
68+
include("upsample.jl")
69+
end
6470

65-
@testset "Utilities" begin
66-
include("utils.jl")
67-
end
71+
@testset "Gather" begin
72+
include("gather.jl")
73+
end
6874

69-
@testset "Grid Sampling" begin
70-
include("sampling.jl")
71-
end
75+
@testset "Scatter" begin
76+
include("scatter.jl")
77+
end
7278

73-
@testset "Functions" begin
74-
include("functions.jl")
75-
end
79+
@testset "Utilities" begin
80+
include("utils.jl")
81+
end
7682

77-
if VERSION >= v"1.6" && CUDA.functional()
78-
if get(ENV, "NNLIB_TEST_CUDA", "false") == "true"
79-
import Pkg
80-
using NNlibCUDA
81-
@testset "CUDA" begin
82-
Pkg.test("NNlibCUDA")
83-
end
84-
else
85-
@info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them"
83+
@testset "Grid Sampling" begin
84+
include("sampling.jl")
85+
end
86+
87+
@testset "Functions" begin
88+
include("functions.jl")
8689
end
87-
else
88-
@info "Insufficient version or CUDA not found; Skipping CUDA tests"
8990
end

0 commit comments

Comments
 (0)