Skip to content

Commit 2d5f145

Browse files
committed
make arithmetic for UnivariateFiniteArray performant
1 parent 7835e19 commit 2d5f145

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

src/methods.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -390,55 +390,61 @@ const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
390390

391391
import Base: +, *, /, -
392392

393-
function _plus(d1, d2, T, N...) # N... is single integer or absent
393+
function _plus(d1, d2, T)
394394
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
395395
S = d1.scitype
396396
decoder = d1.decoder
397397
prob_given_ref = copy(d1.prob_given_ref)
398398
for ref in keys(prob_given_ref)
399399
prob_given_ref[ref] += d2.prob_given_ref[ref]
400400
end
401-
return T(S, decoder, prob_given_ref, N...)
401+
return T(S, decoder, prob_given_ref)
402402
end
403403
+(d1::U, d2::U) where U <: UnivariateFinite = _plus(d1, d2, UnivariateFinite)
404+
+(d1::U, d2::U) where U <: UnivariateFiniteArray =
405+
_plus(d1, d2, UnivariateFiniteArray)
404406

405-
function _minus(d, T, N...)
407+
function _minus(d, T)
406408
S = d.scitype
407409
decoder = d.decoder
408410
prob_given_ref = copy(d.prob_given_ref)
409411
for ref in keys(prob_given_ref)
410412
prob_given_ref[ref] = -prob_given_ref[ref]
411413
end
412-
return T(S, decoder, prob_given_ref, N...)
414+
return T(S, decoder, prob_given_ref)
413415
end
414416
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
417+
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)
415418

416-
function _minus(d1, d2, T, N...)
419+
function _minus(d1, d2, T)
417420
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
418421
S = d1.scitype
419422
decoder = d1.decoder
420423
prob_given_ref = copy(d1.prob_given_ref)
421424
for ref in keys(prob_given_ref)
422425
prob_given_ref[ref] -= d2.prob_given_ref[ref]
423426
end
424-
return T(S, decoder, prob_given_ref, N...)
427+
return T(S, decoder, prob_given_ref)
425428
end
426429
-(d1::U, d2::U) where U <: UnivariateFinite = _minus(d1, d2, UnivariateFinite)
430+
-(d1::U, d2::U) where U <: UnivariateFiniteArray =
431+
_minus(d1, d2, UnivariateFiniteArray)
427432

428433
# TODO: remove type restrction on `x` in the following methods if
429434
# https://github.com/JuliaStats/Distributions.jl/issues/1438 is
430435
# resolved. Currently we'd have a method ambiguity
431436

432-
function _times(d, x, T, N...)
437+
function _times(d, x, T)
433438
S = d.scitype
434439
decoder = d.decoder
435440
prob_given_ref = copy(d.prob_given_ref)
436441
for ref in keys(prob_given_ref)
437442
prob_given_ref[ref] *= x
438443
end
439-
return UnivariateFinite(d.scitype, decoder, prob_given_ref)
444+
return T(d.scitype, decoder, prob_given_ref)
440445
end
441446
*(d::UnivariateFinite, x::Real) = _times(d, x, UnivariateFinite)
447+
*(d::UnivariateFiniteArray, x::Real) = _times(d, x, UnivariateFiniteArray)
442448

443449
*(x::Real, d::SingletonOrArray) = d*x
444450
/(d::SingletonOrArray, x::Real) = d*inv(x)

src/types.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,6 @@ function _augment_probs(::Val{true},
277277
probs::AbstractArray{P,N}) where {P,N}
278278
_check_probs_01(probs)
279279
aug_size = [size(probs)..., 2]
280-
@show probs P
281280
augmentation = one(P) .- probs
282281
all(0 .<= augmentation .<= 1) || throw(ERR_AUG)
283282
aug_probs = Array{P}(undef, aug_size...)

test/methods.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,23 +287,55 @@ d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
287287
for op in [:+, :-]
288288
quote
289289
s = $op(d1, d2 )
290-
@test $op(pdf(d1, L), pdf(d2, L)) pdf(s, L)
290+
@test $op(pdf.(d1, L), pdf.(d2, L)) pdf.(s, L)
291291
end |> eval
292292
end
293293

294294
# negative:
295295
d_neg = -d1
296-
@test pdf(d_neg, L) == -pdf(d1, L)
296+
@test pdf.(d_neg, L) == -pdf.(d1, L)
297297

298298
# multiplication by scalar:
299299
d3 = d1*42
300-
@test pdf(d3, L) pdf(d1, L)*42
300+
@test pdf.(d3, L) pdf.(d1, L)*42
301301
d3 = 42*d1
302-
@test pdf(d3, L) pdf(d1, L)*42
302+
@test pdf.(d3, L) pdf.(d1, L)*42
303303

304304
# division by scalar:
305305
d3 = d1/42
306-
@test pdf(d3, L) pdf(d1, L)/42
306+
@test pdf.(d3, L) pdf.(d1, L)/42
307+
end
308+
309+
p = [0.1, 0.9]
310+
P = vcat(fill(p', 10^5)...)
311+
slow = fill(UnivariateFinite(L, p, pool=missing), 10^5)
312+
fast = UnivariateFinite(L, P, pool=missing)
313+
# @assert pdf(slow, L) == pdf(fast, L)
314+
315+
@testset "performant arithmetic for UnivariateFiniteArray" begin
316+
@test pdf(slow + slow, L) == pdf(fast + fast, L)
317+
t_slow = @elapsed @eval slow + slow
318+
t_fast = @elapsed @eval fast + fast
319+
@test t_slow/t_fast > 10
320+
321+
@test pdf(slow - slow, L) == pdf(fast - fast, L)
322+
t_slow = @elapsed @eval slow - slow
323+
t_fast = @elapsed @eval fast - fast
324+
@test t_slow/t_fast > 10
325+
326+
@test pdf(42*slow, L) == pdf(42*fast, L)
327+
@test pdf(slow*42, L) == pdf(fast*42, L)
328+
t_slow = @elapsed @eval 42*slow
329+
t_fast = @elapsed @eval 42*fast
330+
@test t_slow/t_fast > 10
331+
t_slow = @elapsed @eval slow*42
332+
t_fast = @elapsed @eval fast*42
333+
@test t_slow/t_fast > 10
334+
335+
@test pdf(slow/42, L) == pdf(fast/42, L)
336+
t_slow = @elapsed @eval slow/42
337+
t_fast = @elapsed @eval fast/42
338+
@test t_slow/t_fast > 10
307339
end
308340

309341
end # module

0 commit comments

Comments
 (0)