Skip to content

Commit 5ccc885

Browse files
authored
Merge pull request #71 from tiemvanderdeure/fix_arithmetic_dispatch
allow dispatch on mixed number types for + and -
2 parents 5c19ce3 + c4b8ae9 commit 5ccc885

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/arithmetic.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@ import Base: +, *, /, -
99
pdf_matrix(d::UnivariateFinite, L) = pdf.(d, L)
1010
pdf_matrix(d::AbstractArray{<:UnivariateFinite}, L) = pdf(d, L)
1111

12-
function +(d1::U, d2::U) where U <: SingletonOrArray
12+
function +(d1::UnivariateFinite{S, V}, d2::UnivariateFinite{S, V}) where {S, V}
13+
L = classes(d1)
14+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
15+
return UnivariateFinite(L, pdf_matrix(d1, L) + pdf_matrix(d2, L))
16+
end
17+
function +(d1::UnivariateFiniteArray{S, V}, d2::UnivariateFiniteArray{S, V}) where {S, V}
1318
L = classes(d1)
1419
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
1520
return UnivariateFinite(L, pdf_matrix(d1, L) + pdf_matrix(d2, L))
@@ -27,7 +32,12 @@ end
2732
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
2833
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)
2934

30-
function -(d1::U, d2::U) where U <: SingletonOrArray
35+
function -(d1::UnivariateFinite{S, V}, d2::UnivariateFinite{S, V}) where {S, V}
36+
L = classes(d1)
37+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
38+
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))
39+
end
40+
function -(d1::UnivariateFiniteArray{S, V}, d2::UnivariateFiniteArray{S, V}) where {S, V}
3141
L = classes(d1)
3242
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
3343
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))

test/arithmetic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ end
1313
L = ["yes", "no"]
1414
d1 = UnivariateFinite(L, rand(rng, 2), pool=missing)
1515
d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
16+
df32 = UnivariateFinite(L, rand(rng, Float32, 2), pool=missing)
1617

1718
@testset "arithmetic" begin
1819

1920
# addition and subtraction:
2021
for op in [:+, :-]
2122
quote
2223
s = $op(d1, d2 )
24+
s2 = $op(d1, df32 )
2325
@test $op(pdf.(d1, L), pdf.(d2, L)) pdf.(s, L)
26+
@test $op(pdf.(d1, L), pdf.(df32, L)) pdf.(s2, L)
2427
end |> eval
2528
end
2629

0 commit comments

Comments
 (0)