Skip to content

Commit e1448c1

Browse files
committed
more tests
1 parent ce2cbdf commit e1448c1

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

src/types.jl

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,19 @@ const Prob{P} = Union{P, AbstractArray{P}} where P
182182
# TODO: are some of these now obsolete?
183183

184184
const ERR_01 = DomainError("Probabilities must be in [0,1].")
185-
_err_sum_1() = throw(DomainError(
186-
"Probability arrays must sum to one along the last axis. Perhaps "*
187-
"you meant to specify `augment=true`? "))
188-
_err_dim(support, probs) = throw(DimensionMismatch(
189-
"Probability array is incompatible "*
190-
"with the number of classes, $(length(support)), which should "*
191-
"be equal to `$(size(probs)[end])`, the last dimension "*
192-
"of the array. Perhaps you meant to set `augment=true`? "))
193-
_err_dim_augmented(support, probs) = throw(DimensionMismatch(
194-
"Probability array to be augmented is incompatible "*
195-
"with the number of classes, $(length(support)), which should "*
196-
"be one more than `$(size(probs)[end])`, the last dimension "*
197-
"of the array. "))
198-
_err_aug() = throw(ArgumentError(
185+
err_dim(support, probs) = DimensionMismatch(
186+
"Probability array is incompatible "*
187+
"with the number of classes, $(length(support)), which should "*
188+
"be equal to `$(size(probs)[end])`, the last dimension "*
189+
"of the probability array. Perhaps you meant to set `augment=true`? ")
190+
err_dim_augmented(support, probs) = DimensionMismatch(
191+
"Probability array to be augmented is incompatible "*
192+
"with the number of classes, $(length(support)), which should "*
193+
"be one more than `$(size(probs)[end])`, the last dimension "*
194+
"of the probability array. ")
195+
const ERR_AUG = ArgumentError(
199196
"Array cannot be augmented. There are "*
200-
"sums along the last axis exceeding one. "))
197+
"sums along the last axis exceeding one. ")
201198

202199
function _check_pool(pool)
203200
ismissing(pool) || pool == nothing ||
@@ -207,12 +204,10 @@ function _check_pool(pool)
207204
end
208205
_check_probs_01(probs) =
209206
all(0 .<= probs .<= 1) || throw(ERR_01)
210-
_check_probs_sum(probs::Vector{<:Prob{P}}) where P =
211-
all(x -> xone(P), sum(probs)) || _err_sum_1()
212207
_check_probs(probs) = (_check_probs_01(probs); _check_probs_sum(probs))
213208
_check_augmentable(support, probs) = _check_probs_01(probs) &&
214209
size(probs)[end] + 1 == length(support) ||
215-
_err_dim_augmented(support, probs)
210+
throw(err_dim_augmented(support, probs))
216211

217212

218213
## AUGMENTING ARRAYS TO MAKE THEM PROBABILITY ARRAYS
@@ -232,7 +227,7 @@ function _augment_probs(::Val{false},
232227
aug_size = size(probs) |> collect
233228
aug_size[end] += 1
234229
augmentation = _unwrap(one(P) .- sum(probs, dims=N))
235-
all(0 .<= augmentation .<= 1) || _err_aug()
230+
all(0 .<= augmentation .<= 1) || throw(ERR_AUG)
236231
aug_probs = Array{P}(undef, aug_size...)
237232
aug_probs[fill(:, N - 1)..., 2:end] = probs
238233
aug_probs[fill(:, N - 1)..., 1] = augmentation
@@ -244,7 +239,7 @@ function _augment_probs(::Val{true},
244239
_check_probs_01(probs)
245240
aug_size = [size(probs)..., 2]
246241
augmentation = one(P) .- probs
247-
all(0 .<= augmentation .<= 1) || _err_aug()
242+
all(0 .<= augmentation .<= 1) || throw(ERR_AUG)
248243
aug_probs = Array{P}(undef, aug_size...)
249244
aug_probs[fill(:, N)..., 2] = probs
250245
aug_probs[fill(:, N)..., 1] = augmentation
@@ -370,6 +365,9 @@ function _UnivariateFinite(support::AbstractVector{CategoricalValue{V,R}},
370365

371366
_probs = augment ? _augment_probs(support, probs) : probs
372367

368+
augment || length(support) == size(_probs) |> last ||
369+
throw(err_dim(support, _probs))
370+
373371
# it's necessary to force the typing of the LittleDict otherwise it
374372
# flips to Any type (unlike regular Dict):
375373

test/types.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import Random
2121

2222
@test_logs((:warn, r"No "),
2323
UnivariateFinite(['f', 'q', 's'], [0.7, 0.2, 0.1]))
24+
2425
end
2526

2627
@testset "array constructors" begin
@@ -32,7 +33,27 @@ end
3233
probs = rand(rng, n)
3334
supp = ["class1", "class2"]
3435

35-
u = UnivariateFinite(supp, probs, pool=missing, augment=true);
36+
UnivariateFinite(supp, probs, pool=missing, augment=true);
37+
38+
# dimension mismatches:
39+
badprobs = rand(40, 3)
40+
@test_throws(CategoricalDistributions.err_dim(supp, badprobs),
41+
UnivariateFinite(supp, badprobs, pool=missing))
42+
43+
# dimension mismatch, augmented case:
44+
probs2 = [0.1 0.5 0.1;
45+
0.3 0.2 0.1]
46+
supp2 = ["no", "yes", "maybe"]
47+
@test_throws(CategoricalDistributions.err_dim_augmented(supp2, probs2),
48+
UnivariateFinite(supp2, probs2, augment=true, pool=missing))
49+
50+
# not augmentable:
51+
@test_throws(CategoricalDistributions.ERR_AUG,
52+
UnivariateFinite(["no", "yes", "maybe"],
53+
[0.6 0.5; # sum exceeding one!
54+
0.3 0.2],
55+
augment=true,
56+
pool=missing))
3657

3758
# autosupport:
3859
u = UnivariateFinite(probs, pool=missing, augment=true);

0 commit comments

Comments
 (0)