@@ -390,55 +390,61 @@ const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
390
390
391
391
import Base: + , * , / , -
392
392
393
- function _plus (d1, d2, T, N ... ) # N... is single integer or absent
393
+ function _plus (d1, d2, T)
394
394
classes (d1) == classes (d2) || throw (ERR_DIFFERENT_SAMPLE_SPACES)
395
395
S = d1. scitype
396
396
decoder = d1. decoder
397
397
prob_given_ref = copy (d1. prob_given_ref)
398
398
for ref in keys (prob_given_ref)
399
399
prob_given_ref[ref] += d2. prob_given_ref[ref]
400
400
end
401
- return T (S, decoder, prob_given_ref, N ... )
401
+ return T (S, decoder, prob_given_ref)
402
402
end
403
403
+ (d1:: U , d2:: U ) where U <: UnivariateFinite = _plus (d1, d2, UnivariateFinite)
404
+ + (d1:: U , d2:: U ) where U <: UnivariateFiniteArray =
405
+ _plus (d1, d2, UnivariateFiniteArray)
404
406
405
- function _minus (d, T, N ... )
407
+ function _minus (d, T)
406
408
S = d. scitype
407
409
decoder = d. decoder
408
410
prob_given_ref = copy (d. prob_given_ref)
409
411
for ref in keys (prob_given_ref)
410
412
prob_given_ref[ref] = - prob_given_ref[ref]
411
413
end
412
- return T (S, decoder, prob_given_ref, N ... )
414
+ return T (S, decoder, prob_given_ref)
413
415
end
414
416
- (d:: UnivariateFinite ) = _minus (d, UnivariateFinite)
417
+ - (d:: UnivariateFiniteArray ) = _minus (d, UnivariateFiniteArray)
415
418
416
- function _minus (d1, d2, T, N ... )
419
+ function _minus (d1, d2, T)
417
420
classes (d1) == classes (d2) || throw (ERR_DIFFERENT_SAMPLE_SPACES)
418
421
S = d1. scitype
419
422
decoder = d1. decoder
420
423
prob_given_ref = copy (d1. prob_given_ref)
421
424
for ref in keys (prob_given_ref)
422
425
prob_given_ref[ref] -= d2. prob_given_ref[ref]
423
426
end
424
- return T (S, decoder, prob_given_ref, N ... )
427
+ return T (S, decoder, prob_given_ref)
425
428
end
426
429
- (d1:: U , d2:: U ) where U <: UnivariateFinite = _minus (d1, d2, UnivariateFinite)
430
+ - (d1:: U , d2:: U ) where U <: UnivariateFiniteArray =
431
+ _minus (d1, d2, UnivariateFiniteArray)
427
432
428
433
# TODO : remove type restrction on `x` in the following methods if
429
434
# https://github.com/JuliaStats/Distributions.jl/issues/1438 is
430
435
# resolved. Currently we'd have a method ambiguity
431
436
432
- function _times (d, x, T, N ... )
437
+ function _times (d, x, T)
433
438
S = d. scitype
434
439
decoder = d. decoder
435
440
prob_given_ref = copy (d. prob_given_ref)
436
441
for ref in keys (prob_given_ref)
437
442
prob_given_ref[ref] *= x
438
443
end
439
- return UnivariateFinite (d. scitype, decoder, prob_given_ref)
444
+ return T (d. scitype, decoder, prob_given_ref)
440
445
end
441
446
* (d:: UnivariateFinite , x:: Real ) = _times (d, x, UnivariateFinite)
447
+ * (d:: UnivariateFiniteArray , x:: Real ) = _times (d, x, UnivariateFiniteArray)
442
448
443
449
* (x:: Real , d:: SingletonOrArray ) = d* x
444
450
/ (d:: SingletonOrArray , x:: Real ) = d* inv (x)
0 commit comments