|
1 | 1 | # MLJLinearModels.jl
|
| 2 | + |
| 3 | +This is a convenience package gathering functionalities to solve a number of generalised linear regression/classification problems which, inherently, correspond to an optimisation problem of the form |
| 4 | + |
| 5 | +```math |
| 6 | +L(y, X\theta) + P(\theta) |
| 7 | +``` |
| 8 | + |
| 9 | +where ``L`` is a _loss function_ and ``P`` is a _penalty function_ (both of those can be scaled or composed). |
| 10 | + |
| 11 | +A well known example is the [Ridge regression](https://en.wikipedia.org/wiki/Tikhonov_regularization) where the problem amounts to minimising |
| 12 | + |
| 13 | +```math |
| 14 | +\|y - X\theta\|_2^2 + \lambda\|\theta\|_2^2. |
| 15 | +``` |
| 16 | + |
| 17 | +## Goals for the package |
| 18 | + |
| 19 | +- make these regressions models "easy to call" and callable in a unified way, |
| 20 | +- interface with [`MLJ.jl`](https://github.com/alan-turing-institute/MLJ.jl), |
| 21 | +- focus on performance including in "big data" settings exploiting packages such as [`Optim.jl`](https://github.com/JuliaNLSolvers/Optim.jl), and [`IterativeSolvers.jl`](https://github.com/JuliaMath/IterativeSolvers.jl), |
| 22 | +- use a "machine learning" perspective, i.e.: focus primarily on prediction, hyper-parameters should be obtained via a data-driven procedure such as cross-validation. |
| 23 | + |
| 24 | +All models allow to fit an intercept and allow the penalty to be optionally applied on the intercept (not applied by default). |
| 25 | +All models attempt to be efficient in terms of memory allocation to avoid unnecessary copies of the data. |
| 26 | + |
| 27 | +## Quick start |
| 28 | + |
| 29 | +The package works by |
| 30 | + |
| 31 | +1. specifying the kind of model you want along with its hyper-parameters, |
| 32 | +2. calling `fit` with that model and the data: `fit(model, X, y)`. |
| 33 | + |
| 34 | +!!! note |
| 35 | + |
| 36 | + The convention is that the feature matrix has dimensions ``n \times p`` where ``n`` is the number of records (points) and ``p`` is the number of features (dimensions). |
| 37 | + |
| 38 | +### Lasso regression |
| 39 | + |
| 40 | +The lasso regression corresponds to a l2-loss function with a l1-penalty: |
| 41 | + |
| 42 | +```math |
| 43 | +\theta_{\text{Lasso}} = \frac12\|y-X\theta\|_2^2 + \lambda\|\theta\|_1 |
| 44 | +``` |
| 45 | + |
| 46 | +which you can create as follows: |
| 47 | + |
| 48 | +```julia |
| 49 | +λ = 0.7 |
| 50 | +lasso = LassoRegression(0.7) |
| 51 | +fit(lasso, X, y) |
| 52 | +``` |
| 53 | + |
| 54 | +### (Multinomial) logistic classifier |
| 55 | + |
| 56 | +In a classification context, the multinomial logistic regression returns a predicted score per class that can be interpreted as the likelihood of a point belonging to a class given the trained model. |
| 57 | +It's given by the multinomial loss plus an optional penalty (typically the l2 penalty). |
| 58 | + |
| 59 | +Here's a way to do this: |
| 60 | + |
| 61 | +```julia |
| 62 | +λ = 0.1 |
| 63 | +mlr = MultinomialRegression(λ) # you can also just use LogisticRegression |
| 64 | +fit(mlr, X, y) |
| 65 | +``` |
| 66 | + |
| 67 | +In a **binary** context, ``y`` is expected to have values ``y_i \in \{\pm 1\}`` whereas in the **multiclass** context, ``y`` is expected to have values ``y_i \in {1, \dots, c}`` where ``c > 2`` is the number of classes. |
| 68 | + |
| 69 | +## Available models |
| 70 | + |
| 71 | +### Regression models (continuous target) |
| 72 | + |
| 73 | +| Regressors | Formulation¹ | Available solvers | Comments | |
| 74 | +| :------------------ | :--------------------- | :-------------------------------- | :-------- | |
| 75 | +| OLS & Ridge | L2Loss + 0/L2 | Analytical² or CG³ | | |
| 76 | +| Lasso & Elastic-Net | L2Loss + 0/L2 + L1 | (F)ISTA⁴ | | |
| 77 | +| Robust 0/L2 | RobustLoss⁵ + 0/L2 | Newton, NewtonCG, LBFGS, IWLS-CG⁶ | no scale⁷ | |
| 78 | +| Robust L1/EN | RobustLoss + 0/L2 + L1 | (F)ISTA | | |
| 79 | +| Quantile⁸ + 0/L2 | RobustLoss + 0/L2 | LBFGS, IWLS-CG | | |
| 80 | +| Quantile L1/EN | RobustLoss + 0/L2 + L1 | (F)ISTA | | |
| 81 | + |
| 82 | +1. "0" stands for no penalty |
| 83 | +2. Analytical means the solution is computed in "one shot" using the `\` solver, |
| 84 | +3. CG = conjugate gradient |
| 85 | +4. (Accelerated) Proximal Gradient Descent |
| 86 | +5. _Huber_, _Andrews_, _Bisquare_, _Logistic_, _Fair_ and _Talwar_ weighing functions available. |
| 87 | +6. Iteratively re-Weighted Least Squares where each system is solved iteratively via CG |
| 88 | +7. In other packages such as Scikit-Learn, a scale factor is estimated along with the parameters, this is a bit ad-hoc and corresponds more to a statistical perspective, further it does not work well with penalties; we recommend using cross-validation to set the parameter of the Huber Loss. |
| 89 | +8. Includes as special case the _least absolute deviation_ (LAD) regression when `δ=0.5`. |
| 90 | + |
| 91 | +### Classification models (finite target) |
| 92 | + |
| 93 | +| Classifiers | Formulation | Available solvers | Comments | |
| 94 | +| :-----------------| :-------------------------- | :----------------------- | :------------- | |
| 95 | +| Logistic 0/L2 | LogisticLoss + 0/L2 | Newton, Newton-CG, LBFGS | `yᵢ∈{±1}` | |
| 96 | +| Logistic L1/EN | LogisticLoss + 0/L2 + L1 | (F)ISTA | `yᵢ∈{±1}` | |
| 97 | +| Multinomial 0/L2 | MultinomialLoss + 0/L2 | Newton-CG, LBFGS | `yᵢ∈{1,...,c}` | |
| 98 | +| Multinomial L1/EN | MultinomialLoss + 0/L2 + L1 | ISTA, FISTA | `yᵢ∈{1,...,c}` | |
| 99 | + |
| 100 | +Unless otherwise specified: |
| 101 | + |
| 102 | +* Newton-like solvers use Hager-Zhang line search (default in [`Optim.jl`]((https://github.com/JuliaNLSolvers/Optim.jl))) |
| 103 | +* ISTA, FISTA solvers use backtracking line search and a shrinkage factor of `β=0.8` |
| 104 | + |
| 105 | +**Note**: these models were all tested for correctness whenever a direct comparison with another package was possible, usually by comparing the objective function at the coefficients returned (cf. the tests): |
| 106 | +- (_against [scikit-learn](https://scikit-learn.org/)_): Lasso, Elastic-Net, Logistic (L1/L2/EN), Multinomial (L1/L2/EN) |
| 107 | +- (_against [quantreg](https://cran.r-project.org/web/packages/quantreg/index.html)_): Quantile (0/L1) |
| 108 | + |
| 109 | +Systematic timing benchmarks have not been run yet but it's planned (see [this issue](https://github.com/alan-turing-institute/MLJLinearModels.jl/issues/14)). |
| 110 | + |
| 111 | +## Limitations |
| 112 | + |
| 113 | +Note the current limitations: |
| 114 | + |
| 115 | +* The models are built and tested assuming `n > p`; if this doesn't hold, tricks should be employed to speed up computations; these have not been implemented yet. |
| 116 | +* CV-aware code not implemented yet (code that re-uses computations when fitting over a number of hyper-parameters); "Meta" functionalities such as One-vs-All or Cross-Validation are left to other packages such as MLJ. |
| 117 | +* No support yet for sparse matrices. |
| 118 | +* Stochastic solvers have not yet been implemented. |
| 119 | +* All computations are assumed to be done in Float64. |
0 commit comments