Skip to content

Commit 7835e19

Browse files
committed
some fixes and minor refactoring of arithmetic to allow generalization
1 parent d065f06 commit 7835e19

File tree

3 files changed

+28
-16
lines changed

3 files changed

+28
-16
lines changed

src/methods.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,45 +388,48 @@ const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
388388
"Adding two `UnivariateFinite` objects whose "*
389389
"sample spaces have different labellings is not allowed. ")
390390

391-
import Base: +, *, /
391+
import Base: +, *, /, -
392392

393-
function +(d1::U, d2::U) where U <: UnivariateFinite
393+
function _plus(d1, d2, T, N...) # N... is single integer or absent
394394
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
395395
S = d1.scitype
396396
decoder = d1.decoder
397397
prob_given_ref = copy(d1.prob_given_ref)
398398
for ref in keys(prob_given_ref)
399399
prob_given_ref[ref] += d2.prob_given_ref[ref]
400400
end
401-
return UnivariateFinite(S, decoder, prob_given_ref)
401+
return T(S, decoder, prob_given_ref, N...)
402402
end
403+
+(d1::U, d2::U) where U <: UnivariateFinite = _plus(d1, d2, UnivariateFinite)
403404

404-
function -(d::UnivariateFinite)
405+
function _minus(d, T, N...)
405406
S = d.scitype
406407
decoder = d.decoder
407408
prob_given_ref = copy(d.prob_given_ref)
408409
for ref in keys(prob_given_ref)
409410
prob_given_ref[ref] = -prob_given_ref[ref]
410411
end
411-
return UnivariateFinite(S, decoder, prob_given_ref)
412+
return T(S, decoder, prob_given_ref, N...)
412413
end
414+
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
413415

414-
function -(d1::U, d2::U) where U <: UnivariateFinite
416+
function _minus(d1, d2, T, N...)
415417
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
416418
S = d1.scitype
417419
decoder = d1.decoder
418420
prob_given_ref = copy(d1.prob_given_ref)
419421
for ref in keys(prob_given_ref)
420422
prob_given_ref[ref] -= d2.prob_given_ref[ref]
421423
end
422-
return UnivariateFinite(S, decoder, prob_given_ref)
424+
return T(S, decoder, prob_given_ref, N...)
423425
end
426+
-(d1::U, d2::U) where U <: UnivariateFinite = _minus(d1, d2, UnivariateFinite)
424427

425428
# TODO: remove type restrction on `x` in the following methods if
426429
# https://github.com/JuliaStats/Distributions.jl/issues/1438 is
427430
# resolved. Currently we'd have a method ambiguity
428431

429-
function *(d::UnivariateFinite, x::Real)
432+
function _times(d, x, T, N...)
430433
S = d.scitype
431434
decoder = d.decoder
432435
prob_given_ref = copy(d.prob_given_ref)
@@ -435,6 +438,7 @@ function *(d::UnivariateFinite, x::Real)
435438
end
436439
return UnivariateFinite(d.scitype, decoder, prob_given_ref)
437440
end
438-
*(x::Real, d::UnivariateFinite) = d*x
441+
*(d::UnivariateFinite, x::Real) = _times(d, x, UnivariateFinite)
439442

440-
/(d::UnivariateFinite, x::Real) = d*inv(x)
443+
*(x::Real, d::SingletonOrArray) = d*x
444+
/(d::SingletonOrArray, x::Real) = d*inv(x)

src/types.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ end
203203

204204
const UnivariateFiniteVector{S,V,R,P} = UnivariateFiniteArray{S,V,R,P,1}
205205

206+
# private:
207+
const SingletonOrArray{S,V,R,P} = Union{UnivariateFinite{S,V,R,P},
208+
UnivariateFiniteArray{S,V,R,P}}
209+
206210

207211
# # CHECKS AND ERROR MESSAGES
208212

@@ -273,6 +277,7 @@ function _augment_probs(::Val{true},
273277
probs::AbstractArray{P,N}) where {P,N}
274278
_check_probs_01(probs)
275279
aug_size = [size(probs)..., 2]
280+
@show probs P
276281
augmentation = one(P) .- probs
277282
all(0 .<= augmentation .<= 1) || throw(ERR_AUG)
278283
aug_probs = Array{P}(undef, aug_size...)

test/methods.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,11 @@ end
277277
# @test v ≈ v_close
278278
end
279279

280-
@tesset "arithmetic" begin
281-
L = ["yes", "no"]
282-
d1 = UnivariateFinite(L, rand(rng, 2), pool=missing)
283-
d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
280+
L = ["yes", "no"]
281+
d1 = UnivariateFinite(L, rand(rng, 2), pool=missing)
282+
d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
283+
284+
@testset "arithmetic" begin
284285

285286
# addition and subtraction:
286287
for op in [:+, :-]
@@ -295,8 +296,10 @@ end
295296
@test pdf(d_neg, L) == -pdf(d1, L)
296297

297298
# multiplication by scalar:
298-
d3 = d1/42
299-
@test pdf(d3, L) pdf(d1, L)/42
299+
d3 = d1*42
300+
@test pdf(d3, L) pdf(d1, L)*42
301+
d3 = 42*d1
302+
@test pdf(d3, L) pdf(d1, L)*42
300303

301304
# division by scalar:
302305
d3 = d1/42

0 commit comments

Comments
 (0)