Skip to content

Commit d58adb6

Browse files
authored
Replace broadcasting over distributions with broadcasting with partially applied functions (#1818)
* Replace broadcasting over distributions with broadcasting with partially applied functions * Delete src/multivariate/mvnormal copy.jl
1 parent 28bf738 commit d58adb6

24 files changed

+82
-77
lines changed

src/deprecates.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ for fun in [:pdf, :logpdf,
3434
fun! = Symbol(fun, '!')
3535

3636
@eval begin
37-
@deprecate ($_fun!)(r::AbstractArray{<:Real}, d::UnivariateDistribution, X::AbstractArray{<:Real}) r .= ($fun).(d, X) false
38-
@deprecate ($fun!)(r::AbstractArray{<:Real}, d::UnivariateDistribution, X::AbstractArray{<:Real}) r .= ($fun).(d, X) false
39-
@deprecate ($fun)(d::UnivariateDistribution, X::AbstractArray{<:Real}) ($fun).(d, X)
37+
@deprecate ($_fun!)(r::AbstractArray{<:Real}, d::UnivariateDistribution, X::AbstractArray{<:Real}) r .= Base.Fix1($fun, d).(X) false
38+
@deprecate ($fun!)(r::AbstractArray{<:Real}, d::UnivariateDistribution, X::AbstractArray{<:Real}) r .= Base.Fix1($fun, d).(X) false
39+
@deprecate ($fun)(d::UnivariateDistribution, X::AbstractArray{<:Real}) map(Base.Fix1($fun, d), X)
4040
end
4141
end
4242

43-
@deprecate pdf(d::DiscreteUnivariateDistribution) pdf.(Ref(d), support(d))
43+
@deprecate pdf(d::DiscreteUnivariateDistribution) map(Base.Fix1(pdf, d), support(d))
4444

4545
# Wishart constructors
4646
@deprecate Wishart(df::Real, S::AbstractPDMat, warn::Bool) Wishart(df, S)

src/mixtures/mixturemodel.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ function _mixpdf!(r::AbstractArray, d::AbstractMixtureModel, x)
297297
pi = p[i]
298298
if pi > 0.0
299299
if d isa UnivariateMixture
300-
t .= pdf.(component(d, i), x)
300+
t .= Base.Fix1(pdf, component(d, i)).(x)
301301
else
302302
pdf!(t, component(d, i), x)
303303
end
@@ -326,7 +326,7 @@ function _mixlogpdf!(r::AbstractArray, d::AbstractMixtureModel, x)
326326
lp_i = view(Lp, :, i)
327327
# compute logpdf in batch and store
328328
if d isa UnivariateMixture
329-
lp_i .= logpdf.(component(d, i), x)
329+
lp_i .= Base.Fix1(logpdf, component(d, i)).(x)
330330
else
331331
logpdf!(lp_i, component(d, i), x)
332332
end
@@ -398,7 +398,7 @@ function _cwise_pdf!(r::AbstractMatrix, d::AbstractMixtureModel, X)
398398
size(r) == (n, K) || error("The size of r is incorrect.")
399399
for i = 1:K
400400
if d isa UnivariateMixture
401-
view(r,:,i) .= pdf.(Ref(component(d, i)), X)
401+
view(r,:,i) .= Base.Fix1(pdf, component(d, i)).(X)
402402
else
403403
pdf!(view(r,:,i),component(d, i), X)
404404
end
@@ -412,7 +412,7 @@ function _cwise_logpdf!(r::AbstractMatrix, d::AbstractMixtureModel, X)
412412
size(r) == (n, K) || error("The size of r is incorrect.")
413413
for i = 1:K
414414
if d isa UnivariateMixture
415-
view(r,:,i) .= logpdf.(Ref(component(d, i)), X)
415+
view(r,:,i) .= Base.Fix1(logpdf, component(d, i)).(X)
416416
else
417417
logpdf!(view(r,:,i), component(d, i), X)
418418
end

src/multivariate/jointorderstatistics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:R
162162
else
163163
s += randexp(rng, T)
164164
end
165-
x .= quantile.(d.dist, x ./ s)
165+
x .= Base.Fix1(quantile, d.dist).(x ./ s)
166166
end
167167
return x
168168
end

src/qq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ function qqbuild(x::AbstractVector, d::UnivariateDistribution)
3535
n = length(x)
3636
grid = ppoints(n)
3737
qx = quantile(x, grid)
38-
qd = quantile.(Ref(d), grid)
38+
qd = map(Base.Fix1(quantile, d), grid)
3939
return QQPair(qx, qd)
4040
end
4141

4242
function qqbuild(d::UnivariateDistribution, x::AbstractVector)
4343
n = length(x)
4444
grid = ppoints(n)
45-
qd = quantile.(Ref(d), grid)
45+
qd = map(Base.Fix1(quantile, d), grid)
4646
qx = quantile(x, grid)
4747
return QQPair(qd, qx)
4848
end

src/univariate/continuous/uniform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ Base.:*(c::Real, d::Uniform) = Uniform(minmax(c * d.a, c * d.b)...)
154154
rand(rng::AbstractRNG, d::Uniform) = d.a + (d.b - d.a) * rand(rng)
155155

156156
_rand!(rng::AbstractRNG, d::Uniform, A::AbstractArray{<:Real}) =
157-
A .= quantile.(d, rand!(rng, A))
157+
A .= Base.Fix1(quantile, d).(rand!(rng, A))
158158

159159

160160
#### Fitting

src/univariate/discrete/betabinomial.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,15 @@ for f in (:ccdf, :logcdf, :logccdf)
103103
end
104104
end
105105

106-
entropy(d::BetaBinomial) = entropy(Categorical(pdf.(Ref(d),support(d))))
107-
median(d::BetaBinomial) = median(Categorical(pdf.(Ref(d),support(d)))) - 1
108-
mode(d::BetaBinomial) = argmax(pdf.(Ref(d),support(d))) - 1
109-
modes(d::BetaBinomial) = modes(Categorical(pdf.(Ref(d),support(d)))) .- 1
106+
# Shifted categorical distribution corresponding to `BetaBinomial`
107+
_categorical(d::BetaBinomial) = Categorical(map(Base.Fix1(pdf, d), support(d)))
110108

111-
quantile(d::BetaBinomial, p::Float64) = quantile(Categorical(pdf.(Ref(d), support(d))), p) - 1
109+
entropy(d::BetaBinomial) = entropy(_categorical(d))
110+
median(d::BetaBinomial) = median(_categorical(d)) - 1
111+
mode(d::BetaBinomial) = mode(_categorical(d)) - 1
112+
modes(d::BetaBinomial) = modes(_categorical(d)) .- 1
113+
114+
quantile(d::BetaBinomial, p::Float64) = quantile(_categorical(d), p) - 1
112115

113116
#### Sampling
114117

src/univariate/discrete/hypergeometric.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function kurtosis(d::Hypergeometric)
7575
a/b
7676
end
7777

78-
entropy(d::Hypergeometric) = entropy(pdf.(Ref(d), support(d)))
78+
entropy(d::Hypergeometric) = entropy(map(Base.Fix1(pdf, d), support(d)))
7979

8080
### Evaluation & Sampling
8181

src/univariate/discrete/noncentralhypergeometric.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,12 @@ end
256256
Base.convert(::Type{WalleniusNoncentralHypergeometric{T}}, d::WalleniusNoncentralHypergeometric{T}) where {T<:Real} = d
257257

258258
# Properties
259-
mean(d::WalleniusNoncentralHypergeometric) = sum(support(d) .* pdf.(Ref(d), support(d)))
260-
var(d::WalleniusNoncentralHypergeometric) = sum((support(d) .- mean(d)).^2 .* pdf.(Ref(d), support(d)))
261-
mode(d::WalleniusNoncentralHypergeometric) = support(d)[argmax(pdf.(Ref(d), support(d)))]
259+
function _discretenonparametric(d::WalleniusNoncentralHypergeometric)
260+
return DiscreteNonParametric(support(d), map(Base.Fix1(pdf, d), support(d)))
261+
end
262+
mean(d::WalleniusNoncentralHypergeometric) = mean(_discretenonparametric(d))
263+
var(d::WalleniusNoncentralHypergeometric) = var(_discretenonparametric(d))
264+
mode(d::WalleniusNoncentralHypergeometric) = mode(_discretenonparametric(d))
262265

263266
entropy(d::WalleniusNoncentralHypergeometric) = 1
264267

test/censored.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ end
207207
end
208208
@test @inferred(median(d)) clamp(median(d0), l, u)
209209
@inferred quantile(d, 0.5)
210-
@test quantile.(d, 0:0.01:1) clamp.(quantile.(d0, 0:0.01:1), l, u)
210+
@test Base.Fix1(quantile, d).(0:0.01:1) clamp.(Base.Fix1(quantile, d0).(0:0.01:1), l, u)
211211
# special-case pdf/logpdf/loglikelihood since when replacing Dirac(μ) with
212212
# Normal(μ, 0), they are infinite
213213
if lower === nothing || !isfinite(lower)
@@ -253,7 +253,7 @@ end
253253
@test f(d) f(dmix)
254254
end
255255
@test median(d) clamp(median(d0), l, u)
256-
@test quantile.(d, 0:0.01:1) clamp.(quantile.(d0, 0:0.01:1), l, u)
256+
@test Base.Fix1(quantile, d).(0:0.01:1) clamp.(Base.Fix1(quantile, d0).(0:0.01:1), l, u)
257257
# special-case pdf/logpdf/loglikelihood since when replacing Dirac(μ) with
258258
# Normal(μ, 0), they are infinite
259259
if lower === nothing
@@ -311,7 +311,7 @@ end
311311
end
312312
@test @inferred(median(d)) clamp(median(d0), l, u)
313313
@inferred quantile(d, 0.5)
314-
@test quantile.(d, 0:0.01:1) clamp.(quantile.(d0, 0:0.01:1), l, u)
314+
@test Base.Fix1(quantile, d).(0:0.01:1) clamp.(Base.Fix1(quantile, d0).(0:0.01:1), l, u)
315315
# rand
316316
x = rand(d, 10_000)
317317
@test all(x -> insupport(d, x), x)
@@ -346,7 +346,7 @@ end
346346
@test f(d, 5) f(dmix, 5)
347347
end
348348
@test median(d) clamp(median(d0), l, u)
349-
@test quantile.(d, 0:0.01:0.99) clamp.(quantile.(d0, 0:0.01:0.99), l, u)
349+
@test Base.Fix1(quantile, d).(0:0.01:0.99) clamp.(Base.Fix1(quantile, d0).(0:0.01:0.99), l, u)
350350
x = rand(d, 100)
351351
@test loglikelihood(d, x) loglikelihood(dmix, x)
352352
# rand

test/matrixvariates.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function test_convert(d::MatrixDistribution)
117117
@test d == deepcopy(d)
118118
for elty in (Float32, Float64, BigFloat)
119119
del1 = convert(distname{elty}, d)
120-
del2 = convert(distname{elty}, getfield.(Ref(d), fieldnames(typeof(d)))...)
120+
del2 = convert(distname{elty}, (Base.Fix1(getfield, d)).(fieldnames(typeof(d)))...)
121121
@test del1 isa distname{elty}
122122
@test del2 isa distname{elty}
123123
@test partype(del1) == elty

0 commit comments

Comments
 (0)