@@ -14,6 +14,11 @@ function Base.getindex(u::UniFinArr{<:Any,<:Any,R,P,N},
14
14
return UnivariateFinite (u. scitype, u. decoder, prob_given_ref)
15
15
end
16
16
17
+ function Base. getindex (u:: UniFinArr , idx:: CartesianIndex )
18
+ checkbounds (u, idx)
19
+ return u[Tuple (idx)... ]
20
+ end
21
+
17
22
function Base. getindex (u:: UniFinArr{<:Any,<:Any,R,P,N} ,
18
23
I... ) where {R,P,N}
19
24
prob_given_ref = LittleDict {R,Array{P,N}} ()
35
40
# TODO : return an exception without throwing it:
36
41
37
42
_err_incompatible_levels () = throw (DomainError (
38
- " Cannot concatenate `UnivariateFiniteArray`s with " *
39
- " different categorical levels (classes), " *
40
- " or whose levels, when ordered, are not " *
43
+ " Cannot concatenate `UnivariateFiniteArray`s with " *
44
+ " different categorical levels (classes), " *
45
+ " or whose levels, when ordered, are not " *
41
46
" consistently ordered. " ))
42
47
43
48
# terminology:
@@ -61,14 +66,12 @@ function Base.cat(us::UniFinArr{S,V,R,P,N}...;
61
66
for i in 2 : length (us)
62
67
isordered (us[i]) == ordered || _err_incompatible_levels ()
63
68
if ordered
64
- classes (us[i]) ==
65
- _classes|| _err_incompatible_levels ()
69
+ classes (us[i]) == _classes || _err_incompatible_levels ()
66
70
else
67
- Set (classes (us[i])) ==
68
- Set (_classes) || _err_incompatible_levels ()
71
+ Set (classes (us[i])) == Set (_classes) || _err_incompatible_levels ()
69
72
end
70
- support_with_duplicates =
71
- vcat (support_with_duplicates, Dist. support (us[i]))
73
+ support_with_duplicates = vcat (support_with_duplicates,
74
+ Dist. support (us[i]))
72
75
end
73
76
_support = unique (support_with_duplicates) # no-longer categorical!
74
77
@@ -99,14 +102,12 @@ for func in [:pdf, :logpdf]
99
102
eval (quote
100
103
function Distributions. $func (
101
104
u:: AbstractArray{UnivariateFinite{S,V,R,P},N} ,
102
- C:: AbstractVector {<: Union {
103
- V,
104
- CategoricalValue{V,R}}}) where {S,V,R,P,N}
105
+ C:: AbstractVector ) where {S,V,R,P,N}
105
106
106
- # ret = Array{P,N+1}(undef, size(u)..., length(C))
107
107
ret = zeros (P, size (u)... , length (C))
108
- for i in eachindex (C)
109
- ret[fill (:,N)... ,i] .= broadcast ($ func, u, C[i])
108
+ # note that we do not require C to use 1-base indexing
109
+ for (i, c) in enumerate (C)
110
+ ret[fill (:,N)... , i] .= broadcast ($ func, u, c)
110
111
end
111
112
return ret
112
113
end
126
127
# returns `x[i]` for `Array` inputs `x`
127
128
# For non-Array inputs returns `zero(dtype)`
128
129
# This avoids using an if statement
129
- _getindex (x:: Array ,i, dtype)= x[i]
130
+ _getindex (x:: Array , i, dtype)= x[i]
130
131
_getindex (:: Nothing , i, dtype) = zero (dtype)
131
132
132
133
# pdf.(u, cv)
@@ -135,19 +136,23 @@ function Base.Broadcast.broadcasted(
135
136
u:: UniFinArr{S,V,R,P,N} ,
136
137
cv:: CategoricalValue ) where {S,V,R,P,N}
137
138
138
- cv in classes (u) || throw (err_missing_class (cv))
139
+ # we assume that we compare categorical values by their unwrapped value
140
+ # and pick the index of this value from classes(u)
141
+ cv_loc = findfirst (== (cv), classes (u))
142
+ cv_loc == 0 && throw (err_missing_class (cv))
139
143
140
144
f () = zeros (P, size (u)) # default caller function
141
145
142
146
return Base. Broadcast. Broadcasted (
143
147
identity,
144
- (get (f, u. prob_given_ref, int (cv) ),)
148
+ (get (f, u. prob_given_ref, cv_loc ),)
145
149
)
146
150
end
151
+
147
152
Base. Broadcast. broadcasted (
148
153
:: typeof (pdf),
149
154
u:: UniFinArr{S,V,R,P,N} ,
150
- :: Missing ) where {S,V,R,P,N} = Missings. missings (P, length (u))
155
+ :: Missing ) where {S,V,R,P,N} = Missings. missings (P, size (u))
151
156
152
157
# pdf.(u, v)
153
158
function Base. Broadcast. broadcasted (
@@ -160,17 +165,15 @@ function Base.Broadcast.broadcasted(
160
165
length (u) == length (v) || throw (DimensionMismatch (
161
166
" Arrays could not be broadcast to a common size; " *
162
167
" got a dimension with lengths $(length (u)) and $(length (v)) " ))
163
- for cv in v
164
- ismissing (cv) || cv in classes (u) || throw (err_missing_class (cv))
165
- end
166
168
167
- # will use linear indexing:
168
- v_flat = ((v[i], i) for i in 1 : length (v))
169
+ v_loc_flat = [(ismissing (x) ? missing : findfirst (== (x), classes (u)), i)
170
+ for (i, x) in enumerate (v)]
171
+ any (isequal (0 ), v_loc_flat) && throw (err_missing_class (cv))
169
172
170
- getter ((cv , i), dtype) =
171
- _getindex (get (u. prob_given_ref, int (cv) , nothing ), i, dtype)
173
+ getter ((cv_loc , i), dtype) =
174
+ _getindex (get (u. prob_given_ref, cv_loc , nothing ), i, dtype)
172
175
getter (:: Tuple{Missing,Any} , dtype) = missing
173
- ret_flat = getter .(v_flat , P)
176
+ ret_flat = getter .(v_loc_flat , P)
174
177
return reshape (ret_flat, size (u))
175
178
end
176
179
@@ -243,10 +246,10 @@ function Base.Broadcast.broadcasted(::typeof(mode),
243
246
mode_flat = map (1 : length (u)) do i
244
247
max_prob = maximum (dic[ref][i] for ref in keys (dic))
245
248
m = zero (R)
246
-
247
- # `maximum` of any iterable containing `NaN` would return `NaN`
249
+
250
+ # `maximum` of any iterable containing `NaN` would return `NaN`
248
251
# For this case the index `m` won't be updated in the loop as relations
249
- # involving NaN as one of it's argument always returns false
252
+ # involving NaN as one of it's argument always returns false
250
253
# (e.g `==(NaN, NaN)` returns false)
251
254
throw_nan_error_if_needed (max_prob)
252
255
for ref in keys (dic)
@@ -269,9 +272,7 @@ const ERR_EMPTY_UNIVARIATE_FINITE = ArgumentError(
269
272
" No `UnivariateFinite` object found from which to extract classes. " )
270
273
271
274
function classes (yhat:: AbstractArray{<:Union{Missing,UnivariateFinite}} )
272
- i = findfirst (x -> ! ismissing (x) , yhat)
275
+ i = findfirst (! ismissing, yhat)
273
276
i === nothing && throw (ERR_EMPTY_UNIVARIATE_FINITE)
274
277
return classes (yhat[i])
275
278
end
276
-
277
-
0 commit comments