Skip to content

Commit 233a048

Browse files
committed
move arithmetic code out into separate files
1 parent 3bd2683 commit 233a048

File tree

6 files changed

+142
-127
lines changed

6 files changed

+142
-127
lines changed

src/CategoricalDistributions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ include("utilities.jl")
1717
include("types.jl")
1818
include("methods.jl")
1919
include("arrays.jl")
20+
include("arithmetic.jl")
2021

2122
export UnivariateFinite, UnivariateFiniteArray
2223

src/arithmetic.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# ## ARITHMETIC
2+
3+
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
4+
"Adding two `UnivariateFinite` objects whose "*
5+
"sample spaces have different labellings is not allowed. ")
6+
7+
import Base: +, *, /, -
8+
9+
function _plus(d1, d2, T)
10+
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
11+
S = d1.scitype
12+
decoder = d1.decoder
13+
prob_given_ref = copy(d1.prob_given_ref)
14+
for ref in keys(prob_given_ref)
15+
prob_given_ref[ref] += d2.prob_given_ref[ref]
16+
end
17+
return T(S, decoder, prob_given_ref)
18+
end
19+
+(d1::U, d2::U) where U <: UnivariateFinite = _plus(d1, d2, UnivariateFinite)
20+
+(d1::U, d2::U) where U <: UnivariateFiniteArray =
21+
_plus(d1, d2, UnivariateFiniteArray)
22+
23+
function _minus(d, T)
24+
S = d.scitype
25+
decoder = d.decoder
26+
prob_given_ref = copy(d.prob_given_ref)
27+
for ref in keys(prob_given_ref)
28+
prob_given_ref[ref] = -prob_given_ref[ref]
29+
end
30+
return T(S, decoder, prob_given_ref)
31+
end
32+
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
33+
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)
34+
35+
function _minus(d1, d2, T)
36+
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
37+
S = d1.scitype
38+
decoder = d1.decoder
39+
prob_given_ref = copy(d1.prob_given_ref)
40+
for ref in keys(prob_given_ref)
41+
prob_given_ref[ref] -= d2.prob_given_ref[ref]
42+
end
43+
return T(S, decoder, prob_given_ref)
44+
end
45+
-(d1::U, d2::U) where U <: UnivariateFinite = _minus(d1, d2, UnivariateFinite)
46+
-(d1::U, d2::U) where U <: UnivariateFiniteArray =
47+
_minus(d1, d2, UnivariateFiniteArray)
48+
49+
# TODO: remove type restrction on `x` in the following methods if
50+
# https://github.com/JuliaStats/Distributions.jl/issues/1438 is
51+
# resolved. Currently we'd have a method ambiguity
52+
53+
function _times(d, x, T)
54+
S = d.scitype
55+
decoder = d.decoder
56+
prob_given_ref = copy(d.prob_given_ref)
57+
for ref in keys(prob_given_ref)
58+
prob_given_ref[ref] *= x
59+
end
60+
return T(d.scitype, decoder, prob_given_ref)
61+
end
62+
*(d::UnivariateFinite, x::Real) = _times(d, x, UnivariateFinite)
63+
*(d::UnivariateFiniteArray, x::Real) = _times(d, x, UnivariateFiniteArray)
64+
65+
*(x::Real, d::SingletonOrArray) = d*x
66+
/(d::SingletonOrArray, x::Real) = d*inv(x)

src/methods.jl

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -382,69 +382,3 @@ function Dist.fit(d::Type{<:UnivariateFinite},
382382
end
383383

384384

385-
# ## ARITHMETIC
386-
387-
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
388-
"Adding two `UnivariateFinite` objects whose "*
389-
"sample spaces have different labellings is not allowed. ")
390-
391-
import Base: +, *, /, -
392-
393-
function _plus(d1, d2, T)
394-
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
395-
S = d1.scitype
396-
decoder = d1.decoder
397-
prob_given_ref = copy(d1.prob_given_ref)
398-
for ref in keys(prob_given_ref)
399-
prob_given_ref[ref] += d2.prob_given_ref[ref]
400-
end
401-
return T(S, decoder, prob_given_ref)
402-
end
403-
+(d1::U, d2::U) where U <: UnivariateFinite = _plus(d1, d2, UnivariateFinite)
404-
+(d1::U, d2::U) where U <: UnivariateFiniteArray =
405-
_plus(d1, d2, UnivariateFiniteArray)
406-
407-
function _minus(d, T)
408-
S = d.scitype
409-
decoder = d.decoder
410-
prob_given_ref = copy(d.prob_given_ref)
411-
for ref in keys(prob_given_ref)
412-
prob_given_ref[ref] = -prob_given_ref[ref]
413-
end
414-
return T(S, decoder, prob_given_ref)
415-
end
416-
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
417-
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)
418-
419-
function _minus(d1, d2, T)
420-
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
421-
S = d1.scitype
422-
decoder = d1.decoder
423-
prob_given_ref = copy(d1.prob_given_ref)
424-
for ref in keys(prob_given_ref)
425-
prob_given_ref[ref] -= d2.prob_given_ref[ref]
426-
end
427-
return T(S, decoder, prob_given_ref)
428-
end
429-
-(d1::U, d2::U) where U <: UnivariateFinite = _minus(d1, d2, UnivariateFinite)
430-
-(d1::U, d2::U) where U <: UnivariateFiniteArray =
431-
_minus(d1, d2, UnivariateFiniteArray)
432-
433-
# TODO: remove type restrction on `x` in the following methods if
434-
# https://github.com/JuliaStats/Distributions.jl/issues/1438 is
435-
# resolved. Currently we'd have a method ambiguity
436-
437-
function _times(d, x, T)
438-
S = d.scitype
439-
decoder = d.decoder
440-
prob_given_ref = copy(d.prob_given_ref)
441-
for ref in keys(prob_given_ref)
442-
prob_given_ref[ref] *= x
443-
end
444-
return T(d.scitype, decoder, prob_given_ref)
445-
end
446-
*(d::UnivariateFinite, x::Real) = _times(d, x, UnivariateFinite)
447-
*(d::UnivariateFiniteArray, x::Real) = _times(d, x, UnivariateFiniteArray)
448-
449-
*(x::Real, d::SingletonOrArray) = d*x
450-
/(d::SingletonOrArray, x::Real) = d*inv(x)

test/arithmetic.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
module TestArithmetic
2+
3+
using Test
4+
using CategoricalDistributions
5+
using StableRNGs
6+
rng = StableRNG(123)
7+
8+
L = ["yes", "no"]
9+
d1 = UnivariateFinite(L, rand(rng, 2), pool=missing)
10+
d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
11+
12+
@testset "arithmetic" begin
13+
14+
# addition and subtraction:
15+
for op in [:+, :-]
16+
quote
17+
s = $op(d1, d2 )
18+
@test $op(pdf.(d1, L), pdf.(d2, L)) pdf.(s, L)
19+
end |> eval
20+
end
21+
22+
# negative:
23+
d_neg = -d1
24+
@test pdf.(d_neg, L) == -pdf.(d1, L)
25+
26+
# multiplication by scalar:
27+
d3 = d1*42
28+
@test pdf.(d3, L) pdf.(d1, L)*42
29+
d3 = 42*d1
30+
@test pdf.(d3, L) pdf.(d1, L)*42
31+
32+
# division by scalar:
33+
d3 = d1/42
34+
@test pdf.(d3, L) pdf.(d1, L)/42
35+
end
36+
37+
p = [0.1, 0.9]
38+
P = vcat(fill(p', 10^5)...)
39+
slow = fill(UnivariateFinite(L, p, pool=missing), 10^5)
40+
fast = UnivariateFinite(L, P, pool=missing)
41+
# @assert pdf(slow, L) == pdf(fast, L)
42+
43+
@testset "performant arithmetic for UnivariateFiniteArray" begin
44+
@test pdf(slow + slow, L) == pdf(fast + fast, L)
45+
t_slow = @elapsed @eval slow + slow
46+
t_fast = @elapsed @eval fast + fast
47+
@test t_slow/t_fast > 10
48+
49+
@test pdf(slow - slow, L) == pdf(fast - fast, L)
50+
t_slow = @elapsed @eval slow - slow
51+
t_fast = @elapsed @eval fast - fast
52+
@test t_slow/t_fast > 10
53+
54+
@test pdf(42*slow, L) == pdf(42*fast, L)
55+
@test pdf(slow*42, L) == pdf(fast*42, L)
56+
t_slow = @elapsed @eval 42*slow
57+
t_fast = @elapsed @eval 42*fast
58+
@test t_slow/t_fast > 10
59+
t_slow = @elapsed @eval slow*42
60+
t_fast = @elapsed @eval fast*42
61+
@test t_slow/t_fast > 10
62+
63+
@test pdf(slow/42, L) == pdf(fast/42, L)
64+
t_slow = @elapsed @eval slow/42
65+
t_fast = @elapsed @eval fast/42
66+
@test t_slow/t_fast > 10
67+
end
68+
69+
end # module
70+
71+
true

test/methods.jl

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -277,67 +277,6 @@ end
277277
# @test v ≈ v_close
278278
end
279279

280-
L = ["yes", "no"]
281-
d1 = UnivariateFinite(L, rand(rng, 2), pool=missing)
282-
d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
283-
284-
@testset "arithmetic" begin
285-
286-
# addition and subtraction:
287-
for op in [:+, :-]
288-
quote
289-
s = $op(d1, d2 )
290-
@test $op(pdf.(d1, L), pdf.(d2, L)) pdf.(s, L)
291-
end |> eval
292-
end
293-
294-
# negative:
295-
d_neg = -d1
296-
@test pdf.(d_neg, L) == -pdf.(d1, L)
297-
298-
# multiplication by scalar:
299-
d3 = d1*42
300-
@test pdf.(d3, L) pdf.(d1, L)*42
301-
d3 = 42*d1
302-
@test pdf.(d3, L) pdf.(d1, L)*42
303-
304-
# division by scalar:
305-
d3 = d1/42
306-
@test pdf.(d3, L) pdf.(d1, L)/42
307-
end
308-
309-
p = [0.1, 0.9]
310-
P = vcat(fill(p', 10^5)...)
311-
slow = fill(UnivariateFinite(L, p, pool=missing), 10^5)
312-
fast = UnivariateFinite(L, P, pool=missing)
313-
# @assert pdf(slow, L) == pdf(fast, L)
314-
315-
@testset "performant arithmetic for UnivariateFiniteArray" begin
316-
@test pdf(slow + slow, L) == pdf(fast + fast, L)
317-
t_slow = @elapsed @eval slow + slow
318-
t_fast = @elapsed @eval fast + fast
319-
@test t_slow/t_fast > 10
320-
321-
@test pdf(slow - slow, L) == pdf(fast - fast, L)
322-
t_slow = @elapsed @eval slow - slow
323-
t_fast = @elapsed @eval fast - fast
324-
@test t_slow/t_fast > 10
325-
326-
@test pdf(42*slow, L) == pdf(42*fast, L)
327-
@test pdf(slow*42, L) == pdf(fast*42, L)
328-
t_slow = @elapsed @eval 42*slow
329-
t_fast = @elapsed @eval 42*fast
330-
@test t_slow/t_fast > 10
331-
t_slow = @elapsed @eval slow*42
332-
t_fast = @elapsed @eval fast*42
333-
@test t_slow/t_fast > 10
334-
335-
@test pdf(slow/42, L) == pdf(fast/42, L)
336-
t_slow = @elapsed @eval slow/42
337-
t_fast = @elapsed @eval fast/42
338-
@test t_slow/t_fast > 10
339-
end
340-
341280
end # module
342281

343282
true

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,7 @@ end
2323
@testset "arrays.jl" begin
2424
@test include("arrays.jl")
2525
end
26+
27+
@testset "arithmetic.jl" begin
28+
@test include("arithmetic.jl")
29+
end

0 commit comments

Comments
 (0)