Skip to content

Commit 4676e02

Browse files
committed
address possibility of different encodings for summands
1 parent 5a65151 commit 4676e02

File tree

2 files changed

+14
-27
lines changed

2 files changed

+14
-27
lines changed

src/arithmetic.jl

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,14 @@ const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
66

77
import Base: +, *, /, -
88

9-
function _plus(d1, d2, T)
10-
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
11-
S = d1.scitype
12-
decoder = d1.decoder
13-
prob_given_ref = copy(d1.prob_given_ref)
14-
for ref in keys(prob_given_ref)
15-
prob_given_ref[ref] += d2.prob_given_ref[ref]
16-
end
17-
return T(S, decoder, prob_given_ref)
9+
pdf_matrix(d::UnivariateFinite, L) = pdf.(d, L)
10+
pdf_matrix(d::AbstractArray{<:UnivariateFinite}, L) = pdf(d, L)
11+
12+
function +(d1::U, d2::U) where U <: SingletonOrArray
13+
L = classes(d1)
14+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
15+
return UnivariateFinite(L, pdf_matrix(d1, L) + pdf_matrix(d2, L))
1816
end
19-
+(d1::U, d2::U) where U <: UnivariateFinite = _plus(d1, d2, UnivariateFinite)
20-
+(d1::U, d2::U) where U <: UnivariateFiniteArray =
21-
_plus(d1, d2, UnivariateFiniteArray)
2217

2318
function _minus(d, T)
2419
S = d.scitype
@@ -32,19 +27,11 @@ end
3227
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
3328
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)
3429

35-
function _minus(d1, d2, T)
36-
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
37-
S = d1.scitype
38-
decoder = d1.decoder
39-
prob_given_ref = copy(d1.prob_given_ref)
40-
for ref in keys(prob_given_ref)
41-
prob_given_ref[ref] -= d2.prob_given_ref[ref]
42-
end
43-
return T(S, decoder, prob_given_ref)
30+
function -(d1::U, d2::U) where U <: SingletonOrArray
31+
L = classes(d1)
32+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
33+
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))
4434
end
45-
-(d1::U, d2::U) where U <: UnivariateFinite = _minus(d1, d2, UnivariateFinite)
46-
-(d1::U, d2::U) where U <: UnivariateFiniteArray =
47-
_minus(d1, d2, UnivariateFiniteArray)
4835

4936
# It seems that the restriction `x::Number` below (applying only to the
5037
# array case) is unavoidable because of a method ambiguity with

test/arithmetic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
3939
end
4040

4141
p = [0.1, 0.9]
42-
P = vcat(fill(p', 10^5)...)
43-
slow = fill(UnivariateFinite(L, p, pool=missing), 10^5)
44-
fast = UnivariateFinite(L, P, pool=missing)
42+
P = vcat(fill(p', 10^5)...);
43+
slow = fill(UnivariateFinite(L, p, pool=missing), 10^5);
44+
fast = UnivariateFinite(L, P, pool=missing);
4545
# @assert pdf(slow, L) == pdf(fast, L)
4646

4747
@testset "performant arithmetic for UnivariateFiniteArray" begin

0 commit comments

Comments
 (0)