@@ -26,7 +26,7 @@ function Base.getindex(u::UniFinArr{<:Any,<:Any,R, P}, i...) where {R, P}
26
26
ref_probs = Vector {P} (undef, n_refs)
27
27
unf_constructor = UnivariateFinite
28
28
end
29
-
29
+
30
30
# Fill in the first elements
31
31
# Both `refs` and `ref_probs` are both of type `Vector` and hence support
32
32
# linear indexing with index starting at `1`
@@ -49,6 +49,11 @@ function Base.getindex(u::UniFinArr{<:Any,<:Any,R, P}, i...) where {R, P}
49
49
return unf_constructor (u. scitype, u. decoder, prob_given_ref)
50
50
end
51
51
52
+ function Base. getindex (u:: UniFinArr , idx:: CartesianIndex )
53
+ checkbounds (u, idx)
54
+ return u[Tuple (idx)... ]
55
+ end
56
+
52
57
function Base. setindex! (u:: UniFinArr{S,V,R,P,N} ,
53
58
v:: UnivariateFinite{S,V,R,P} ,
54
59
i:: Integer... ) where {S,V,R,P,N}
61
66
# TODO : return an exception without throwing it:
62
67
63
68
_err_incompatible_levels () = throw (DomainError (
64
- " Cannot concatenate `UnivariateFiniteArray`s with " *
65
- " different categorical levels (classes), " *
66
- " or whose levels, when ordered, are not " *
69
+ " Cannot concatenate `UnivariateFiniteArray`s with " *
70
+ " different categorical levels (classes), " *
71
+ " or whose levels, when ordered, are not " *
67
72
" consistently ordered. " ))
68
73
69
74
# terminology:
@@ -87,14 +92,12 @@ function Base.cat(us::UniFinArr{S,V,R,P,N}...;
87
92
for i in 2 : length (us)
88
93
isordered (us[i]) == ordered || _err_incompatible_levels ()
89
94
if ordered
90
- classes (us[i]) ==
91
- _classes|| _err_incompatible_levels ()
95
+ classes (us[i]) == _classes || _err_incompatible_levels ()
92
96
else
93
- Set (classes (us[i])) ==
94
- Set (_classes) || _err_incompatible_levels ()
97
+ Set (classes (us[i])) == Set (_classes) || _err_incompatible_levels ()
95
98
end
96
- support_with_duplicates =
97
- vcat (support_with_duplicates, Dist. support (us[i]))
99
+ support_with_duplicates = vcat (support_with_duplicates,
100
+ Dist. support (us[i]))
98
101
end
99
102
_support = unique (support_with_duplicates) # no-longer categorical!
100
103
@@ -125,14 +128,12 @@ for func in [:pdf, :logpdf]
125
128
eval (quote
126
129
function Distributions. $func (
127
130
u:: AbstractArray{UnivariateFinite{S,V,R,P},N} ,
128
- C:: AbstractVector {<: Union {
129
- V,
130
- CategoricalValue{V,R}}}) where {S,V,R,P,N}
131
+ C:: AbstractVector ) where {S,V,R,P,N}
131
132
132
- # ret = Array{P,N+1}(undef, size(u)..., length(C))
133
133
ret = zeros (P, size (u)... , length (C))
134
- for i in eachindex (C)
135
- ret[fill (:,N)... ,i] .= broadcast ($ func, u, C[i])
134
+ # note that we do not require C to use 1-base indexing
135
+ for (i, c) in enumerate (C)
136
+ ret[fill (:,N)... , i] .= broadcast ($ func, u, c)
136
137
end
137
138
return ret
138
139
end
152
153
# returns `x[i]` for `Array` inputs `x`
153
154
# For non-Array inputs returns `zero(dtype)`
154
155
# This avoids using an if statement
155
- _getindex (x:: Array ,i, dtype)= x[i]
156
+ _getindex (x:: Array , i, dtype)= x[i]
156
157
_getindex (:: Nothing , i, dtype) = zero (dtype)
157
158
158
159
# pdf.(u, cv)
@@ -161,19 +162,23 @@ function Base.Broadcast.broadcasted(
161
162
u:: UniFinArr{S,V,R,P,N} ,
162
163
cv:: CategoricalValue ) where {S,V,R,P,N}
163
164
164
- cv in classes (u) || throw (err_missing_class (cv))
165
+ # we assume that we compare categorical values by their unwrapped value
166
+ # and pick the index of this value from classes(u)
167
+ cv_loc = findfirst (== (cv), classes (u))
168
+ cv_loc == 0 && throw (err_missing_class (cv))
165
169
166
170
f () = zeros (P, size (u)) # default caller function
167
171
168
172
return Base. Broadcast. Broadcasted (
169
173
identity,
170
- (get (f, u. prob_given_ref, int (cv) ),)
174
+ (get (f, u. prob_given_ref, cv_loc ),)
171
175
)
172
176
end
177
+
173
178
Base. Broadcast. broadcasted (
174
179
:: typeof (pdf),
175
180
u:: UniFinArr{S,V,R,P,N} ,
176
- :: Missing ) where {S,V,R,P,N} = Missings. missings (P, length (u))
181
+ :: Missing ) where {S,V,R,P,N} = Missings. missings (P, size (u))
177
182
178
183
# pdf.(u, v)
179
184
function Base. Broadcast. broadcasted (
@@ -186,17 +191,15 @@ function Base.Broadcast.broadcasted(
186
191
length (u) == length (v) || throw (DimensionMismatch (
187
192
" Arrays could not be broadcast to a common size; " *
188
193
" got a dimension with lengths $(length (u)) and $(length (v)) " ))
189
- for cv in v
190
- ismissing (cv) || cv in classes (u) || throw (err_missing_class (cv))
191
- end
192
194
193
- # will use linear indexing:
194
- v_flat = ((v[i], i) for i in 1 : length (v))
195
+ v_loc_flat = [(ismissing (x) ? missing : findfirst (== (x), classes (u)), i)
196
+ for (i, x) in enumerate (v)]
197
+ any (isequal (0 ), v_loc_flat) && throw (err_missing_class (cv))
195
198
196
- getter ((cv , i), dtype) =
197
- _getindex (get (u. prob_given_ref, int (cv) , nothing ), i, dtype)
199
+ getter ((cv_loc , i), dtype) =
200
+ _getindex (get (u. prob_given_ref, cv_loc , nothing ), i, dtype)
198
201
getter (:: Tuple{Missing,Any} , dtype) = missing
199
- ret_flat = getter .(v_flat , P)
202
+ ret_flat = getter .(v_loc_flat , P)
200
203
return reshape (ret_flat, size (u))
201
204
end
202
205
@@ -269,10 +272,10 @@ function Base.Broadcast.broadcasted(::typeof(mode),
269
272
mode_flat = map (1 : length (u)) do i
270
273
max_prob = maximum (dic[ref][i] for ref in keys (dic))
271
274
m = zero (R)
272
-
273
- # `maximum` of any iterable containing `NaN` would return `NaN`
275
+
276
+ # `maximum` of any iterable containing `NaN` would return `NaN`
274
277
# For this case the index `m` won't be updated in the loop as relations
275
- # involving NaN as one of it's argument always returns false
278
+ # involving NaN as one of it's argument always returns false
276
279
# (e.g `==(NaN, NaN)` returns false)
277
280
throw_nan_error_if_needed (max_prob)
278
281
for ref in keys (dic)
@@ -295,9 +298,7 @@ const ERR_EMPTY_UNIVARIATE_FINITE = ArgumentError(
295
298
" No `UnivariateFinite` object found from which to extract classes. " )
296
299
297
300
function classes (yhat:: AbstractArray{<:Union{Missing,UnivariateFinite}} )
298
- i = findfirst (x -> ! ismissing (x) , yhat)
301
+ i = findfirst (! ismissing, yhat)
299
302
i === nothing && throw (ERR_EMPTY_UNIVARIATE_FINITE)
300
303
return classes (yhat[i])
301
304
end
302
-
303
-
0 commit comments