@@ -79,11 +79,22 @@ function Base.show(stream::IO, d::UnivariateFinite)
79
79
print (stream, " UnivariateFinite{$(d. scitype) }($arg_str )" )
80
80
end
81
81
82
+ Base. show (io:: IO , mime:: MIME"text/plain" ,
83
+ d:: UnivariateFinite ) = show (io, d)
84
+
85
+ # in common case of `Real` probabilities we can do a pretty bar plot:
82
86
function Base. show (io:: IO , mime:: MIME"text/plain" ,
83
- d:: UnivariateFinite{S} ) where S
87
+ d:: UnivariateFinite{<:Finite{K},V,R,P} ) where {K,V,R,P<: Real }
88
+ show_bars = false
89
+ if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
90
+ all (>= (0 ), values (d. prob_given_ref))
91
+ show_bars = true
92
+ end
93
+ show_bars || return show (io, d)
84
94
s = support (d)
85
95
x = string .(CategoricalArrays. DataAPI. unwrap .(s))
86
96
y = pdf .(d, s)
97
+ S = d. scitype
87
98
plt = barplot (x, y, title= " UnivariateFinite{$S }" )
88
99
show (io, mime, plt)
89
100
end
@@ -371,3 +382,59 @@ function Dist.fit(d::Type{<:UnivariateFinite},
371
382
end
372
383
373
384
385
+ # ## ARITHMETIC
386
+
387
+ const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError (
388
+ " Adding two `UnivariateFinite` objects whose " *
389
+ " sample spaces have different labellings is not allowed. " )
390
+
391
+ import Base: + , * , /
392
+
393
+ function + (d1:: U , d2:: U ) where U <: UnivariateFinite
394
+ classes (d1) == classes (d2) || throw (ERR_DIFFERENT_SAMPLE_SPACES)
395
+ S = d1. scitype
396
+ decoder = d1. decoder
397
+ prob_given_ref = copy (d1. prob_given_ref)
398
+ for ref in keys (prob_given_ref)
399
+ prob_given_ref[ref] += d2. prob_given_ref[ref]
400
+ end
401
+ return UnivariateFinite (S, decoder, prob_given_ref)
402
+ end
403
+
404
+ function - (d:: UnivariateFinite )
405
+ S = d. scitype
406
+ decoder = d. decoder
407
+ prob_given_ref = copy (d. prob_given_ref)
408
+ for ref in keys (prob_given_ref)
409
+ prob_given_ref[ref] = - prob_given_ref[ref]
410
+ end
411
+ return UnivariateFinite (S, decoder, prob_given_ref)
412
+ end
413
+
414
+ function - (d1:: U , d2:: U ) where U <: UnivariateFinite
415
+ classes (d1) == classes (d2) || throw (ERR_DIFFERENT_SAMPLE_SPACES)
416
+ S = d1. scitype
417
+ decoder = d1. decoder
418
+ prob_given_ref = copy (d1. prob_given_ref)
419
+ for ref in keys (prob_given_ref)
420
+ prob_given_ref[ref] -= d2. prob_given_ref[ref]
421
+ end
422
+ return UnivariateFinite (S, decoder, prob_given_ref)
423
+ end
424
+
425
+ # TODO : remove type restrction on `x` in the following methods if
426
+ # https://github.com/JuliaStats/Distributions.jl/issues/1438 is
427
+ # resolved. Currently we'd have a method ambiguity
428
+
429
+ function * (d:: UnivariateFinite , x:: Real )
430
+ S = d. scitype
431
+ decoder = d. decoder
432
+ prob_given_ref = copy (d. prob_given_ref)
433
+ for ref in keys (prob_given_ref)
434
+ prob_given_ref[ref] *= x
435
+ end
436
+ return UnivariateFinite (d. scitype, decoder, prob_given_ref)
437
+ end
438
+ * (x:: Real , d:: UnivariateFinite ) = d* x
439
+
440
+ / (d:: UnivariateFinite , x:: Real ) = d* inv (x)
0 commit comments