@@ -79,11 +79,22 @@ function Base.show(stream::IO, d::UnivariateFinite)
79
79
print (stream, " UnivariateFinite{$(d. scitype) }($arg_str )" )
80
80
end
81
81
82
+ Base. show (io:: IO , mime:: MIME"text/plain" ,
83
+ d:: UnivariateFinite ) = show (io, d)
84
+
85
+ # in common case of `Real` probabilities we can do a pretty bar plot:
82
86
function Base. show (io:: IO , mime:: MIME"text/plain" ,
83
- d:: UnivariateFinite{S} ) where S
87
+ d:: UnivariateFinite{<:Finite{K},V,R,P} ) where {K,V,R,P<: Real }
88
+ show_bars = false
89
+ if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
90
+ all (>= (0 ), values (d. prob_given_ref))
91
+ show_bars = true
92
+ end
93
+ show_bars || return show (io, d)
84
94
s = support (d)
85
95
x = string .(CategoricalArrays. DataAPI. unwrap .(s))
86
96
y = pdf .(d, s)
97
+ S = d. scitype
87
98
plt = barplot (x, y, title= " UnivariateFinite{$S }" )
88
99
show (io, mime, plt)
89
100
end
125
136
126
137
# TODO : It would be useful to define == as well.
127
138
128
- # TODO : Now that `UnivariateFinite` is any finite measure, we can
129
- # replace the following nonsense with an overloading of `+`. I think
130
- # it is only used in MLJEnsembles.jl - but need to check. I believe
131
- # this is a private method we can easily remove
132
-
133
- function average (dvec:: AbstractVector{UnivariateFinite{S,V,R,P}} ;
134
- weights= nothing ) where {S,V,R,P}
135
-
136
- n = length (dvec)
137
-
138
- Dist. @check_args (UnivariateFinite, weights == nothing || n== length (weights))
139
-
140
- # check all distributions have consistent pool:
141
- first_index = first (dvec). decoder. classes
142
- for d in dvec
143
- d. decoder. classes == first_index ||
144
- error (" Averaging UnivariateFinite distributions with incompatible" *
145
- " pools. " )
146
- end
147
-
148
- # get all refs:
149
- refs = reduce (union, [keys (d. prob_given_ref) for d in dvec]) |> collect
150
-
151
- # initialize the prob dictionary for the distribution sum:
152
- prob_given_ref = LittleDict {R,P} ([refs... ], zeros (P, length (refs)))
153
-
154
- # make vector of all the distributions dicts padded to have same common keys:
155
- prob_given_ref_vec = map (dvec) do d
156
- merge (prob_given_ref, d. prob_given_ref)
157
- end
158
-
159
- # sum up:
160
- if weights == nothing
161
- scale = 1 / n
162
- for x in refs
163
- for k in 1 : n
164
- prob_given_ref[x] += scale* prob_given_ref_vec[k][x]
165
- end
166
- end
167
- else
168
- scale = 1 / sum (weights)
169
- for x in refs
170
- for k in 1 : n
171
- prob_given_ref[x] +=
172
- weights[k]* prob_given_ref_vec[k][x]* scale
173
- end
174
- end
175
- end
176
- d1 = first (dvec)
177
- return UnivariateFinite (sample_scitype (d1), d1. decoder, prob_given_ref)
178
- end
179
-
180
139
"""
181
140
Dist.pdf(d::UnivariateFinite, x)
182
141
@@ -371,3 +330,9 @@ function Dist.fit(d::Type{<:UnivariateFinite},
371
330
end
372
331
373
332
333
+ # # BROADCASTING OVER SINGLE UNIVARIATE FINITE
334
+
335
+ # This mirrors behaviour assigned Distributions.Distribution objects,
336
+ # which allows `pdf.(d::UnivariateFinite, support(d))` to work.
337
+
338
+ Broadcast. broadcastable (d:: UnivariateFinite ) = Ref (d)
0 commit comments