@@ -182,22 +182,19 @@ const Prob{P} = Union{P, AbstractArray{P}} where P
182
182
# TODO : are some of these now obsolete?
183
183
184
184
const ERR_01 = DomainError (" Probabilities must be in [0,1]." )
185
- _err_sum_1 () = throw (DomainError (
186
- " Probability arrays must sum to one along the last axis. Perhaps " *
187
- " you meant to specify `augment=true`? " ))
188
- _err_dim (support, probs) = throw (DimensionMismatch (
189
- " Probability array is incompatible " *
190
- " with the number of classes, $(length (support)) , which should " *
191
- " be equal to `$(size (probs)[end ]) `, the last dimension " *
192
- " of the array. Perhaps you meant to set `augment=true`? " ))
193
- _err_dim_augmented (support, probs) = throw (DimensionMismatch (
194
- " Probability array to be augmented is incompatible " *
195
- " with the number of classes, $(length (support)) , which should " *
196
- " be one more than `$(size (probs)[end ]) `, the last dimension " *
197
- " of the array. " ))
198
- _err_aug () = throw (ArgumentError (
185
+ err_dim (support, probs) = DimensionMismatch (
186
+ " Probability array is incompatible " *
187
+ " with the number of classes, $(length (support)) , which should " *
188
+ " be equal to `$(size (probs)[end ]) `, the last dimension " *
189
+ " of the probability array. Perhaps you meant to set `augment=true`? " )
190
+ err_dim_augmented (support, probs) = DimensionMismatch (
191
+ " Probability array to be augmented is incompatible " *
192
+ " with the number of classes, $(length (support)) , which should " *
193
+ " be one more than `$(size (probs)[end ]) `, the last dimension " *
194
+ " of the probability array. " )
195
+ const ERR_AUG = ArgumentError (
199
196
" Array cannot be augmented. There are " *
200
- " sums along the last axis exceeding one. " ))
197
+ " sums along the last axis exceeding one. " )
201
198
202
199
function _check_pool (pool)
203
200
ismissing (pool) || pool == nothing ||
@@ -207,12 +204,10 @@ function _check_pool(pool)
207
204
end
208
205
_check_probs_01 (probs) =
209
206
all (0 .<= probs .<= 1 ) || throw (ERR_01)
210
- _check_probs_sum (probs:: Vector{<:Prob{P}} ) where P =
211
- all (x -> x≈ one (P), sum (probs)) || _err_sum_1 ()
212
207
_check_probs (probs) = (_check_probs_01 (probs); _check_probs_sum (probs))
213
208
_check_augmentable (support, probs) = _check_probs_01 (probs) &&
214
209
size (probs)[end ] + 1 == length (support) ||
215
- _err_dim_augmented ( support, probs)
210
+ throw ( err_dim_augmented ( support, probs) )
216
211
217
212
218
213
# # AUGMENTING ARRAYS TO MAKE THEM PROBABILITY ARRAYS
@@ -232,7 +227,7 @@ function _augment_probs(::Val{false},
232
227
aug_size = size (probs) |> collect
233
228
aug_size[end ] += 1
234
229
augmentation = _unwrap (one (P) .- sum (probs, dims= N))
235
- all (0 .<= augmentation .<= 1 ) || _err_aug ( )
230
+ all (0 .<= augmentation .<= 1 ) || throw (ERR_AUG )
236
231
aug_probs = Array {P} (undef, aug_size... )
237
232
aug_probs[fill (:, N - 1 )... , 2 : end ] = probs
238
233
aug_probs[fill (:, N - 1 )... , 1 ] = augmentation
@@ -244,7 +239,7 @@ function _augment_probs(::Val{true},
244
239
_check_probs_01 (probs)
245
240
aug_size = [size (probs)... , 2 ]
246
241
augmentation = one (P) .- probs
247
- all (0 .<= augmentation .<= 1 ) || _err_aug ( )
242
+ all (0 .<= augmentation .<= 1 ) || throw (ERR_AUG )
248
243
aug_probs = Array {P} (undef, aug_size... )
249
244
aug_probs[fill (:, N)... , 2 ] = probs
250
245
aug_probs[fill (:, N)... , 1 ] = augmentation
@@ -370,6 +365,9 @@ function _UnivariateFinite(support::AbstractVector{CategoricalValue{V,R}},
370
365
371
366
_probs = augment ? _augment_probs (support, probs) : probs
372
367
368
+ augment || length (support) == size (_probs) |> last ||
369
+ throw (err_dim (support, _probs))
370
+
373
371
# it's necessary to force the typing of the LittleDict otherwise it
374
372
# flips to Any type (unlike regular Dict):
375
373
0 commit comments