Skip to content

Commit 84572f3

Browse files
authored
Merge pull request #2 from JuliaAI/BalancedBagging
✨ Introduce BalancedBagging
2 parents e2cbd1d + 7f0f569 commit 84572f3

File tree

10 files changed

+558
-1000
lines changed

10 files changed

+558
-1000
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ jobs:
2020
matrix:
2121
version:
2222
- '1.8'
23+
- '1'
24+
2325
os: [ubuntu-latest, windows-latest, macOS-latest]
2426
arch:
2527
- x64

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/Manifest.toml
2+
.CondaPkg/*
23
.CondaPkg

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ authors = ["Essam <[email protected]> and contributors"]
44
version = "1.0.0-DEV"
55

66
[deps]
7+
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
78
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
89
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
10+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
911
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1012
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1114

1215
[compat]
1316
MLJBase = "0.21"
@@ -17,9 +20,10 @@ julia = "1.6"
1720
[extras]
1821
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1922
Imbalance = "c709b415-507b-45b7-9a3d-1767c89fde68"
20-
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
2123
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
24+
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
25+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2226
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2327

2428
[targets]
25-
test = ["Test", "Imbalance", "DataFrames", "MLJLinearModels", "MLJModels"]
29+
test = ["Test", "Imbalance", "DataFrames", "MLJLinearModels", "MLJModels", "Tables"]

examples/BalancedBagging.ipynb

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"ENV[\"JULIA_PKG_SERVER\"] = \"\"\n",
10+
"using Pkg\n",
11+
"Pkg.activate(@__DIR__)\n",
12+
"Pkg.instantiate()\n",
13+
"\n",
14+
"\n",
15+
"using MLJBalancing\n",
16+
"using Imbalance\n",
17+
"using MLJ\n",
18+
"using Random"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"#### Load Data"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": 15,
31+
"metadata": {},
32+
"outputs": [
33+
{
34+
"data": {
35+
"text/plain": [
36+
"((Column1 = [0.9695150609084499, 0.012898301755861596, 0.7555027304121053, 0.3467415729179013, 0.35969402837473463, 0.2601876747805505, 0.9522580699968279, 0.06304475092339623, 0.18909001622655808, 0.19934942931986965 … 0.021532597906190776, 0.8482825697641306, 0.10773487816863903, 0.32189982199036116, 0.12662208474317038, 0.28529465447429614, 0.2907506630258835, 0.36872799387588473, 0.061489791166806085, 0.45645058368583713], Column2 = [0.06546916714160167, 0.7243956502957003, 0.5183099801474415, 0.7555562860508294, 0.11226218114407538, 0.9135150277876691, 0.8739421974558176, 0.2268482788660101, 0.580604436651146, 0.4142252330250549 … 0.6517425913240111, 0.01713263102740481, 0.7175499403837856, 0.7362894157420817, 0.24893665902538054, 0.41499951381631595, 0.2159527717429719, 0.8966879835264249, 0.87252430655793, 0.41461921031276117], Column3 = [0.5939320702328891, 0.19329886972497456, 0.04656947038518311, 0.22095698685781184, 0.678807659662497, 0.12720198818430306, 0.6795750371448686, 0.9314917999820301, 0.22920734893984274, 0.5148148980955375 … 0.55049773593343, 0.038576459283091946, 0.27765727942909757, 0.2753072414696357, 0.8823620780359746, 0.44831794170895023, 0.9073846432163745, 0.4648550947905655, 0.311984726769037, 0.25829997798611304], Column4 = [0.12253944650540982, 0.8259140842535423, 0.4034477332184384, 0.5279399406265695, 0.5579944087437719, 0.24650366028608328, 0.6874897000162434, 0.23391406844015605, 0.5641254897013973, 0.6250622796341656 … 0.21708181942178983, 0.35224683896541464, 0.8444113778983325, 0.4547214584884428, 0.13508852017592232, 0.9510137735662383, 0.5723463533029658, 0.626377972762265, 0.7854013810594317, 0.15394691114473347], Column5 = [0.47958743625921163, 0.45779753417165514, 0.6367059235247621, 0.8601116026079643, 0.3334020182022719, 0.41593698717526373, 0.13208968772625174, 0.16951044109747648, 0.8137887839507706, 0.4429229861115882 … 0.01308976221980429, 0.48597926808091163, 0.20768781798463476, 0.30045611276046247, 0.15759293576302558, 0.975806377881983, 0.19451065500145392, 0.9638103356367584, 0.3594043445295293, 0.7792867217495332], Column6 = [3.0, 3.0, 1.0, 3.0, 1.0, 2.0, 3.0, 2.0, 3.0, 3.0 … 3.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0, 3.0, 3.0, 1.0], Column7 = [2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0 … 2.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0]), CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 0, 0, 0, 1, 0 … 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])"
37+
]
38+
},
39+
"metadata": {},
40+
"output_type": "display_data"
41+
}
42+
],
43+
"source": [
44+
"X, y = generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], \n",
45+
" probs = [0.9, 0.1], \n",
46+
" type = \"ColTable\", \n",
47+
" rng=42)"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"#### Construct `BalancedBaggingClassifier` Model"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 16,
60+
"metadata": {},
61+
"outputs": [
62+
{
63+
"data": {
64+
"text/plain": [
65+
"BalancedBaggingClassifier(\n",
66+
" model = LogisticClassifier(\n",
67+
" lambda = 2.220446049250313e-16, \n",
68+
" gamma = 0.0, \n",
69+
" penalty = :l2, \n",
70+
" fit_intercept = true, \n",
71+
" penalize_intercept = false, \n",
72+
" scale_penalty_with_samples = true, \n",
73+
" solver = nothing), \n",
74+
" T = 10, \n",
75+
" rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1))"
76+
]
77+
},
78+
"metadata": {},
79+
"output_type": "display_data"
80+
}
81+
],
82+
"source": [
83+
"LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n",
84+
"logistic_model = LogisticClassifier()\n",
85+
"model = BalancedBaggingClassifier(classifier=logistic_model, T=10, rng=Random.Xoshiro(42))"
86+
]
87+
},
88+
{
89+
"cell_type": "markdown",
90+
"metadata": {},
91+
"source": [
92+
"#### Train & Evaluate the Model"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": 18,
98+
"metadata": {},
99+
"outputs": [
100+
{
101+
"name": "stderr",
102+
"output_type": "stream",
103+
"text": [
104+
"┌ Info: Training machine(LogisticClassifier(lambda = 2.220446049250313e-16, …), …).\n",
105+
"└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n",
106+
"┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}\n",
107+
"│ optim_options: Optim.Options{Float64, Nothing}\n",
108+
"│ lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()\n",
109+
"└ @ MLJLinearModels /Users/essam/.julia/packages/MLJLinearModels/zSQnL/src/mlj/interface.jl:72\n"
110+
]
111+
},
112+
{
113+
"data": {
114+
"text/plain": [
115+
"100-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, Int64, UInt32, Float64}:\n",
116+
" UnivariateFinite{Multiclass{2}}(0=>0.928, 1=>0.0722)\n",
117+
" UnivariateFinite{Multiclass{2}}(0=>0.845, 1=>0.155)\n",
118+
" UnivariateFinite{Multiclass{2}}(0=>0.749, 1=>0.251)\n",
119+
" UnivariateFinite{Multiclass{2}}(0=>0.902, 1=>0.0977)\n",
120+
" UnivariateFinite{Multiclass{2}}(0=>0.804, 1=>0.196)\n",
121+
" UnivariateFinite{Multiclass{2}}(0=>0.864, 1=>0.136)\n",
122+
" UnivariateFinite{Multiclass{2}}(0=>0.851, 1=>0.149)\n",
123+
" UnivariateFinite{Multiclass{2}}(0=>0.954, 1=>0.0458)\n",
124+
" UnivariateFinite{Multiclass{2}}(0=>0.853, 1=>0.147)\n",
125+
" UnivariateFinite{Multiclass{2}}(0=>0.86, 1=>0.14)\n",
126+
"\n",
127+
" UnivariateFinite{Multiclass{2}}(0=>0.671, 1=>0.329)\n",
128+
" UnivariateFinite{Multiclass{2}}(0=>0.73, 1=>0.27)\n",
129+
" UnivariateFinite{Multiclass{2}}(0=>0.843, 1=>0.157)\n",
130+
" UnivariateFinite{Multiclass{2}}(0=>0.941, 1=>0.0594)\n",
131+
" UnivariateFinite{Multiclass{2}}(0=>0.872, 1=>0.128)\n",
132+
" UnivariateFinite{Multiclass{2}}(0=>0.92, 1=>0.0797)\n",
133+
" UnivariateFinite{Multiclass{2}}(0=>0.929, 1=>0.0714)\n",
134+
" UnivariateFinite{Multiclass{2}}(0=>0.791, 1=>0.209)\n",
135+
" UnivariateFinite{Multiclass{2}}(0=>0.827, 1=>0.173)"
136+
]
137+
},
138+
"metadata": {},
139+
"output_type": "display_data"
140+
}
141+
],
142+
"source": [
143+
"mach = machine(logistic_model, X, y)\n",
144+
"fit!(mach)\n",
145+
"pred = predict(mach, X)"
146+
]
147+
}
148+
],
149+
"metadata": {
150+
"kernelspec": {
151+
"display_name": "Julia 1.8.5",
152+
"language": "julia",
153+
"name": "julia-1.8"
154+
},
155+
"language_info": {
156+
"file_extension": ".jl",
157+
"mimetype": "application/julia",
158+
"name": "julia",
159+
"version": "1.8.5"
160+
},
161+
"orig_nbformat": 4
162+
},
163+
"nbformat": 4,
164+
"nbformat_minor": 2
165+
}

0 commit comments

Comments
 (0)