Skip to content

Commit 270cb4c

Browse files
authored
Merge pull request #22 from JuliaAI/dev
For a 0.1.2 release
2 parents 1e1b3b7 + defa20f commit 270cb4c

File tree

7 files changed

+128
-38
lines changed

7 files changed

+128
-38
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
version:
22-
- '1.7'
22+
- '1.6'
2323
- '1'
2424

2525
os: [ubuntu-latest, windows-latest, macOS-latest]

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ OrderedCollections = "1.6"
1717
MLJModelInterface = "1.9"
1818
MLUtils = "0.4"
1919
StatsBase = "0.34"
20-
julia = "1.7"
20+
julia = "1.6"
2121

2222
[extras]
2323
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"

src/MLJBalancing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using MLJModelInterface
55
using MLUtils
66
using OrderedCollections
77
using Random
8-
using Random: AbstractRNG, Xoshiro, rand
8+
using Random: AbstractRNG, rand
99
using StatsBase: sample
1010

1111
MMI = MLJModelInterface

src/balanced_bagging.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,34 @@ mutable struct BalancedBaggingClassifier{RI<:Union{AbstractRNG, Integer},I<:Inte
7979
rng::RI
8080
end
8181

82-
rng_handler(rng::Integer) = Random.Xoshiro(rng)
82+
# To support Julia 1.6 which does not have Xoshiro
83+
XoshiroOrMT(rng::Integer) = (VERSION < v"1.7") ? Random.MersenneTwister(rng) : Random.Xoshiro(rng)
84+
rng_handler(rng::Integer) = XoshiroOrMT(rng)
8385
rng_handler(rng::AbstractRNG) = rng
86+
8487
const ERR_MISSING_CLF = "No model specified. Please specify a probabilistic classifier using the `model` keyword argument."
8588
const ERR_BAD_T = "The number of ensemble models `T` cannot be negative."
8689
const INFO_DEF_T(T_def) = "The number of ensemble models was not given and was thus, automatically set to $T_def"*
8790
" which is the ratio of the frequency of the majority class to that of the minority class"
88-
function BalancedBaggingClassifier(;
91+
const ERR_NUM_ARGS_BB = "`BalancedBaggingClassifier` can at most have one non-keyword argument where the model is passed."
92+
const WRN_MODEL_GIVEN = "Ignoring keyword argument `model=...` as model already given as positional argument. "
93+
94+
function BalancedBaggingClassifier(args...;
8995
model = nothing,
9096
T = 0,
9197
rng = Random.default_rng(),
9298
)
93-
model === nothing && error(ERR_MISSING_CLF)
94-
T < 0 && error(ERR_BAD_T)
95-
rng = rng_handler(rng)
99+
length(args) <= 1 || throw(ERR_NUM_ARGS_BB)
100+
if length(args) === 1
101+
atom = first(args)
102+
model === nothing ||
103+
@warn WRN_MODEL_GIVEN
104+
model = atom
105+
else
106+
model === nothing && throw(ERR_MISSING_CLF)
107+
end
108+
T < 0 && error(ERR_BAD_T)
109+
rng = rng_handler(rng)
96110
return BalancedBaggingClassifier(model, T, rng)
97111
end
98112

@@ -208,8 +222,8 @@ Train the machine with `fit!(mach, rows=...)`.
208222
- `T::Integer=0`: The number of bags to be used in the ensemble. If not given, will be set as
209223
the ratio between the frequency of the majority and minority classes. Can be later found in `report(mach)`.
210224
211-
- `rng::Union{AbstractRNG, Integer}=default_rng()`: Either an `AbstractRNG` object or an `Integer`
212-
seed to be used with `Xoshiro`
225+
- `rng::Union{AbstractRNG, Integer}=default_rng()`: Either an `AbstractRNG` object or an `Integer`
226+
seed to be used with `Xoshiro` if Julia `VERSION>=1.7`. Otherwise, uses MersenneTwister`.
213227
214228
# Operations
215229

src/balanced_model.jl

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,33 +55,78 @@ const UNION_MODEL_TYPES = Union{keys(MODELTYPE_TO_COMPOSITETYPE_EVAL)...}
5555

5656

5757
# Possible Errors (for the constructor as well)
58-
const ERR_MODEL_UNSPECIFIED = ArgumentError("Expected an atomic model as argument. None specified. ")
58+
const ERR_MODEL_UNSPECIFIED = ErrorException("Expected an atomic model as argument. None specified. ")
5959

6060
const WRN_BALANCER_UNSPECIFIED = "No balancer was provided. Data will be directly passed to the model. "
6161

6262
const PRETTY_SUPPORTED_MODEL_TYPES = join([string("`", opt, "`") for opt in SUPPORTED_MODEL_TYPES], ", ",", and ")
6363

64-
const ERR_UNSUPPORTED_MODEL(model) = ArgumentError(
64+
const ERR_UNSUPPORTED_MODEL(model) = ErrorException(
6565
"Only these model supertypes support wrapping: "*
6666
"$PRETTY_SUPPORTED_MODEL_TYPES.\n"*
6767
"Model provided has type `$(typeof(model))`. "
6868
)
69+
const ERR_NUM_ARGS_BM = "`BalancedModel` can at most have one non-keyword argument where the model is passed."
6970

7071

7172
"""
72-
BalancedModel(; balancers=[], model=nothing)
73+
BalancedModel(; model=nothing, balancer1=balancer_model1, balancer2=balancer_model2, ...)
74+
BalancedModel(model; balancer1=balancer_model1, balancer2=balancer_model2, ...)
7375
74-
Wraps a classification model with balancers that resample the data before passing it to the model.
76+
Given a classification model, and one or more balancer models that all implement the `MLJModelInterface`,
77+
`BalancedModel` allows constructing a sequential pipeline that wraps an arbitrary number of balancing models
78+
and a classifier together in a sequential pipeline.
7579
76-
# Arguments
77-
- `balancers::AbstractVector=[]`: A vector of balancers (i.e., resampling models).
78-
Data passed to the model will be first passed to the balancers sequentially.
79-
- `model=nothing`: The classification model which must be provided.
80+
# Operation
81+
- During training, data is first passed to `balancer1` and the result is passed to `balancer2` and so on, the result from the final balancer
82+
is then passed to the classifier for training.
83+
- During prediction, the balancers have no effect.
8084
85+
# Arguments
86+
- `model::Supervised`: A classification model that implements the `MLJModelInterface`.
87+
- `balancer1::Static=...`: The first balancer model to pass the data to. This keyword argument can have any name.
88+
- `balancer2::Static=...`: The second balancer model to pass the data to. This keyword argument can have any name.
89+
- and so on for an arbitrary number of balancers.
90+
91+
# Returns
92+
- An instance of type ProbabilisticBalancedModel or DeterministicBalancedModel, depending on the prediction type of model.
93+
94+
# Example
95+
```julia
96+
using MLJ
97+
using Imbalance
98+
99+
# generate data
100+
X, y = Imbalance.generate_imbalanced_data(1000, 5; class_probs=[0.2, 0.3, 0.5])
101+
102+
# prepare classification and balancing models
103+
SMOTENC = @load SMOTENC pkg=Imbalance verbosity=0
104+
TomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0
105+
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0
106+
107+
oversampler = SMOTENC(k=5, ratios=1.0, rng=42)
108+
undersampler = TomekUndersampler(min_ratios=0.5, rng=42)
109+
logistic_model = LogisticClassifier()
110+
111+
# wrap them in a BalancedModel
112+
balanced_model = BalancedModel(model=logistic_model, balancer1=oversampler, balancer2=undersampler)
113+
114+
# now this behaves as a unified model that can be trained, validated, fine-tuned, etc.
115+
mach = machine(balanced_model, X, y)
116+
fit!(mach)
117+
```
81118
"""
82-
function BalancedModel(; model=nothing, named_balancers...)
119+
function BalancedModel(args...; model=nothing, named_balancers...)
83120
# check model and balancer are given
84-
model === nothing && throw(ERR_MODEL_UNSPECIFIED)
121+
length(args) <= 1 || throw(ERR_NUM_ARGS_BM)
122+
if length(args) === 1
123+
atom = first(args)
124+
model === nothing ||
125+
@warn WRN_MODEL_GIVEN
126+
model = atom
127+
else
128+
model === nothing && throw(ERR_MODEL_UNSPECIFIED)
129+
end
85130
# check model is supported
86131
model isa UNION_MODEL_TYPES || throw(ERR_UNSUPPORTED_MODEL(model))
87132

@@ -116,6 +161,7 @@ for model_type in SUPPORTED_MODEL_TYPES
116161
eval(ex)
117162
end
118163

164+
119165
const ERR_NO_PROP = ArgumentError("trying to access property $name which does not exist")
120166
# overload set property to set the property from the vector in the struct
121167
for model_type in SUPPORTED_MODEL_TYPES

test/balanced_bagging.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
@testset "group_inds and get_majority_minority_inds_counts" begin
33
y = [0, 0, 0, 0, 1, 1, 1, 0]
4-
@test MLJBalancing.group_inds(y) == Dict(0 => [1, 2, 3, 4, 8], 1 => [5, 6, 7])
4+
@test MLJBalancing.group_inds(y) == Dict(0 => [1, 2, 3, 4, 8], 1 => [5, 6, 7])
55
@test MLJBalancing.get_majority_minority_inds_counts(y) ==
66
([1, 2, 3, 4, 8], [5, 6, 7], 5, 3)
77
y = [0, 0, 0, 0, 1, 1, 1, 0, 2, 2, 2]
@@ -18,7 +18,7 @@ end
1818
num_vals_per_category = [3, 2, 1, 2],
1919
class_probs = [0.9, 0.1],
2020
type = "ColTable",
21-
rng = 42,
21+
rng = Random.MersenneTwister(42),
2222
)
2323
majority_inds, minority_inds, majority_count, minority_count =
2424
MLJBalancing.get_majority_minority_inds_counts(y)
@@ -30,7 +30,7 @@ end
3030
minority_inds,
3131
majority_count,
3232
minority_count,
33-
Random.Xoshiro(42)
33+
Random.MersenneTwister(42)
3434
)
3535
X_sub, y_sub = X_sub(rows = 1:100), y_sub(rows = 1:100)
3636
majority_inds_sub, minority_inds_sub, _, _ =
@@ -51,7 +51,7 @@ end
5151

5252
@testset "End-to-end Test" begin
5353
## setup parameters
54-
R = Random.Xoshiro(42)
54+
R = Random.MersenneTwister(42)
5555
T = 2
5656
LogisticClassifier = @load LogisticClassifier pkg = MLJLinearModels verbosity = 0
5757
model = LogisticClassifier()
@@ -64,7 +64,7 @@ end
6464
num_vals_per_category = [3, 2, 1, 2],
6565
class_probs = [0.9, 0.1],
6666
type = "ColTable",
67-
rng = 42,
67+
rng = Random.MersenneTwister(42),
6868
)
6969
# testing
7070
Xt, yt = generate_imbalanced_data(
@@ -73,7 +73,7 @@ end
7373
num_vals_per_category = [3, 2, 1, 2],
7474
class_probs = [0.9, 0.1],
7575
type = "ColTable",
76-
rng = 42,
76+
rng = Random.MersenneTwister(42),
7777
)
7878

7979
## prepare subsets
@@ -111,14 +111,30 @@ end
111111
pred_manual = mean([pred1, pred2])
112112

113113
## using BalancedBagging
114-
modelo = BalancedBaggingClassifier(model = model, T = 2, rng = Random.Xoshiro(42))
114+
modelo = BalancedBaggingClassifier(model = model, T = 2, rng = Random.MersenneTwister(42))
115115
mach = machine(modelo, X, y)
116116
fit!(mach)
117117
pred_auto = MLJBase.predict(mach, Xt)
118118
@test sum(pred_manual) sum(pred_auto)
119-
modelo = BalancedBaggingClassifier(model = model, rng = Random.Xoshiro(42))
119+
modelo = BalancedBaggingClassifier(model = model, rng = Random.MersenneTwister(42))
120120
mach = machine(modelo, X, y)
121121
fit!(mach)
122-
@test report(mach) == (chosen_T = 5,)
123-
122+
@test report(mach) == (chosen_T = 9,)
124123
end
124+
125+
126+
127+
128+
@testset "Equivalence of Constructions" begin
129+
## setup parameters
130+
R = Random.MersenneTwister(42)
131+
T = 2
132+
LogisticClassifier = @load LogisticClassifier pkg = MLJLinearModels verbosity = 0
133+
model = LogisticClassifier()
134+
BalancedBaggingClassifier(model=model, T=T, rng=R) == BalancedBaggingClassifier(model; T=T, rng=R)
135+
136+
@test_throws MLJBalancing.ERR_NUM_ARGS_BB BalancedBaggingClassifier(model, model; T=T, rng=R)
137+
@test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin
138+
BalancedBaggingClassifier(model; model=model, T=T, rng=R)
139+
end
140+
end

test/balanced_model.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
@testset "BalancedModel" begin
22
### end-to-end test
33
# Create and split data
4-
X, y = generate_imbalanced_data(100, 5; class_probs = [0.2, 0.3, 0.5])
4+
X, y = generate_imbalanced_data(100, 5; class_probs = [0.2, 0.3, 0.5], rng=Random.MersenneTwister(42))
55
X = DataFrame(X)
66
train_inds, test_inds =
7-
partition(eachindex(y), 0.8, shuffle = true, stratify = y, rng = Random.Xoshiro(42))
7+
partition(eachindex(y), 0.8, shuffle = true, stratify = y, rng = Random.MersenneTwister(42))
88
X_train, X_test = X[train_inds, :], X[test_inds, :]
99
y_train, y_test = y[train_inds], y[test_inds]
1010

@@ -18,9 +18,9 @@
1818
# And here are three resamplers from Imbalance.
1919
# The package should actually work with any `Static` transformer of the form `(X, y) -> (Xout, yout)`
2020
# provided that it implements the MLJ interface. Here, the balancer is the transformer
21-
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = 42)
22-
balancer2 = Imbalance.MLJ.SMOTENC(k = 10, ratios = 1.2, rng = 42)
23-
balancer3 = Imbalance.MLJ.ROSE(ratios = 1.3, rng = 42)
21+
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = Random.MersenneTwister(42))
22+
balancer2 = Imbalance.MLJ.SMOTENC(k = 10, ratios = 1.2, rng = Random.MersenneTwister(42))
23+
balancer3 = Imbalance.MLJ.ROSE(ratios = 1.3, rng = Random.MersenneTwister(42))
2424

2525
### 1. Make a pipeline of the three balancers and a probablistic model
2626
## ordinary way
@@ -35,8 +35,8 @@
3535
fit!(mach)
3636
y_pred = MLJBase.predict(mach, X_test)
3737

38-
# with MLJ balancing
39-
@test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin
38+
# with MLJ balancing
39+
@test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin
4040
BalancedModel(b1 = balancer1, b2 = balancer2, b3 = balancer3)
4141
end
4242
@test_throws(
@@ -46,7 +46,6 @@
4646
@test_logs (:warn, MLJBalancing.WRN_BALANCER_UNSPECIFIED) begin
4747
BalancedModel(model = model_prob)
4848
end
49-
5049
balanced_model =
5150
BalancedModel(model = model_prob, b1 = balancer1, b2 = balancer2, b3 = balancer3)
5251
mach = machine(balanced_model, X_train, y_train)
@@ -86,3 +85,18 @@
8685
Base.setproperty!(balanced_model, :name11, balancer2),
8786
)
8887
end
88+
89+
90+
@testset "Equivalence of Constructions" begin
91+
## setup parameters
92+
R = Random.MersenneTwister(42)
93+
LogisticClassifier = @load LogisticClassifier pkg = MLJLinearModels verbosity = 0
94+
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = Random.MersenneTwister(42))
95+
model = LogisticClassifier()
96+
BalancedModel(model=model, balancer1=balancer1) == BalancedModel(model; balancer1=balancer1)
97+
98+
@test_throws MLJBalancing.ERR_NUM_ARGS_BM BalancedModel(model, model; balancer1=balancer1)
99+
@test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin
100+
BalancedModel(model; model=model, balancer1=balancer1)
101+
end
102+
end

0 commit comments

Comments
 (0)