Skip to content

Commit 220c442

Browse files
author
Avik Pal
committed
Update NNPACK interface to conform to the new NNlib
2 parents e2947d8 + ed4fe9a commit 220c442

39 files changed

+3824
-1858
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ notifications:
99
email: false
1010
git:
1111
depth: 99999999
12+
env:
13+
# Disable test fuzzing for the moment, as we're a little too slow for Travis
14+
- NNLIB_TEST_FUZZING=false
1215

1316
# Submit to Codecov
1417
after_success:

Manifest.toml

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,30 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
13
[[Base64]]
24
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
35

46
[[BinaryProvider]]
57
deps = ["Libdl", "Pkg", "SHA", "Test"]
6-
git-tree-sha1 = "9930c1a6cd49d9fcd7218df6be417e6ae4f1468a"
8+
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
79
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
8-
version = "0.5.2"
10+
version = "0.5.3"
911

10-
[[Compat]]
11-
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
12-
git-tree-sha1 = "2d9e14d19bad3f9ad5cc5e4cffabc3cfa59de825"
13-
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
14-
version = "1.3.0"
12+
[[Crayons]]
13+
deps = ["Test"]
14+
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
15+
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
16+
version = "4.0.0"
1517

1618
[[Dates]]
1719
deps = ["Printf"]
1820
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
1921

20-
[[DelimitedFiles]]
21-
deps = ["Mmap"]
22-
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
23-
2422
[[Distributed]]
25-
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
23+
deps = ["Random", "Serialization", "Sockets"]
2624
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2725

2826
[[InteractiveUtils]]
29-
deps = ["LinearAlgebra", "Markdown"]
27+
deps = ["Markdown"]
3028
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
3129

3230
[[LibGit2]]
@@ -42,19 +40,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4240
[[Logging]]
4341
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
4442

45-
[[MacroTools]]
46-
deps = ["Compat"]
47-
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
48-
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
49-
version = "0.4.4"
50-
5143
[[Markdown]]
5244
deps = ["Base64"]
5345
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
5446

55-
[[Mmap]]
56-
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
57-
5847
[[Pkg]]
5948
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
6049
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -83,10 +72,6 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
8372
[[Serialization]]
8473
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
8574

86-
[[SharedArrays]]
87-
deps = ["Distributed", "Mmap", "Random", "Serialization"]
88-
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
89-
9075
[[Sockets]]
9176
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
9277

@@ -102,8 +87,14 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
10287
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
10388
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10489

90+
[[TimerOutputs]]
91+
deps = ["Crayons", "Printf", "Test", "Unicode"]
92+
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
93+
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
94+
version = "0.5.0"
95+
10596
[[UUIDs]]
106-
deps = ["Random"]
97+
deps = ["Random", "SHA"]
10798
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
10899

109100
[[Unicode]]

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3+
version = "0.6.0"
34

45
[deps]
56
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
67
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8-
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
10+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12+
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
julia 0.7-
1+
julia 1.0
22
Requires
33
MacroTools
44
BinaryProvider

src/NNlib.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
module NNlib
2+
using Requires, TimerOutputs
23

3-
using Requires, Libdl
4-
5-
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
6-
softmax, logsoftmax, maxpool, meanpool
7-
8-
include("numeric.jl")
4+
# Include APIs
5+
include("dim_helpers.jl")
96
include("activation.jl")
107
include("softmax.jl")
11-
include("logsoftmax.jl")
12-
include("linalg.jl")
8+
include("gemm.jl")
139
include("conv.jl")
14-
include("cubroadcast.jl")
10+
include("pooling.jl")
1511

16-
try
17-
global ENABLE_NNPACK = parse(UInt64, ENV["ENABLE_NNPACK"])
18-
catch
19-
global ENABLE_NNPACK = 1
20-
end
12+
## Include implementations
13+
include("impl/padding_edges.jl")
14+
15+
# Direct implementations of convolutional and depthwise-convolutional algorithms
16+
include("impl/conv_direct.jl")
17+
include("impl/depthwiseconv_direct.jl")
18+
# im2col implementations of convolutional and depthwise-convolutional algorithms
19+
include("impl/conv_im2col.jl")
20+
include("impl/depthwiseconv_im2col.jl")
21+
22+
# Direct implementations of pooling
23+
include("impl/pooling_direct.jl")
24+
25+
to = TimerOutput()
2126

22-
if Sys.islinux() && ENABLE_NNPACK == 1
27+
if Sys.islinux()
2328
include("nnpack/NNPACK.jl")
24-
include("backends.jl")
2529
end
2630

27-
end # module
31+
end # module NNlib

src/activation.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,40 @@
1+
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ,
2+
logsigmoid
3+
14
"""
25
σ(x) = 1 / (1 + exp(-x))
36
47
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
58
function.
69
"""
710
σ(x) = one(x) / (one(x) + exp(-x))
8-
911
const sigmoid = σ
1012

1113
# ForwardDiff numerical stability hack
1214
σ_stable(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
13-
1415
σ(x::Float32) = σ_stable(x)
15-
1616
@init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
1717
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
1818
end
1919

20+
2021
"""
2122
logσ(x)
2223
2324
Return `log(σ(x))` which is computed in a numerically stable way.
2425
25-
julia> logσ(0.)
26+
julia> logσ(0)
2627
-0.6931471805599453
27-
julia> logσ.([-100, -10, 100.])
28+
julia> logσ.([-100, -10, 100])
2829
3-element Array{Float64,1}:
29-
-100.0
30-
-10.0
31-
-0.0
32-
"""
33-
function logσ(x)
34-
max_v = max(zero(x), -x)
35-
z = exp(-max_v) + exp(-x-max_v)
36-
-(max_v + log(z))
37-
end
38-
30+
-100.0
31+
-10.000045398899218
32+
-3.720075976020836e-44
33+
"""
34+
logσ(x) = -softplus(-x)
3935
const logsigmoid = logσ
4036

37+
4138
"""
4239
relu(x) = max(0, x)
4340
@@ -56,6 +53,7 @@ You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
5653
"""
5754
leakyrelu(x, a = oftype(x/1, 0.01)) = max(a*x, x/1)
5855

56+
5957
"""
6058
elu(x, α = 1) =
6159
x > 0 ? x : α * (exp(x) - 1)
@@ -66,6 +64,7 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
6664
"""
6765
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))
6866

67+
6968
"""
7069
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
7170
@@ -103,6 +102,7 @@ function selu(x)
103102
λ * ifelse(x > 0, x/1, α * (exp(x) - 1))
104103
end
105104

105+
106106
"""
107107
softsign(x) = x / (1 + |x|)
108108
@@ -116,4 +116,4 @@ softsign(x) = x / (one(x) + abs(x))
116116
117117
See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf).
118118
"""
119-
softplus(x) = log1p(exp(x))
119+
softplus(x) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))

src/backends.jl

Lines changed: 0 additions & 68 deletions
This file was deleted.

0 commit comments

Comments
 (0)