Skip to content

Commit 5053c70

Browse files
Reduce lpnormpool unnecessary conversion (#467)
* fix: convert p type the same with eltype(x) when eltype(x) <: Real * fix: remove re-assign, use Real
1 parent 1672035 commit 5053c70

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/pooling.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ end
185185

186186

187187
"""
188-
lpnormpool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k)
188+
lpnormpool(x, p::Real, k::NTuple{N, Integer}; pad=0, stride=k)
189189
190190
Perform Lp pool operation with value of the Lp norm `p` and window size `k` on input tensor `x`, also known as LPPool in pytorch.
191191
This pooling operator from [Learned-Norm Pooling for Deep Feedforward and Recurrent Neural Networks](https://arxiv.org/abs/1311.1780).
@@ -201,12 +201,11 @@ For all elements `x` in a size `k` window, lpnormpool computes `(∑ᵢ xᵢ^p)^
201201
202202
Thus `lpnormpool(x, 1, k) ./ prod(k) ≈ meanpool(x, k)` and `lpnormpool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k)`.
203203
"""
204-
function lpnormpool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k) where N
205-
(isinf(p) || p < 0) && error("p value of Lp norm pool expects `0 < p < Inf`, but p is $(p) now.")
206-
pad = expand(Val(N), pad)
207-
stride = expand(Val(N), stride)
208-
pdims = PoolDims(x, k; padding=pad, stride=stride)
209-
return lpnormpool(x, pdims; p=p)
204+
function lpnormpool(x, p::Real, k::NTuple{N, Integer}; pad=0, stride=k) where {N}
205+
pow = p isa Integer ? p : convert(float(eltype(x)), p)
206+
(isinf(pow) || pow < 0) && error("p value of Lp norm pool expects `0 < p < Inf`, but p is $(pow) now.")
207+
pdims = PoolDims(x, k; padding=expand(Val(N), pad), stride=expand(Val(N), stride))
208+
return lpnormpool(x, pdims; p=pow)
210209
end
211210

212211

0 commit comments

Comments
 (0)