|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": null, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "ENV[\"JULIA_PKG_SERVER\"] = \"\"\n", |
| 10 | + "using Pkg\n", |
| 11 | + "Pkg.activate(@__DIR__)\n", |
| 12 | + "Pkg.instantiate()\n", |
| 13 | + "\n", |
| 14 | + "\n", |
| 15 | + "using MLJBalancing\n", |
| 16 | + "using Imbalance\n", |
| 17 | + "using MLJ\n", |
| 18 | + "using Random" |
| 19 | + ] |
| 20 | + }, |
| 21 | + { |
| 22 | + "cell_type": "markdown", |
| 23 | + "metadata": {}, |
| 24 | + "source": [ |
| 25 | + "#### Load Data" |
| 26 | + ] |
| 27 | + }, |
| 28 | + { |
| 29 | + "cell_type": "code", |
| 30 | + "execution_count": 15, |
| 31 | + "metadata": {}, |
| 32 | + "outputs": [ |
| 33 | + { |
| 34 | + "data": { |
| 35 | + "text/plain": [ |
| 36 | + "((Column1 = [0.9695150609084499, 0.012898301755861596, 0.7555027304121053, 0.3467415729179013, 0.35969402837473463, 0.2601876747805505, 0.9522580699968279, 0.06304475092339623, 0.18909001622655808, 0.19934942931986965 … 0.021532597906190776, 0.8482825697641306, 0.10773487816863903, 0.32189982199036116, 0.12662208474317038, 0.28529465447429614, 0.2907506630258835, 0.36872799387588473, 0.061489791166806085, 0.45645058368583713], Column2 = [0.06546916714160167, 0.7243956502957003, 0.5183099801474415, 0.7555562860508294, 0.11226218114407538, 0.9135150277876691, 0.8739421974558176, 0.2268482788660101, 0.580604436651146, 0.4142252330250549 … 0.6517425913240111, 0.01713263102740481, 0.7175499403837856, 0.7362894157420817, 0.24893665902538054, 0.41499951381631595, 0.2159527717429719, 0.8966879835264249, 0.87252430655793, 0.41461921031276117], Column3 = [0.5939320702328891, 0.19329886972497456, 0.04656947038518311, 0.22095698685781184, 0.678807659662497, 0.12720198818430306, 0.6795750371448686, 0.9314917999820301, 0.22920734893984274, 0.5148148980955375 … 0.55049773593343, 0.038576459283091946, 0.27765727942909757, 0.2753072414696357, 0.8823620780359746, 0.44831794170895023, 0.9073846432163745, 0.4648550947905655, 0.311984726769037, 0.25829997798611304], Column4 = [0.12253944650540982, 0.8259140842535423, 0.4034477332184384, 0.5279399406265695, 0.5579944087437719, 0.24650366028608328, 0.6874897000162434, 0.23391406844015605, 0.5641254897013973, 0.6250622796341656 … 0.21708181942178983, 0.35224683896541464, 0.8444113778983325, 0.4547214584884428, 0.13508852017592232, 0.9510137735662383, 0.5723463533029658, 0.626377972762265, 0.7854013810594317, 0.15394691114473347], Column5 = [0.47958743625921163, 0.45779753417165514, 0.6367059235247621, 0.8601116026079643, 0.3334020182022719, 0.41593698717526373, 0.13208968772625174, 0.16951044109747648, 0.8137887839507706, 0.4429229861115882 … 0.01308976221980429, 0.48597926808091163, 0.20768781798463476, 0.30045611276046247, 0.15759293576302558, 0.975806377881983, 0.19451065500145392, 0.9638103356367584, 0.3594043445295293, 0.7792867217495332], Column6 = [3.0, 3.0, 1.0, 3.0, 1.0, 2.0, 3.0, 2.0, 3.0, 3.0 … 3.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0, 3.0, 3.0, 1.0], Column7 = [2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0 … 2.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0]), CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 0, 0, 0, 1, 0 … 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])" |
| 37 | + ] |
| 38 | + }, |
| 39 | + "metadata": {}, |
| 40 | + "output_type": "display_data" |
| 41 | + } |
| 42 | + ], |
| 43 | + "source": [ |
| 44 | + "X, y = generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], \n", |
| 45 | + " probs = [0.9, 0.1], \n", |
| 46 | + " type = \"ColTable\", \n", |
| 47 | + " rng=42)" |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "markdown", |
| 52 | + "metadata": {}, |
| 53 | + "source": [ |
| 54 | + "#### Construct `BalancedBaggingClassifier` Model" |
| 55 | + ] |
| 56 | + }, |
| 57 | + { |
| 58 | + "cell_type": "code", |
| 59 | + "execution_count": 16, |
| 60 | + "metadata": {}, |
| 61 | + "outputs": [ |
| 62 | + { |
| 63 | + "data": { |
| 64 | + "text/plain": [ |
| 65 | + "BalancedBaggingClassifier(\n", |
| 66 | + " model = LogisticClassifier(\n", |
| 67 | + " lambda = 2.220446049250313e-16, \n", |
| 68 | + " gamma = 0.0, \n", |
| 69 | + " penalty = :l2, \n", |
| 70 | + " fit_intercept = true, \n", |
| 71 | + " penalize_intercept = false, \n", |
| 72 | + " scale_penalty_with_samples = true, \n", |
| 73 | + " solver = nothing), \n", |
| 74 | + " T = 10, \n", |
| 75 | + " rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1))" |
| 76 | + ] |
| 77 | + }, |
| 78 | + "metadata": {}, |
| 79 | + "output_type": "display_data" |
| 80 | + } |
| 81 | + ], |
| 82 | + "source": [ |
| 83 | + "LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n", |
| 84 | + "logistic_model = LogisticClassifier()\n", |
| 85 | + "model = BalancedBaggingClassifier(classifier=logistic_model, T=10, rng=Random.Xoshiro(42))" |
| 86 | + ] |
| 87 | + }, |
| 88 | + { |
| 89 | + "cell_type": "markdown", |
| 90 | + "metadata": {}, |
| 91 | + "source": [ |
| 92 | + "#### Train & Evaluate the Model" |
| 93 | + ] |
| 94 | + }, |
| 95 | + { |
| 96 | + "cell_type": "code", |
| 97 | + "execution_count": 18, |
| 98 | + "metadata": {}, |
| 99 | + "outputs": [ |
| 100 | + { |
| 101 | + "name": "stderr", |
| 102 | + "output_type": "stream", |
| 103 | + "text": [ |
| 104 | + "┌ Info: Training machine(LogisticClassifier(lambda = 2.220446049250313e-16, …), …).\n", |
| 105 | + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", |
| 106 | + "┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}\n", |
| 107 | + "│ optim_options: Optim.Options{Float64, Nothing}\n", |
| 108 | + "│ lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()\n", |
| 109 | + "└ @ MLJLinearModels /Users/essam/.julia/packages/MLJLinearModels/zSQnL/src/mlj/interface.jl:72\n" |
| 110 | + ] |
| 111 | + }, |
| 112 | + { |
| 113 | + "data": { |
| 114 | + "text/plain": [ |
| 115 | + "100-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, Int64, UInt32, Float64}:\n", |
| 116 | + " UnivariateFinite{Multiclass{2}}(0=>0.928, 1=>0.0722)\n", |
| 117 | + " UnivariateFinite{Multiclass{2}}(0=>0.845, 1=>0.155)\n", |
| 118 | + " UnivariateFinite{Multiclass{2}}(0=>0.749, 1=>0.251)\n", |
| 119 | + " UnivariateFinite{Multiclass{2}}(0=>0.902, 1=>0.0977)\n", |
| 120 | + " UnivariateFinite{Multiclass{2}}(0=>0.804, 1=>0.196)\n", |
| 121 | + " UnivariateFinite{Multiclass{2}}(0=>0.864, 1=>0.136)\n", |
| 122 | + " UnivariateFinite{Multiclass{2}}(0=>0.851, 1=>0.149)\n", |
| 123 | + " UnivariateFinite{Multiclass{2}}(0=>0.954, 1=>0.0458)\n", |
| 124 | + " UnivariateFinite{Multiclass{2}}(0=>0.853, 1=>0.147)\n", |
| 125 | + " UnivariateFinite{Multiclass{2}}(0=>0.86, 1=>0.14)\n", |
| 126 | + " ⋮\n", |
| 127 | + " UnivariateFinite{Multiclass{2}}(0=>0.671, 1=>0.329)\n", |
| 128 | + " UnivariateFinite{Multiclass{2}}(0=>0.73, 1=>0.27)\n", |
| 129 | + " UnivariateFinite{Multiclass{2}}(0=>0.843, 1=>0.157)\n", |
| 130 | + " UnivariateFinite{Multiclass{2}}(0=>0.941, 1=>0.0594)\n", |
| 131 | + " UnivariateFinite{Multiclass{2}}(0=>0.872, 1=>0.128)\n", |
| 132 | + " UnivariateFinite{Multiclass{2}}(0=>0.92, 1=>0.0797)\n", |
| 133 | + " UnivariateFinite{Multiclass{2}}(0=>0.929, 1=>0.0714)\n", |
| 134 | + " UnivariateFinite{Multiclass{2}}(0=>0.791, 1=>0.209)\n", |
| 135 | + " UnivariateFinite{Multiclass{2}}(0=>0.827, 1=>0.173)" |
| 136 | + ] |
| 137 | + }, |
| 138 | + "metadata": {}, |
| 139 | + "output_type": "display_data" |
| 140 | + } |
| 141 | + ], |
| 142 | + "source": [ |
| 143 | + "mach = machine(logistic_model, X, y)\n", |
| 144 | + "fit!(mach)\n", |
| 145 | + "pred = predict(mach, X)" |
| 146 | + ] |
| 147 | + } |
| 148 | + ], |
| 149 | + "metadata": { |
| 150 | + "kernelspec": { |
| 151 | + "display_name": "Julia 1.8.5", |
| 152 | + "language": "julia", |
| 153 | + "name": "julia-1.8" |
| 154 | + }, |
| 155 | + "language_info": { |
| 156 | + "file_extension": ".jl", |
| 157 | + "mimetype": "application/julia", |
| 158 | + "name": "julia", |
| 159 | + "version": "1.8.5" |
| 160 | + }, |
| 161 | + "orig_nbformat": 4 |
| 162 | + }, |
| 163 | + "nbformat": 4, |
| 164 | + "nbformat_minor": 2 |
| 165 | +} |
0 commit comments