Skip to content

Commit 1e1b3b7

Browse files
authored
Merge pull request #15 from JuliaAI/dev
For a 0.1.1 release
2 parents 104adcf + e6b1aaa commit 1e1b3b7

File tree

11 files changed

+345
-201
lines changed

11 files changed

+345
-201
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
version:
22-
- '1.8'
22+
- '1.7'
2323
- '1'
2424

2525
os: [ubuntu-latest, windows-latest, macOS-latest]

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
/Manifest.toml
22
.CondaPkg/*
33
.CondaPkg
4+
.vscode/settings.json

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBalancing"
22
uuid = "45f359ea-796d-4f51-95a5-deb1a414c586"
33
authors = ["Essam Wisam <[email protected]>", "Anthony Blaom <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
@@ -12,12 +12,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1313

1414
[compat]
15-
MLJBase = "0.21"
15+
MLJBase = "1"
1616
OrderedCollections = "1.6"
17-
julia = "1.6"
1817
MLJModelInterface = "1.9"
1918
MLUtils = "0.4"
2019
StatsBase = "0.34"
20+
julia = "1.7"
2121

2222
[extras]
2323
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# MLJBalancing
22
A package providing composite models wrapping class imbalance algorithms from [Imbalance.jl](https://github.com/JuliaAI/Imbalance.jl) with classifiers from [MLJ](https://github.com/alan-turing-institute/MLJ.jl).
33

4-
## Instalattion
4+
## Installation
55
```julia
66
import Pkg;
77
Pkg.add("MLJBalancing")
@@ -17,6 +17,7 @@ This package allows chaining of resampling methods from Imbalance.jl with classi
1717
```julia
1818
SMOTENC = @load SMOTENC pkg=Imbalance verbosity=0
1919
TomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0
20+
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0
2021

2122
oversampler = SMOTENC(k=5, ratios=1.0, rng=42)
2223
undersampler = TomekUndersampler(min_ratios=0.5, rng=42)
@@ -33,7 +34,7 @@ Here training data will be passed to `balancer1` then `balancer2`, whose output
3334
In general, there can be any number of balancers, and the user can give the balancers arbitrary names.
3435

3536
#### At this point, they behave like one single model
36-
You can fit, predict, cross-validate and finetune it like any other MLJ model. Here is an example for finetuning
37+
You can fit, predict, cross-validate and fine-tune it like any other MLJ model. Here is an example for fine-tuning
3738
```julia
3839
r1 = range(balanced_model, :(balancer1.k), lower=3, upper=10)
3940
r2 = range(balanced_model, :(balancer2.min_ratios), lower=0.1, upper=0.9)
@@ -57,7 +58,7 @@ The package also offers an implementation of bagging over probabilistic classifi
5758

5859

5960
#### Construct a BalancedBaggingClassifier
60-
In this you must specify the model, and optionally specify the number of bags `T` and the random number generator `rng`. If `T` is not specified it is set as the ratio between the majority and minority counts. If `rng` isn't specified then `default_rng()` is used.
61+
In this you must specify some probabilistic model, and optionally specify the number of bags `T` and the random number generator `rng`. If `T` is not specified it is set as the ratio between the majority and minority counts. If `rng` isn't specified then `default_rng()` is used.
6162

6263
```julia
6364
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0
@@ -66,9 +67,9 @@ bagging_model = BalancedBaggingClassifier(model=logistic_model, T=10, rng=Random
6667
```
6768

6869
#### Now it behaves like one single model
69-
You can fit, predict, cross-validate and finetune it like any other probabilistic MLJ model where `X` must be a table input (e.g., a dataframe).
70+
You can fit, predict, cross-validate and fine-tune it like any other probabilistic MLJ model where `X` must be a table input (e.g., a dataframe).
7071
```julia
7172
mach = machine(bagging_model, X, y)
7273
fit!(mach)
7374
pred = predict(mach, X)
74-
```
75+
```

example/BalancedBagging.ipynb

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@
4040
{
4141
"data": {
4242
"text/plain": [
43-
"((Column1 = [0.9695150609084499, 0.012898301755861596, 0.7555027304121053, 0.3467415729179013, 0.35969402837473463, 0.2601876747805505, 0.9522580699968279, 0.06304475092339623, 0.18909001622655808, 0.19934942931986965 … 0.021532597906190776, 0.8482825697641306, 0.10773487816863903, 0.32189982199036116, 0.12662208474317038, 0.28529465447429614, 0.2907506630258835, 0.36872799387588473, 0.061489791166806085, 0.45645058368583713], Column2 = [0.06546916714160167, 0.7243956502957003, 0.5183099801474415, 0.7555562860508294, 0.11226218114407538, 0.9135150277876691, 0.8739421974558176, 0.2268482788660101, 0.580604436651146, 0.4142252330250549 … 0.6517425913240111, 0.01713263102740481, 0.7175499403837856, 0.7362894157420817, 0.24893665902538054, 0.41499951381631595, 0.2159527717429719, 0.8966879835264249, 0.87252430655793, 0.41461921031276117], Column3 = [0.5939320702328891, 0.19329886972497456, 0.04656947038518311, 0.22095698685781184, 0.678807659662497, 0.12720198818430306, 0.6795750371448686, 0.9314917999820301, 0.22920734893984274, 0.5148148980955375 … 0.55049773593343, 0.038576459283091946, 0.27765727942909757, 0.2753072414696357, 0.8823620780359746, 0.44831794170895023, 0.9073846432163745, 0.4648550947905655, 0.311984726769037, 0.25829997798611304], Column4 = [0.12253944650540982, 0.8259140842535423, 0.4034477332184384, 0.5279399406265695, 0.5579944087437719, 0.24650366028608328, 0.6874897000162434, 0.23391406844015605, 0.5641254897013973, 0.6250622796341656 … 0.21708181942178983, 0.35224683896541464, 0.8444113778983325, 0.4547214584884428, 0.13508852017592232, 0.9510137735662383, 0.5723463533029658, 0.626377972762265, 0.7854013810594317, 0.15394691114473347], Column5 = [0.47958743625921163, 0.45779753417165514, 0.6367059235247621, 0.8601116026079643, 0.3334020182022719, 0.41593698717526373, 0.13208968772625174, 0.16951044109747648, 0.8137887839507706, 0.4429229861115882 … 0.01308976221980429, 0.48597926808091163, 0.20768781798463476, 0.30045611276046247, 0.15759293576302558, 0.975806377881983, 0.19451065500145392, 0.9638103356367584, 0.3594043445295293, 0.7792867217495332], Column6 = [3.0, 3.0, 1.0, 3.0, 1.0, 2.0, 3.0, 2.0, 3.0, 3.0 … 3.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0, 3.0, 3.0, 1.0], Column7 = [2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0 … 2.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0]), CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 0, 0, 0, 1, 0 … 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])"
43+
"((Column1 = [0.564, 0.862, 0.793, 0.505, 0.683, 0.699, 0.545, 0.693, 0.95, 0.44 … 0.423, 0.632, 0.922, 0.592, 0.944, 0.517, 0.785, 0.579, 0.725, 0.711], Column2 = [0.42, 0.715, 0.358, -0.009, 0.228, 0.725, 0.786, 0.52, 0.646, 0.582 … 0.65, 0.633, 0.263, 0.141, 0.472, 0.45, -0.019, 0.593, 0.777, 0.877], Column3 = [0.638, 0.719, 0.716, 0.604, 0.616, 0.784, 0.697, 0.711, 0.878, 0.739 … 0.722, 0.672, 0.879, 0.598, 0.879, 0.669, 0.728, 0.768, 0.736, 0.725], Column4 = [0.29, 0.164, 0.164, 0.262, 0.246, 0.211, 0.155, 0.03, 1.842, 0.324 … 0.192, 0.143, 1.323, 0.251, 1.084, 0.165, 0.138, 0.176, 0.155, 0.217], Column5 = [0.605, 0.287, 0.565, 0.121, 0.752, 0.317, 0.165, 0.497, 0.361, 0.293 … 0.726, 0.781, 0.694, 0.728, 0.692, 0.351, 0.089, 0.478, 0.067, -0.19], Column6 = [2.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 2.0, 2.0, 3.0 … 1.0, 3.0, 2.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0], Column7 = [2.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0 … 1.0, 2.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0]), CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 0, 0, 0, 1, 0 … 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])"
4444
]
4545
},
4646
"metadata": {},
4747
"output_type": "display_data"
4848
}
4949
],
5050
"source": [
51-
"X, y = generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], \n",
52-
" probs = [0.9, 0.1], \n",
51+
"X, y = generate_imbalanced_data(100, 5; num_vals_per_category = [3, 2], \n",
52+
" class_probs = [0.9, 0.1], \n",
5353
" type = \"ColTable\", \n",
5454
" rng=42)"
5555
]
@@ -73,6 +73,15 @@
7373
"WARNING: using StaticArrays.setindex in module FiniteDiff conflicts with an existing identifier.\n"
7474
]
7575
},
76+
{
77+
"name": "stderr",
78+
"output_type": "stream",
79+
"text": [
80+
"┌ Warning: The call to compilecache failed to create a usable precompiled cache file for MLJLinearModels [6ee0df7b-362f-4a72-a706-9e79364fb692]\n",
81+
"│ exception = ErrorException(\"Required dependency Optim [429524aa-4258-5aef-a3af-852621145aeb] failed to load from a cache file.\")\n",
82+
"└ @ Base loading.jl:1349\n"
83+
]
84+
},
7685
{
7786
"data": {
7887
"text/plain": [
@@ -108,7 +117,7 @@
108117
},
109118
{
110119
"cell_type": "code",
111-
"execution_count": 10,
120+
"execution_count": 4,
112121
"metadata": {},
113122
"outputs": [
114123
{
@@ -127,26 +136,26 @@
127136
"data": {
128137
"text/plain": [
129138
"100-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, Int64, UInt32, Float64}:\n",
130-
" UnivariateFinite{Multiclass{2}}(0=>0.928, 1=>0.0722)\n",
131-
" UnivariateFinite{Multiclass{2}}(0=>0.845, 1=>0.155)\n",
132-
" UnivariateFinite{Multiclass{2}}(0=>0.749, 1=>0.251)\n",
133-
" UnivariateFinite{Multiclass{2}}(0=>0.902, 1=>0.0977)\n",
134-
" UnivariateFinite{Multiclass{2}}(0=>0.804, 1=>0.196)\n",
135-
" UnivariateFinite{Multiclass{2}}(0=>0.864, 1=>0.136)\n",
136-
" UnivariateFinite{Multiclass{2}}(0=>0.851, 1=>0.149)\n",
137-
" UnivariateFinite{Multiclass{2}}(0=>0.954, 1=>0.0458)\n",
138-
" UnivariateFinite{Multiclass{2}}(0=>0.853, 1=>0.147)\n",
139-
" UnivariateFinite{Multiclass{2}}(0=>0.86, 1=>0.14)\n",
139+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
140+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
141+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
142+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
143+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
144+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
145+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
146+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
147+
" UnivariateFinite{Multiclass{2}}(0=>0.0, 1=>1.0)\n",
148+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
140149
"\n",
141-
" UnivariateFinite{Multiclass{2}}(0=>0.671, 1=>0.329)\n",
142-
" UnivariateFinite{Multiclass{2}}(0=>0.73, 1=>0.27)\n",
143-
" UnivariateFinite{Multiclass{2}}(0=>0.843, 1=>0.157)\n",
144-
" UnivariateFinite{Multiclass{2}}(0=>0.941, 1=>0.0594)\n",
145-
" UnivariateFinite{Multiclass{2}}(0=>0.872, 1=>0.128)\n",
146-
" UnivariateFinite{Multiclass{2}}(0=>0.92, 1=>0.0797)\n",
147-
" UnivariateFinite{Multiclass{2}}(0=>0.929, 1=>0.0714)\n",
148-
" UnivariateFinite{Multiclass{2}}(0=>0.791, 1=>0.209)\n",
149-
" UnivariateFinite{Multiclass{2}}(0=>0.827, 1=>0.173)"
150+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
151+
" UnivariateFinite{Multiclass{2}}(0=>0.0, 1=>1.0)\n",
152+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
153+
" UnivariateFinite{Multiclass{2}}(0=>0.0, 1=>1.0)\n",
154+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
155+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
156+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
157+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n",
158+
" UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)"
150159
]
151160
},
152161
"metadata": {},

0 commit comments

Comments
 (0)