Skip to content

Commit 1f67a9c

Browse files
committed
🎨 Some polishing
1 parent 84572f3 commit 1f67a9c

File tree

9 files changed

+1395
-462
lines changed

9 files changed

+1395
-462
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Essam <[email protected]> and contributors"]
44
version = "1.0.0-DEV"
55

66
[deps]
7-
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
87
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
98
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
109
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -16,6 +15,9 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1615
MLJBase = "0.21"
1716
OrderedCollections = "1.6"
1817
julia = "1.6"
18+
MLJModelInterface = "1.9"
19+
MLUtils = "0.4"
20+
StatsBase = "0.34"
1921

2022
[extras]
2123
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"

README.md

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# MLJBalancing
2-
A package providing composite models wrapping class imbalance algorithms from [Imbalance.jl](https://github.com/JuliaAI/Imbalance.jl).
2+
A package providing composite models wrapping class imbalance algorithms from [Imbalance.jl](https://github.com/JuliaAI/Imbalance.jl) with classifiers from [MLJ](https://github.com/alan-turing-institute/MLJ.jl).
33

44
## ⏬ Instalattion
55
```julia
@@ -51,6 +51,24 @@ fit!(mach, verbosity=0);
5151
fitted_params(mach).best_model
5252
```
5353

54-
## 🚆🚆 Parallel Resampling with EasyEnsemble
54+
## 🚆🚆 Parallel Resampling with Balanced Bagging
5555

56-
Coming soon...
56+
The package also offers an implementation of bagging over probabilistic classifiers where the majority class is repeatedly undersampled `T` times down to the size of the minority class. This undersampling scheme was proposed in the *EasyEnsemble* algorithm found in the paper *Exploratory Undersampling for Class-Imbalance Learning.* by *Xu-Ying Liu, Jianxin Wu, & Zhi-Hua Zhou* where an Adaboost model was used and the output scores were averaged.
57+
58+
59+
#### Construct a BalancedBaggingClassifier
60+
In this you must specify the model, and optionally specify the number of bags `T` and the random number generator `rng`. If `T` is not specified it is set as the ratio between the majority and minority counts. If `rng` isn't specified then `default_rng()` is used.
61+
62+
```julia
63+
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0
64+
logistic_model = LogisticClassifier()
65+
bagging_model = BalancedBaggingClassifier(model=logistic_model, T=10, rng=Random.Xoshiro(42))
66+
```
67+
68+
#### Now it behaves like one single model
69+
You can fit, predict, cross-validate and finetune it like any other probabilistic MLJ model where `X` must be a table input (e.g., a dataframe).
70+
```julia
71+
mach = machine(bagging_model, X, y)
72+
fit!(mach)
73+
pred = predict(mach, X)
74+
```
Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 1,
66
"metadata": {},
7-
"outputs": [],
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `~/Documents/GitHub/MLJBalancing/example`\n"
13+
]
14+
}
15+
],
816
"source": [
917
"ENV[\"JULIA_PKG_SERVER\"] = \"\"\n",
1018
"using Pkg\n",
1119
"Pkg.activate(@__DIR__)\n",
1220
"Pkg.instantiate()\n",
1321
"\n",
14-
"\n",
1522
"using MLJBalancing\n",
1623
"using Imbalance\n",
1724
"using MLJ\n",
@@ -27,7 +34,7 @@
2734
},
2835
{
2936
"cell_type": "code",
30-
"execution_count": 15,
37+
"execution_count": 2,
3138
"metadata": {},
3239
"outputs": [
3340
{
@@ -56,9 +63,16 @@
5663
},
5764
{
5865
"cell_type": "code",
59-
"execution_count": 16,
66+
"execution_count": 3,
6067
"metadata": {},
6168
"outputs": [
69+
{
70+
"name": "stderr",
71+
"output_type": "stream",
72+
"text": [
73+
"WARNING: using StaticArrays.setindex in module FiniteDiff conflicts with an existing identifier.\n"
74+
]
75+
},
6276
{
6377
"data": {
6478
"text/plain": [
@@ -82,7 +96,7 @@
8296
"source": [
8397
"LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n",
8498
"logistic_model = LogisticClassifier()\n",
85-
"model = BalancedBaggingClassifier(classifier=logistic_model, T=10, rng=Random.Xoshiro(42))"
99+
"model = BalancedBaggingClassifier(model=logistic_model, T=10, rng=Random.Xoshiro(42))"
86100
]
87101
},
88102
{
@@ -94,7 +108,7 @@
94108
},
95109
{
96110
"cell_type": "code",
97-
"execution_count": 18,
111+
"execution_count": 10,
98112
"metadata": {},
99113
"outputs": [
100114
{

0 commit comments

Comments
 (0)