Skip to content

Commit f13ee87

Browse files
authored
Merge pull request #597 from JuliaAI/levels
Bump compat CategoricalArrays="1"
2 parents 910c741 + 83966c7 commit f13ee87

File tree

8 files changed

+30
-145
lines changed

8 files changed

+30
-145
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
- uses: julia-actions/julia-runtest@v1
4444
env:
4545
# This environment variable enables the integration tests:
46-
MLJ_TEST_REGISTRY: '1'
46+
MLJ_TEST_REGISTRY: "false"
4747
- uses: julia-actions/julia-processcoverage@v1
4848
- uses: codecov/codecov-action@v4
4949
with:

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2828
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2929

3030
[compat]
31-
CategoricalArrays = "0.9, 0.10"
32-
CategoricalDistributions = "0.1"
31+
CategoricalArrays = "1"
32+
CategoricalDistributions = "0.2"
3333
Combinatorics = "1.0"
3434
Dates = "1"
3535
Distances = "0.9,0.10"

src/MLJModels.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ using Combinatorics
2222
import Distributions
2323
import REPL # stdlib, needed for `Term`
2424
import PrettyPrinting
25-
import CategoricalDistributions: UnivariateFinite, UnivariateFiniteArray,
26-
classes
25+
import CategoricalDistributions: UnivariateFinite, UnivariateFiniteArray
2726
import StatisticalTraits # for `info`
2827

2928
# from loading.jl:

src/builtins/Constant.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function MLJModelInterface.fit(::ConstantClassifier,
5555
y,
5656
w=nothing)
5757
d = Distributions.fit(UnivariateFinite, y, w)
58-
C = classes(d)
58+
C = levels(d)
5959
fitresult = (C, Distributions.pdf([d, ], C))
6060
cache = nothing
6161
report = NamedTuple()
@@ -66,10 +66,10 @@ MLJModelInterface.fitted_params(::ConstantClassifier, fitresult) =
6666
(target_distribution=fitresult,)
6767

6868
function MLJModelInterface.predict(::ConstantClassifier, fitresult, Xnew)
69-
_classes, probs1 = fitresult
69+
_levels, probs1 = fitresult
7070
N = nrows(Xnew)
71-
probs = reshape(vcat(fill(probs1, N)...), N, length(_classes))
72-
return UnivariateFinite(_classes, probs)
71+
probs = reshape(vcat(fill(probs1, N)...), N, length(_levels))
72+
return UnivariateFinite(_levels, probs)
7373
end
7474

7575

@@ -216,10 +216,11 @@ ConstantRegressor
216216
217217
This "dummy" probabilistic predictor always returns the same distribution, irrespective of
218218
the provided input pattern. The distribution `d` returned is the `UnivariateFinite`
219-
distribution based on frequency of classes observed in the training target data. So,
220-
`pdf(d, level)` is the number of times the training target takes on the value `level`.
221-
Use `predict_mode` instead of `predict` to obtain the training target mode instead. For
222-
more on the `UnivariateFinite` type, see the CategoricalDistributions.jl package.
219+
distribution based on frequency of levels (classes) observed in the training target
220+
data. So, `pdf(d, level)` is the number of times the training target takes on the value
221+
`level`. Use `predict_mode` instead of `predict` to obtain the training target mode
222+
instead. For more on the `UnivariateFinite` type, see the CategoricalDistributions.jl
223+
package.
223224
224225
Almost any reasonable model is expected to outperform `ConstantClassifier`, which is used
225226
almost exclusively for testing and establishing performance baselines.

src/builtins/ThresholdPredictors.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,16 @@ const ThresholdSupported = Union{keys(_type_given_atom)...}
5656

5757
const ERR_MODEL_UNSPECIFIED = ArgumentError(
5858
"Expecting atomic model as argument. None specified. ")
59-
warn_classes(first_class, second_class) =
59+
warn_levels(first_class, second_class) =
6060
"Taking positive class as `$(second_class)` and negative class as"*
6161
"`$(first_class)`."*
6262
"Coerce target to `OrderedFactor{2}` to suppress this warning, "*
6363
"ensuring that positive class > negative class. "
64-
const ERR_CLASSES_DETECTOR = ArgumentError(
64+
const ERR_LEVELS_DETECTOR = ArgumentError(
6565
"Targets for detector models must be ordered. Consider coercing to "*
6666
"`OrderedFactor`, ensuring that outlier class > inlier class. ")
6767
const ERR_TARGET_NOT_BINARY = ArgumentError(
68-
"Target `y` must have two classes in its pool, even if only one "*
68+
"Target `y` must have two levels in its pool, even if only one "*
6969
"class is manifest. ")
7070
const err_unsupported_model_type(T) = ArgumentError(
7171
"`BinaryThresholdPredictor` does not support atomic models with supertype `$T`. "*
@@ -208,9 +208,9 @@ function MMI.fit(model::ThresholdUnion, verbosity::Int, args...)
208208
length(L) == 2 || throw(ERR_TARGET_NOT_BINARY)
209209
first_class, second_class = L
210210
if model.model isa Probabilistic
211-
@warn warn_classes(first_class, second_class)
211+
@warn warn_levels(first_class, second_class)
212212
else
213-
throw(ERR_CLASSES_DETECTOR)
213+
throw(ERR_LEVELS_DETECTOR)
214214
end
215215
end
216216
model_fitresult, model_cache, model_report = MMI.fit(
@@ -259,7 +259,7 @@ function _predict_threshold(yhat::UnivariateFinite, threshold)
259259
dict = yhat.prob_given_ref
260260
length(threshold) == length(dict) || throw(
261261
ArgumentError(
262-
"`length(threshold)` has to equal number of classes in specified "*
262+
"`length(threshold)` has to equal number of levels in specified "*
263263
"`UnivariateFinite` distribution."
264264
)
265265
)
@@ -277,14 +277,14 @@ function _predict_threshold(yhat::UnivariateFiniteArray{S,V,R,P,N},
277277
dict = yhat.prob_given_ref
278278
length(threshold) == length(dict) || throw(
279279
ArgumentError(
280-
"`length(threshold)` has to equal number of classes in specified "*
280+
"`length(threshold)` has to equal number of levels in specified "*
281281
"`UnivariateFiniteArray`."
282282
)
283283
)
284284
d = yhat.decoder(1)
285285
levs = levels(d)
286286
ord = isordered(d)
287-
# Array to house the predicted classes
287+
# Array to house the predicted levels
288288
ret = CategoricalArray{V, N, R}(undef, size(yhat), levels=levs, ordered=ord)
289289
#ret = Array{CategoricalValue{V, R}, N}(undef, size(yhat))
290290
# `temp` vector allocted once to be used for calculations in each loop

test/builtins/Constant.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ end
3535
d = MLJBase.UnivariateFinite([y[1], y[2], y[4]], [0.5, 0.25, 0.25])
3636

3737
yhat = MLJBase.predict_mode(model, fitresult, X)
38-
@test MLJBase.classes(yhat[1]) == MLJBase.classes(y[1])
38+
@test levels(yhat[1]) == levels(y[1])
3939
@test yhat[5] == y[1]
4040
@test length(yhat) == 10
4141

4242
yhat = MLJBase.predict(model, fitresult, X)
4343
yhat1 = yhat[1]
4444

45-
for c in MLJBase.classes(d)
45+
for c in levels(d)
4646
Distributions.pdf(yhat1, c) Distributions.pdf(d, c)
4747
end
4848

test/builtins/ThresholdPredictors.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ y2_ = categorical(yraw[2:end], ordered=true)
2727
)
2828

2929
# Check warning when `y` is not ordered:
30-
@test_logs((:warn, MLJModels.warn_classes(levels(y_)...)),
30+
@test_logs((:warn, MLJModels.warn_levels(levels(y_)...)),
3131
MMI.fit(model, 1, MMI.reformat(model, X_, y1_)...))
32-
# Check predictions containing two classes
32+
# Check predictions containing two levels
3333
@test_throws ArgumentError BinaryThresholdPredictor(ConstantRegressor())
3434
@test_logs((:warn, r"`threshold` should be"),
3535
BinaryThresholdPredictor(atom, threshold=-1))
@@ -88,13 +88,13 @@ end
8888
v1 = categorical(['a', 'b', 'a'])
8989
v2 = categorical(['a', 'b', 'a', 'c'])
9090
# Test with UnivariateFinite object
91-
d1 = UnivariateFinite(MMI.classes(v1), [0.4, 0.6])
91+
d1 = UnivariateFinite(levels(v1), [0.4, 0.6])
9292
@test_throws ArgumentError MLJModels._predict_threshold(d1, 0.7)
9393
@test MLJModels._predict_threshold(d1, (0.7, 0.3)) == v1[2]
9494
@test MLJModels._predict_threshold(d1, [0.5, 0.5]) == v1[2]
9595
@test MLJModels._predict_threshold(d1, (0.4, 0.6)) == v1[1]
9696
@test MLJModels._predict_threshold(d1, [0.2, 0.8]) == v1[1]
97-
d2 = UnivariateFinite(MMI.classes(v2), [0.4, 0.3, 0.3])
97+
d2 = UnivariateFinite(levels(v2), [0.4, 0.3, 0.3])
9898
@test_throws ArgumentError MLJModels._predict_threshold(d2, (0.7, 0.3))
9999
@test MLJModels._predict_threshold(d2, (0.2, 0.5, 0.3)) == v2[1]
100100
@test MLJModels._predict_threshold(d2, [0.3, 0.2, 0.5]) == v2[2]
@@ -117,14 +117,14 @@ end
117117

118118
# Test with UnivariateFiniteArray oject
119119
probs1 = [0.2 0.8; 0.7 0.3; 0.1 0.9]
120-
unf_arr1 = UnivariateFinite(MMI.classes(v1), probs1)
120+
unf_arr1 = UnivariateFinite(levels(v1), probs1)
121121
@test_throws ArgumentError MLJModels._predict_threshold(unf_arr1, 0.7)
122122
@test MLJModels._predict_threshold(unf_arr1, (0.7, 0.3)) == [v1[2], v1[1], v1[2]]
123123
@test MLJModels._predict_threshold(unf_arr1, [0.5, 0.5]) == [v1[2], v1[1], v1[2]]
124124
@test MLJModels._predict_threshold(unf_arr1, (0.4, 0.6)) == [v1[2], v1[1], v1[2]]
125125
@test MLJModels._predict_threshold(unf_arr1, [0.2, 0.8]) == [v1[1], v1[1], v1[2]]
126126
probs2 = [0.2 0.3 0.5;0.1 0.6 0.3; 0.4 0.0 0.6]
127-
unf_arr2 = UnivariateFinite(MMI.classes(v2), probs2)
127+
unf_arr2 = UnivariateFinite(levels(v2), probs2)
128128
@test_throws ArgumentError MLJModels._predict_threshold(unf_arr2, (0.7, 0.3))
129129
@test MLJModels._predict_threshold(unf_arr2, (0.2, 0.5, 0.3)) == [v2[4], v2[2], v2[1]]
130130
@test MLJModels._predict_threshold(unf_arr2, [0.3, 0.2, 0.5]) == [v2[2], v2[2], v2[1]]
@@ -144,7 +144,7 @@ MMI.input_scitype(::Type{<:DummyDetector}) = MMI.Table
144144

145145
@testset "BinaryThresholdPredictor - ProbabilisticUnsupervisedDetector" begin
146146
detector = BinaryThresholdPredictor(DummyDetector(), threshold=0.2)
147-
@test_throws MLJModels.ERR_CLASSES_DETECTOR MMI.fit(
147+
@test_throws MLJModels.ERR_LEVELS_DETECTOR MMI.fit(
148148
detector, 1, MMI.reformat(detector, X_, y1_)...
149149
)
150150

test/testutils.jl

Lines changed: 0 additions & 115 deletions
This file was deleted.

0 commit comments

Comments
 (0)