Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,8 @@ Global lp norm pooling layer.

Transform (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing lp norm pooling on the complete (w,h)-shaped feature maps.
And expects input `x` to satisfy `all(x .>= 0)` to avoid `function ^(x, y)`
in Base.Math throw DomainError.

See also [`LPNormPool`](@ref).

Expand Down Expand Up @@ -793,6 +795,7 @@ also known as LPPool in pytorch.

Expects as input an array with `ndims(x) == N+2`, i.e. channel and
batch dimensions, after the `N` feature dimensions, where `N = length(window)`.
Also expects `all(x .>= 0)` to avoid `function ^(x, y)` in Base.Math throw DomainError.

By default the window size is also the stride in each dimension.
The keyword `pad` accepts the same options as for the `Conv` layer,
Expand Down Expand Up @@ -839,6 +842,7 @@ function LPNormPool(k::NTuple{N,Integer}, p::Real; pad = 0, stride = k) where {N
end

function (l::LPNormPool)(x)
all(x .>= 0) || throw(DomainError("LPNormPool requires 'all(x .>= 0)'. Relu before LPNormPool is recommended."))
pdims = PoolDims(x, l.k; padding=l.pad, stride=l.stride)
return lpnormpool(x, pdims; p=l.p)
end
Expand Down