@@ -151,10 +151,10 @@ end
151
151
152
152
# dummy function
153
153
# returns `x[i]` for `Array` inputs `x`
154
- # For non-Array inputs returns `zero(dtype)`
154
+ # For non-Array inputs returns the input
155
155
# This avoids using an if statement
156
- _getindex (x:: Array , i, dtype) = x[i]
157
- _getindex (:: Nothing , i, dtype ) = zero (dtype)
156
+ _getindex (x:: Array , i) = x[i]
157
+ _getindex (x , i) = x
158
158
159
159
# pdf.(u, cv)
160
160
function Base. Broadcast. broadcasted (
@@ -164,15 +164,16 @@ function Base.Broadcast.broadcasted(
164
164
165
165
# we assume that we compare categorical values by their unwrapped value
166
166
# and pick the index of this value from classes(u)
167
- cv_loc = findfirst (== (cv), classes (u))
168
- cv_loc == 0 && throw (err_missing_class (cv))
167
+ _classes = classes (u)
168
+ cv_loc = get (CategoricalArrays. pool (_classes), cv, zero (R))
169
+ isequal (cv_loc, 0 ) && throw (err_missing_class (cv))
169
170
170
171
f () = zeros (P, size (u)) # default caller function
171
172
172
173
return Base. Broadcast. Broadcasted (
173
174
identity,
174
175
(get (f, u. prob_given_ref, cv_loc),)
175
- )
176
+ )
176
177
end
177
178
178
179
Base. Broadcast. broadcasted (
@@ -191,15 +192,23 @@ function Base.Broadcast.broadcasted(
191
192
length (u) == length (v) || throw (DimensionMismatch (
192
193
" Arrays could not be broadcast to a common size; " *
193
194
" got a dimension with lengths $(length (u)) and $(length (v)) " ))
195
+
196
+ _classes = classes (u)
197
+ _classes_pool = CategoricalArrays. pool (_classes)
198
+ T = eltype (v) >: Missing ? Missing : Union{}
199
+ v_loc_flat = Vector {Tuple{Union{R, T}, Int}} (undef, length (v))
200
+
201
+
202
+ for (i, x) in enumerate (v)
203
+ cv_ref = ismissing (x) ? missing : get (_classes_pool, x, zero (R))
204
+ isequal (cv_ref, 0 ) && throw (err_missing_class (x))
205
+ v_loc_flat[i] = (cv_ref, i)
206
+ end
194
207
195
- v_loc_flat = [(ismissing (x) ? missing : findfirst (== (x), classes (u)), i)
196
- for (i, x) in enumerate (v)]
197
- any (isequal (0 ), v_loc_flat) && throw (err_missing_class (cv))
198
-
199
- getter ((cv_loc, i), dtype) =
200
- _getindex (get (u. prob_given_ref, cv_loc, nothing ), i, dtype)
201
- getter (:: Tuple{Missing,Any} , dtype) = missing
202
- ret_flat = getter .(v_loc_flat, P)
208
+ getter ((cv_ref, i)) =
209
+ _getindex (get (u. prob_given_ref, cv_ref, zero (P)), i)
210
+ getter (:: Tuple{Missing,Any} ) = missing
211
+ ret_flat = getter .(v_loc_flat)
203
212
return reshape (ret_flat, size (u))
204
213
end
205
214
0 commit comments