@@ -4,23 +4,54 @@ const UniFinArr = UnivariateFiniteArray
4
4
5
5
Base. size (u:: UniFinArr , args... ) =
6
6
size (first (values (u. prob_given_ref)), args... )
7
-
8
- function Base. getindex (u:: UniFinArr{<:Any,<:Any,R,P,N} ,
9
- i:: Integer... ) where {R,P,N}
10
- prob_given_ref = LittleDict {R,P} ()
11
- for ref in keys (u. prob_given_ref)
12
- prob_given_ref[ref] = getindex (u. prob_given_ref[ref], i... )
7
+
8
+ function Base. getindex (u:: UniFinArr{<:Any,<:Any,R, P} , i... ) where {R, P}
9
+ # It's faster to generate `Array`s of `refs` and indexed `ref_probs`
10
+ # and pass them to the `LittleDict` constructor.
11
+ # The first element of `u.prob_given_ref` is used to get the dimensions
12
+ # for allocating these arrays.
13
+ u_dict = u. prob_given_ref
14
+ a, rest = Iterators. peel (u_dict)
15
+ # `a` is of the form `key => value`.
16
+ a_ref, a_prob = first (a), getindex (last (a), i... )
17
+
18
+ # Preallocate Arrays using the key and value of the first
19
+ # element (i.e `a`) of `u_dict`.
20
+ n_refs = length (u_dict)
21
+ refs = Vector {R} (undef, n_refs)
22
+ if a_prob isa AbstractArray
23
+ ref_probs = Vector {Array{P, ndims(a_prob)}} (undef, n_refs)
24
+ unf_constructor = UniFinArr
25
+ else
26
+ ref_probs = Vector {P} (undef, n_refs)
27
+ unf_constructor = UnivariateFinite
28
+ end
29
+
30
+ # Fill in the first elements
31
+ # Both `refs` and `ref_probs` are both of type `Vector` and hence support
32
+ # linear indexing with index starting at `1`
33
+ refs[1 ] = a_ref
34
+ ref_probs[1 ] = a_prob
35
+
36
+ # Fill in the rest
37
+ iter = 2
38
+ for (ref, ref_prob) in rest
39
+ refs[iter] = ref
40
+ ref_probs[iter] = getindex (ref_prob, i... )
41
+ iter += 1
13
42
end
14
- return UnivariateFinite (u. scitype, u. decoder, prob_given_ref)
43
+
44
+ # `keytype(prob_given_ref)` is always same as `keytype(u_dict)`.
45
+ # But `ndims(valtype(prob_given_ref))` might not be the same
46
+ # as `ndims(valtype(u_dict))`.
47
+ prob_given_ref = LittleDict {R, eltype(ref_probs)} (refs, ref_probs)
48
+
49
+ return unf_constructor (u. scitype, u. decoder, prob_given_ref)
15
50
end
16
51
17
- function Base. getindex (u:: UniFinArr{<:Any,<:Any,R,P,N} ,
18
- I... ) where {R,P,N}
19
- prob_given_ref = LittleDict {R,Array{P,N}} ()
20
- for ref in keys (u. prob_given_ref)
21
- prob_given_ref[ref] = getindex (u. prob_given_ref[ref], I... )
22
- end
23
- return UniFinArr (u. scitype, u. decoder, prob_given_ref)
52
+ function Base. getindex (u:: UniFinArr , idx:: CartesianIndex )
53
+ checkbounds (u, idx)
54
+ return u[Tuple (idx)... ]
24
55
end
25
56
26
57
function Base. setindex! (u:: UniFinArr{S,V,R,P,N} ,
35
66
# TODO : return an exception without throwing it:
36
67
37
68
_err_incompatible_levels () = throw (DomainError (
38
- " Cannot concatenate `UnivariateFiniteArray`s with " *
39
- " different categorical levels (classes), " *
40
- " or whose levels, when ordered, are not " *
69
+ " Cannot concatenate `UnivariateFiniteArray`s with " *
70
+ " different categorical levels (classes), " *
71
+ " or whose levels, when ordered, are not " *
41
72
" consistently ordered. " ))
42
73
43
74
# terminology:
@@ -61,14 +92,12 @@ function Base.cat(us::UniFinArr{S,V,R,P,N}...;
61
92
for i in 2 : length (us)
62
93
isordered (us[i]) == ordered || _err_incompatible_levels ()
63
94
if ordered
64
- classes (us[i]) ==
65
- _classes|| _err_incompatible_levels ()
95
+ classes (us[i]) == _classes || _err_incompatible_levels ()
66
96
else
67
- Set (classes (us[i])) ==
68
- Set (_classes) || _err_incompatible_levels ()
97
+ Set (classes (us[i])) == Set (_classes) || _err_incompatible_levels ()
69
98
end
70
- support_with_duplicates =
71
- vcat (support_with_duplicates, Dist. support (us[i]))
99
+ support_with_duplicates = vcat (support_with_duplicates,
100
+ Dist. support (us[i]))
72
101
end
73
102
_support = unique (support_with_duplicates) # no-longer categorical!
74
103
@@ -99,14 +128,12 @@ for func in [:pdf, :logpdf]
99
128
eval (quote
100
129
function Distributions. $func (
101
130
u:: AbstractArray{UnivariateFinite{S,V,R,P},N} ,
102
- C:: AbstractVector {<: Union {
103
- V,
104
- CategoricalValue{V,R}}}) where {S,V,R,P,N}
131
+ C:: AbstractVector ) where {S,V,R,P,N}
105
132
106
- # ret = Array{P,N+1}(undef, size(u)..., length(C))
107
133
ret = zeros (P, size (u)... , length (C))
108
- for i in eachindex (C)
109
- ret[fill (:,N)... ,i] .= broadcast ($ func, u, C[i])
134
+ # note that we do not require C to use 1-base indexing
135
+ for (i, c) in enumerate (C)
136
+ ret[fill (:,N)... , i] .= broadcast ($ func, u, c)
110
137
end
111
138
return ret
112
139
end
@@ -124,30 +151,35 @@ end
124
151
125
152
# dummy function
126
153
# returns `x[i]` for `Array` inputs `x`
127
- # For non-Array inputs returns `zero(dtype)`
154
+ # For non-Array inputs returns the input
128
155
# This avoids using an if statement
129
- _getindex (x:: Array ,i, dtype) = x[i]
130
- _getindex (:: Nothing , i, dtype ) = zero (dtype)
156
+ _getindex (x:: Array , i) = x[i]
157
+ _getindex (x , i) = x
131
158
132
159
# pdf.(u, cv)
133
160
function Base. Broadcast. broadcasted (
134
161
:: typeof (pdf),
135
162
u:: UniFinArr{S,V,R,P,N} ,
136
163
cv:: CategoricalValue ) where {S,V,R,P,N}
137
164
138
- cv in classes (u) || throw (err_missing_class (cv))
165
+ # we assume that we compare categorical values by their unwrapped value
166
+ # and pick the index of this value from classes(u)
167
+ _classes = classes (u)
168
+ cv_loc = get (CategoricalArrays. pool (_classes), cv, zero (R))
169
+ isequal (cv_loc, 0 ) && throw (err_missing_class (cv))
139
170
140
171
f () = zeros (P, size (u)) # default caller function
141
172
142
173
return Base. Broadcast. Broadcasted (
143
174
identity,
144
- (get (f, u. prob_given_ref, int (cv) ),)
145
- )
175
+ (get (f, u. prob_given_ref, cv_loc ),)
176
+ )
146
177
end
178
+
147
179
Base. Broadcast. broadcasted (
148
180
:: typeof (pdf),
149
181
u:: UniFinArr{S,V,R,P,N} ,
150
- :: Missing ) where {S,V,R,P,N} = Missings. missings (P, length (u))
182
+ :: Missing ) where {S,V,R,P,N} = Missings. missings (P, size (u))
151
183
152
184
# pdf.(u, v)
153
185
function Base. Broadcast. broadcasted (
@@ -160,17 +192,23 @@ function Base.Broadcast.broadcasted(
160
192
length (u) == length (v) || throw (DimensionMismatch (
161
193
" Arrays could not be broadcast to a common size; " *
162
194
" got a dimension with lengths $(length (u)) and $(length (v)) " ))
163
- for cv in v
164
- ismissing (cv) || cv in classes (u) || throw (err_missing_class (cv))
195
+
196
+ _classes = classes (u)
197
+ _classes_pool = CategoricalArrays. pool (_classes)
198
+ T = eltype (v) >: Missing ? Missing : Union{}
199
+ v_loc_flat = Vector {Tuple{Union{R, T}, Int}} (undef, length (v))
200
+
201
+
202
+ for (i, x) in enumerate (v)
203
+ cv_ref = ismissing (x) ? missing : get (_classes_pool, x, zero (R))
204
+ isequal (cv_ref, 0 ) && throw (err_missing_class (x))
205
+ v_loc_flat[i] = (cv_ref, i)
165
206
end
166
207
167
- # will use linear indexing:
168
- v_flat = ((v[i], i) for i in 1 : length (v))
169
-
170
- getter ((cv, i), dtype) =
171
- _getindex (get (u. prob_given_ref, int (cv), nothing ), i, dtype)
172
- getter (:: Tuple{Missing,Any} , dtype) = missing
173
- ret_flat = getter .(v_flat, P)
208
+ getter ((cv_ref, i)) =
209
+ _getindex (get (u. prob_given_ref, cv_ref, zero (P)), i)
210
+ getter (:: Tuple{Missing,Any} ) = missing
211
+ ret_flat = getter .(v_loc_flat)
174
212
return reshape (ret_flat, size (u))
175
213
end
176
214
@@ -243,10 +281,10 @@ function Base.Broadcast.broadcasted(::typeof(mode),
243
281
mode_flat = map (1 : length (u)) do i
244
282
max_prob = maximum (dic[ref][i] for ref in keys (dic))
245
283
m = zero (R)
246
-
247
- # `maximum` of any iterable containing `NaN` would return `NaN`
284
+
285
+ # `maximum` of any iterable containing `NaN` would return `NaN`
248
286
# For this case the index `m` won't be updated in the loop as relations
249
- # involving NaN as one of it's argument always returns false
287
+ # involving NaN as one of it's argument always returns false
250
288
# (e.g `==(NaN, NaN)` returns false)
251
289
throw_nan_error_if_needed (max_prob)
252
290
for ref in keys (dic)
@@ -269,9 +307,7 @@ const ERR_EMPTY_UNIVARIATE_FINITE = ArgumentError(
269
307
" No `UnivariateFinite` object found from which to extract classes. " )
270
308
271
309
function classes (yhat:: AbstractArray{<:Union{Missing,UnivariateFinite}} )
272
- i = findfirst (x -> ! ismissing (x) , yhat)
310
+ i = findfirst (! ismissing, yhat)
273
311
i === nothing && throw (ERR_EMPTY_UNIVARIATE_FINITE)
274
312
return classes (yhat[i])
275
313
end
276
-
277
-
0 commit comments