@@ -388,45 +388,48 @@ const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
388
388
" Adding two `UnivariateFinite` objects whose " *
389
389
" sample spaces have different labellings is not allowed. " )
390
390
391
- import Base: + , * , /
391
+ import Base: + , * , / , -
392
392
393
- function + (d1:: U , d2:: U ) where U <: UnivariateFinite
393
+ function _plus (d1, d2, T, N ... ) # N... is single integer or absent
394
394
classes (d1) == classes (d2) || throw (ERR_DIFFERENT_SAMPLE_SPACES)
395
395
S = d1. scitype
396
396
decoder = d1. decoder
397
397
prob_given_ref = copy (d1. prob_given_ref)
398
398
for ref in keys (prob_given_ref)
399
399
prob_given_ref[ref] += d2. prob_given_ref[ref]
400
400
end
401
- return UnivariateFinite (S, decoder, prob_given_ref)
401
+ return T (S, decoder, prob_given_ref, N ... )
402
402
end
403
+ + (d1:: U , d2:: U ) where U <: UnivariateFinite = _plus (d1, d2, UnivariateFinite)
403
404
404
- function - (d :: UnivariateFinite )
405
+ function _minus (d, T, N ... )
405
406
S = d. scitype
406
407
decoder = d. decoder
407
408
prob_given_ref = copy (d. prob_given_ref)
408
409
for ref in keys (prob_given_ref)
409
410
prob_given_ref[ref] = - prob_given_ref[ref]
410
411
end
411
- return UnivariateFinite (S, decoder, prob_given_ref)
412
+ return T (S, decoder, prob_given_ref, N ... )
412
413
end
414
+ - (d:: UnivariateFinite ) = _minus (d, UnivariateFinite)
413
415
414
- function - (d1:: U , d2:: U ) where U <: UnivariateFinite
416
+ function _minus (d1, d2, T, N ... )
415
417
classes (d1) == classes (d2) || throw (ERR_DIFFERENT_SAMPLE_SPACES)
416
418
S = d1. scitype
417
419
decoder = d1. decoder
418
420
prob_given_ref = copy (d1. prob_given_ref)
419
421
for ref in keys (prob_given_ref)
420
422
prob_given_ref[ref] -= d2. prob_given_ref[ref]
421
423
end
422
- return UnivariateFinite (S, decoder, prob_given_ref)
424
+ return T (S, decoder, prob_given_ref, N ... )
423
425
end
426
+ - (d1:: U , d2:: U ) where U <: UnivariateFinite = _minus (d1, d2, UnivariateFinite)
424
427
425
428
# TODO : remove type restrction on `x` in the following methods if
426
429
# https://github.com/JuliaStats/Distributions.jl/issues/1438 is
427
430
# resolved. Currently we'd have a method ambiguity
428
431
429
- function * (d :: UnivariateFinite , x:: Real )
432
+ function _times (d , x, T, N ... )
430
433
S = d. scitype
431
434
decoder = d. decoder
432
435
prob_given_ref = copy (d. prob_given_ref)
@@ -435,6 +438,7 @@ function *(d::UnivariateFinite, x::Real)
435
438
end
436
439
return UnivariateFinite (d. scitype, decoder, prob_given_ref)
437
440
end
438
- * (x :: Real , d :: UnivariateFinite ) = d * x
441
+ * (d :: UnivariateFinite , x :: Real ) = _times (d, x, UnivariateFinite)
439
442
440
- / (d:: UnivariateFinite , x:: Real ) = d* inv (x)
443
+ * (x:: Real , d:: SingletonOrArray ) = d* x
444
+ / (d:: SingletonOrArray , x:: Real ) = d* inv (x)
0 commit comments