Skip to content

Commit 7ee2a6f

Browse files
committed
add: LPNormPool
1 parent 7997174 commit 7ee2a6f

File tree

4 files changed

+104
-2
lines changed

4 files changed

+104
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ChainRulesCore = "1.12"
3030
Functors = "0.3, 0.4"
3131
MLUtils = "0.2, 0.3.1, 0.4"
3232
MacroTools = "0.5"
33-
NNlib = "0.8.14"
33+
NNlib = "0.8.16"
3434
NNlibCUDA = "0.2.4"
3535
OneHotArrays = "0.1, 0.2"
3636
Optimisers = "0.2.12"

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zyg
2121
export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion,
2222
RNN, LSTM, GRU, GRUv3,
2323
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
24-
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
24+
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, GlobalLPNormPool, MaxPool, MeanPool, LPNormPool,
2525
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
2626
Upsample, PixelShuffle,
2727
fmap, cpu, gpu, f32, f64,

src/layers/conv.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,40 @@ function Base.show(io::IO, g::GlobalMeanPool)
633633
print(io, "GlobalMeanPool()")
634634
end
635635

636+
"""
637+
GlobalLPNormPool
638+
639+
Global lp norm pooling layer.
640+
641+
Transform (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
642+
by performing lp norm pooling on the complete (w,h)-shaped feature maps.
643+
644+
See also [`LPNormPool`](@ref).
645+
646+
```jldoctest
647+
julia> xs = rand(Float32, 100, 100, 3, 50)
648+
649+
julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool())
650+
651+
julia> m(xs) |> size
652+
(1, 1, 7, 50)
653+
```
654+
"""
655+
struct GlobalLPNormPool
656+
p::Number
657+
end
658+
659+
function (g::GlobalLPNormPool)(x)
660+
x_size = size(x)
661+
k = x_size[1:end-2]
662+
pdims = PoolDims(x, k)
663+
return lpnormpool(x, pdims; p=g.p)
664+
end
665+
666+
function Base.show(io::IO, g::GlobalLPNormPool)
667+
print(io, "GlobalLPNormPool(p=", g.p, ")")
668+
end
669+
636670
"""
637671
MaxPool(window::NTuple; pad=0, stride=window)
638672
@@ -754,3 +788,67 @@ function Base.show(io::IO, m::MeanPool)
754788
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
755789
print(io, ")")
756790
end
791+
792+
"""
793+
LPNormPool(window::NTuple, p::Number; pad=0, stride=window)
794+
795+
Lp norm pooling layer, calculating p-norm distance for each window,
796+
also known as LPPool in pytorch.
797+
798+
Expects as input an array with `ndims(x) == N+2`, i.e. channel and
799+
batch dimensions, after the `N` feature dimensions, where `N = length(window)`.
800+
801+
By default the window size is also the stride in each dimension.
802+
The keyword `pad` accepts the same options as for the `Conv` layer,
803+
including `SamePad()`.
804+
805+
See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalLPNormPool`](@ref).
806+
807+
# Examples
808+
809+
```jldoctest
810+
julia> xs = rand(Float32, 100, 100, 3, 50);
811+
812+
julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2; pad=SamePad()))
813+
Chain(
814+
Conv((5, 5), 3 => 7), # 532 parameters
815+
LPNormPool((5, 5), p=2, pad=2),
816+
)
817+
818+
julia> m[1](xs) |> size
819+
(96, 96, 7, 50)
820+
821+
julia> m(xs) |> size
822+
(20, 20, 7, 50)
823+
824+
julia> layer = LPNormPool((5,), 2, pad=2, stride=(3,)) # one-dimensional window
825+
LPNormPool((5,), p=2, pad=2, stride=3)
826+
827+
julia> layer(rand(Float32, 100, 7, 50)) |> size
828+
(34, 7, 50)
829+
```
830+
"""
831+
struct LPNormPool{N,M}
832+
k::NTuple{N,Int}
833+
p::Number
834+
pad::NTuple{M,Int}
835+
stride::NTuple{N,Int}
836+
end
837+
838+
function LPNormPool(k::NTuple{N,Integer}, p::Number; pad = 0, stride = k) where N
839+
stride = expand(Val(N), stride)
840+
pad = calc_padding(LPNormPool, pad, k, 1, stride)
841+
return LPNormPool(k, p, pad, stride)
842+
end
843+
844+
function (l::LPNormPool)(x)
845+
pdims = PoolDims(x, l.k; padding=l.pad, stride=l.stride)
846+
return lpnormpool(x, pdims; p=l.p)
847+
end
848+
849+
function Base.show(io::IO, l::LPNormPool)
850+
print(io, "LPNormPool(", l.k, ", p=", l.p)
851+
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
852+
l.stride == l.k || print(io, ", stride=", _maybetuple_string(l.stride))
853+
print(io, ")")
854+
end

test/layers/conv.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ using Flux: gradient
1717
@test size(gmp(x)) == (1, 1, 3, 2)
1818
gmp = GlobalMeanPool()
1919
@test size(gmp(x)) == (1, 1, 3, 2)
20+
glmp = GlobalLPNormPool(2)
21+
@test size(glmp(x)) == (1, 1, 3, 2)
2022
mp = MaxPool((2, 2))
2123
@test mp(x) == maxpool(x, PoolDims(x, 2))
2224
mp = MeanPool((2, 2))
2325
@test mp(x) == meanpool(x, PoolDims(x, 2))
26+
lnp = LPNormPool((2,2), 2)
27+
@test lnp(x) == lpnormpool(x, PoolDims(x, 2); p=2)
2428
end
2529

2630
@testset "CNN" begin

0 commit comments

Comments
 (0)