Skip to content
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f3c5d24
Merge pull request #7 from apoorvalal/fw_simplex
apoorvalal Nov 18, 2024
6c4daa3
add unit tests for MLE
apoorvalal Dec 18, 2024
e615fcd
Refactor C++ code into modular files and switch to setuptools build s…
apoorvalal Mar 16, 2025
99efd18
Add comprehensive benchmarking script
apoorvalal Mar 16, 2025
5159f1a
"add full benchmarks"
apoorvalal Mar 16, 2025
2ff25ee
version bump; figure in readme
apoorvalal Mar 16, 2025
ec914ce
Merge pull request #9 from apoorvalal/modular-cpp-structure
apoorvalal Mar 16, 2025
06b3864
Update GitHub Actions to use upload-artifact v4
apoorvalal Mar 16, 2025
182263c
"prelim gmm implementation"
apoorvalal Mar 16, 2025
8be1039
clean up GMM, implement bootstrap, example.
apoorvalal Mar 23, 2025
c3bbf7b
cleanup
apoorvalal Mar 23, 2025
450196c
add tests
apoorvalal Mar 23, 2025
d7af935
version bump
apoorvalal Mar 23, 2025
8380c22
Merge pull request #10 from apoorvalal/gmm
apoorvalal Mar 23, 2025
644b058
changelog generator
apoorvalal Mar 23, 2025
8190b43
version bump- this time its bumpier
apoorvalal Mar 23, 2025
4445066
Bump version to 0.2.1 to resolve PyPI filename conflict
apoorvalal Mar 23, 2025
27d6ae5
Fix wheel building in CI by removing continue-on-error flag
apoorvalal Mar 23, 2025
a617f0f
Bump version to 0.2.2
apoorvalal Mar 23, 2025
7a6b953
Fix packaging conflicts between setup.py and pyproject.toml for versi…
apoorvalal Mar 23, 2025
fd2db7f
Fix Ensmallen header discovery during wheel building
apoorvalal Mar 23, 2025
d2aad73
"revert to last build setup that got wheels working"
apoorvalal Mar 24, 2025
5979c8d
switch back to meson build
apoorvalal Mar 24, 2025
3540bb0
"fix unit test action"
apoorvalal Mar 24, 2025
f513082
"try uv in unit test action"
apoorvalal Mar 24, 2025
1d3888e
"local tests for now"
apoorvalal Mar 24, 2025
69f1f80
"add paper directory"
apoorvalal Mar 30, 2025
e5c23cf
"clean up results dir"
Apr 1, 2025
1c895f5
Merge pull request #12 from apoorvalal/joss
apoorvalal Apr 1, 2025
022d98e
fix image path in readme
apoorvalal Apr 8, 2025
13f23ab
fix gmm notebook
apoorvalal Apr 14, 2025
e4ae933
"fix to export simplex constrained FW again - need for synth"
apoorvalal May 18, 2025
80f400f
"version bump"
apoorvalal May 21, 2025
a2d93a4
add PyReport class and expose it
gauravmanmode May 22, 2025
4713d6c
Merge branch 'master' into master
gauravmanmode Jul 9, 2025
22c08fa
Merge branch 'master' into master
gauravmanmode Aug 12, 2025
171a920
add tests
gauravmanmode Aug 12, 2025
bbdaa53
add notebook
gauravmanmode Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
runs-on: ubuntu-22.04
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
Expand All @@ -88,12 +88,12 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools build meson meson-python pybind11
pip install setuptools build wheel pybind11

- name: Build sdist
run: python -m build --sdist --outdir wheelhouse/

- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4
with:
name: sdist
path: ./wheelhouse/*.tar.gz
Expand Down Expand Up @@ -140,9 +140,9 @@ jobs:
permissions:
contents: write
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- uses: actions/download-artifact@v4 # Updated to v4
- uses: actions/download-artifact@v4
with:
path: ./wheelhouse
merge-multiple: true # Downloads all artifacts
Expand Down Expand Up @@ -175,4 +175,4 @@ jobs:
release.upload_asset(f"./wheelhouse/{asset}")

print("Release created and assets uploaded successfully")
EOF
EOF
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.DS_Store
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# `pyensmallen`: python bindings for the [`ensmallen`](https://ensmallen.org/) library for numerical optimization

Very minimal python bindings for `ensmallen` library. Currently supports

Lightweight python bindings for `ensmallen` library. Currently supports
+ L-BFGS, with intended use for optimisation of smooth objectives for m-estimation
+ ADAM (and variants with different step-size routines) - makes use of ensmallen's templatization.
+ Frank-Wolfe, with intended use for constrained optimization of smooth losses
- constraints are either lp-ball (lasso, ridge, elastic-net) or simplex
+ (Generalized) Method of Moments estimation with ensmallen optimizers.
- this uses ensmallen for optimization [and relies on `jax` for automatic differentiation to get gradients and jacobians]. This is the main use case for `pyensmallen` and is the reason for the bindings.

See [ensmallen docs](https://ensmallen.org/docs.html) for details. The `notebooks/` directory walks through several statistical examples.

## speed
`pyensmallen` is very fast. A comprehensive set of benchmarks is available in the `benchmarks` directory. The benchmarks are run on an intel 12th gen framework laptop. Benchmarks vary data size (sample size and number of covariates) and parametric family (linear, logistic, poisson) and compare `pyensmallen` with `scipy` and `statsmodels` (I initially also tried to keep `cvxpy` in the comparison set but it was far too slow to be in the running). At large data sizes, pyensmallen is roughly an order of magnitude faster than scipy, which in turn is an order of magnitude faster than statsmodels. So, a single statsmodels run takes around as long as a pyensmallen run that naively uses the nonparametric bootstrap for inference. This makes the bootstrap a viable option for inference in large data settings.

See [ensmallen docs](https://ensmallen.org/docs.html) for details.
![](paper/benchmark_time.png)

Installation:
## Installation:

Make sure your system has `blas` installed. On macos, this can be done via brew. Linux systems should have it installed by default. If you are using conda, you can install `blas` via conda-forge.

Then,

__from pypi__

Expand All @@ -25,4 +37,3 @@ __from source__
__from wheel__
- download the appropriate `.whl` for your system from the more recent release listed in `Releases` and run `pip install ./pyensmallen...` OR
- copy the download url and run `pip install https://github.com/apoorvalal/pyensmallen/releases/download/<version>/pyensmallen-<version>-<pyversion>-linux_x86_64.whl`

16 changes: 12 additions & 4 deletions meson.build
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
project('pyensmallen',
['cpp'],
version: '0.0.3',
version: '0.2.5',
default_options: ['cpp_std=c++14'])

py = import('python').find_installation(pure: false)
Expand All @@ -10,15 +10,23 @@ armadillo_dep = dependency('Armadillo')

# Simplify to bare minimum
py.extension_module('_pyensmallen',
'pyensmallen/_pyensmallen.cpp',
'pyensmallen/module.cpp',
dependencies: [pybind11_dep, armadillo_dep, ensmallen_dep],
install: true,
install_dir: py.get_install_dir() / 'pyensmallen'
)

# Install Python sources separately
py.install_sources(
['pyensmallen/__init__.py', 'pyensmallen/losses.py'],
[
'pyensmallen/__init__.py',
'pyensmallen/losses.py',
'pyensmallen/gmm.py',
'pyensmallen/utils.hpp',
'pyensmallen/first_order.hpp',
'pyensmallen/newton_type.hpp',
'pyensmallen/constrained.hpp'
],
pure: false,
subdir: 'pyensmallen'
)
)
111 changes: 73 additions & 38 deletions notebooks/autodiff_mnl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@
" logits = multinomial_logit(params, X)\n",
" return -jnp.mean(logits[jnp.arange(y.shape[0]), y])\n",
"\n",
"\n",
"# Create JAX gradient function - autodiff!\n",
"grad_loss = jax.grad(loss)\n",
"\n",
"\n",
"# Define the objective function for pyensmallen\n",
"def objective(params, gradient, X, y):\n",
" params_jax = jax.device_put(params.reshape(D, K - 1))\n",
Expand All @@ -81,6 +83,7 @@
" gradient[:] = np.array(grad).flatten()\n",
" return float(loss_value)\n",
"\n",
"\n",
"# Pyensmallen optimization\n",
"start_time = time.time()\n",
"optimizer = pyensmallen.L_BFGS()\n",
Expand All @@ -107,19 +110,21 @@
"source": [
"# JAX optimization with Optax\n",
"start_time = time.time()\n",
"initial_params = jnp.array(initial_params.reshape(D, K-1))\n",
"initial_params = jnp.array(initial_params.reshape(D, K - 1))\n",
"\n",
"# Define the Optax optimizer (using Adam as an example)\n",
"optimizer = optax.adam(learning_rate=0.01)\n",
"opt_state = optimizer.init(initial_params)\n",
"\n",
"\n",
"@jax.jit\n",
"def step(params, opt_state, X, y):\n",
" loss_value, grads = jax.value_and_grad(loss)(params, X, y)\n",
" updates, opt_state = optimizer.update(grads, opt_state, params)\n",
" params = optax.apply_updates(params, updates)\n",
" return params, opt_state, loss_value\n",
"\n",
"\n",
"params = initial_params\n",
"for i in range(2000):\n",
" params, opt_state, _ = step(params, opt_state, X_jax, y_jax)\n",
Expand All @@ -143,36 +148,7 @@
{
"data": {
"text/plain": [
"(array([[ 1.76405235, 0.40015721, 0.97873798, 0. ],\n",
" [ 1.86755799, -0.97727788, 0.95008842, 0. ],\n",
" [-0.10321885, 0.4105985 , 0.14404357, 0. ],\n",
" [ 0.76103773, 0.12167502, 0.44386323, 0. ],\n",
" [ 1.49407907, -0.20515826, 0.3130677 , 0. ],\n",
" [-2.55298982, 0.6536186 , 0.8644362 , 0. ],\n",
" [ 2.26975462, -1.45436567, 0.04575852, 0. ],\n",
" [ 1.53277921, 1.46935877, 0.15494743, 0. ],\n",
" [-0.88778575, -1.98079647, -0.34791215, 0. ],\n",
" [ 1.23029068, 1.20237985, -0.38732682, 0. ]]),\n",
" array([[ 1.81593705, 0.42660712, 0.9881251 , 0. ],\n",
" [ 1.86311812, -0.99712363, 0.91224389, 0. ],\n",
" [-0.09046709, 0.40371618, 0.1855722 , 0. ],\n",
" [ 0.8096924 , 0.16937907, 0.42813934, 0. ],\n",
" [ 1.5803887 , -0.13466159, 0.33455137, 0. ],\n",
" [-2.57077101, 0.69204522, 0.86878803, 0. ],\n",
" [ 2.34925969, -1.47031258, -0.01797388, 0. ],\n",
" [ 1.63782881, 1.56277552, 0.16238417, 0. ],\n",
" [-0.86976324, -1.94815402, -0.34469109, 0. ],\n",
" [ 1.25135382, 1.17332414, -0.411664 , 0. ]]),\n",
" Array([[ 1.8109069 , 0.42548722, 0.9861881 , 0. ],\n",
" [ 1.8578703 , -0.9970413 , 0.9101146 , 0. ],\n",
" [-0.09009396, 0.40340352, 0.18550713, 0. ],\n",
" [ 0.8079424 , 0.16907169, 0.42753083, 0. ],\n",
" [ 1.5771316 , -0.1348491 , 0.3337257 , 0. ],\n",
" [-2.5637126 , 0.6927647 , 0.8692422 , 0. ],\n",
" [ 2.3421607 , -1.4703182 , -0.01972268, 0. ],\n",
" [ 1.635118 , 1.5612816 , 0.1624512 , 0. ],\n",
" [-0.86919826, -1.9467523 , -0.34499857, 0. ],\n",
" [ 1.2491224 , 1.1719129 , -0.41112202, 0. ]], dtype=float32))"
"(40,)"
]
},
"execution_count": 5,
Expand All @@ -181,20 +157,79 @@
}
],
"source": [
"true_coeffs, estimated_coeffs_ens, estimated_coeffs_jax"
"true_coeffs.reshape(-1).shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.76405235, 1.81593705, 1.81090689],\n",
" [ 0.40015721, 0.42660712, 0.42548722],\n",
" [ 0.97873798, 0.9881251 , 0.98618811],\n",
" [ 0. , 0. , 0. ],\n",
" [ 1.86755799, 1.86311812, 1.85787034],\n",
" [-0.97727788, -0.99712363, -0.99704129],\n",
" [ 0.95008842, 0.91224389, 0.91011459],\n",
" [ 0. , 0. , 0. ],\n",
" [-0.10321885, -0.09046709, -0.09009396],\n",
" [ 0.4105985 , 0.40371618, 0.40340352],\n",
" [ 0.14404357, 0.1855722 , 0.18550713],\n",
" [ 0. , 0. , 0. ],\n",
" [ 0.76103773, 0.8096924 , 0.80794239],\n",
" [ 0.12167502, 0.16937907, 0.16907169],\n",
" [ 0.44386323, 0.42813934, 0.42753083],\n",
" [ 0. , 0. , 0. ],\n",
" [ 1.49407907, 1.5803887 , 1.57713163],\n",
" [-0.20515826, -0.13466159, -0.1348491 ],\n",
" [ 0.3130677 , 0.33455137, 0.33372569],\n",
" [ 0. , 0. , 0. ],\n",
" [-2.55298982, -2.57077101, -2.5637126 ],\n",
" [ 0.6536186 , 0.69204522, 0.6927647 ],\n",
" [ 0.8644362 , 0.86878803, 0.86924219],\n",
" [ 0. , 0. , 0. ],\n",
" [ 2.26975462, 2.34925969, 2.3421607 ],\n",
" [-1.45436567, -1.47031258, -1.4703182 ],\n",
" [ 0.04575852, -0.01797388, -0.01972268],\n",
" [ 0. , 0. , 0. ],\n",
" [ 1.53277921, 1.63782881, 1.63511801],\n",
" [ 1.46935877, 1.56277552, 1.56128156],\n",
" [ 0.15494743, 0.16238417, 0.16245119],\n",
" [ 0. , 0. , 0. ],\n",
" [-0.88778575, -0.86976324, -0.86919826],\n",
" [-1.98079647, -1.94815402, -1.94675231],\n",
" [-0.34791215, -0.34469109, -0.34499857],\n",
" [ 0. , 0. , 0. ],\n",
" [ 1.23029068, 1.25135382, 1.24912238],\n",
" [ 1.20237985, 1.17332414, 1.17191291],\n",
" [-0.38732682, -0.411664 , -0.41112202],\n",
" [ 0. , 0. , 0. ]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.c_[true_coeffs.reshape(-1), estimated_coeffs_ens.reshape(-1), estimated_coeffs_jax.reshape(-1)]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pyensmallen optimization time: 0.7720785140991211\n",
"JAX optimization time: 0.6985325813293457\n",
"Pyensmallen optimization time: 1.0235660076141357\n",
"JAX optimization time: 2.4678776264190674\n",
"\n",
"Pyensmallen Mean Absolute Error: 0.026384408255362625\n",
"JAX Mean Absolute Error: 0.025860388\n"
Expand All @@ -215,7 +250,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -232,11 +267,11 @@
}
],
"source": [
"\n",
"def predict(coeffs, X):\n",
" logits = X @ coeffs\n",
" return np.argmax(logits, axis=1)\n",
"\n",
"\n",
"accuracy_ens = np.mean(predict(estimated_coeffs_ens, X) == y)\n",
"accuracy_jax = np.mean(predict(estimated_coeffs_jax, X) == y)\n",
"\n",
Expand All @@ -253,7 +288,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "econometrics",
"display_name": "metrics",
"language": "python",
"name": "python3"
},
Expand All @@ -267,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
"version": "3.12.6"
}
},
"nbformat": 4,
Expand Down
14 changes: 7 additions & 7 deletions notebooks/banana.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
{
"data": {
"text/plain": [
"array([1.92570882, 3.7081687 ])"
"array([1.84242719, 3.39499252])"
]
},
"execution_count": 5,
Expand All @@ -145,7 +145,7 @@
{
"data": {
"text/plain": [
"array([1.94800205, 3.79459674])"
"array([1.9529393 , 3.81386375])"
]
},
"execution_count": 6,
Expand All @@ -167,7 +167,7 @@
{
"data": {
"text/plain": [
"array([1.77346333, 3.14597227])"
"array([1.5916263 , 2.53414253])"
]
},
"execution_count": 7,
Expand All @@ -189,7 +189,7 @@
{
"data": {
"text/plain": [
"array([0.83742859, 0.69637474])"
"array([0.65721087, 0.42299591])"
]
},
"execution_count": 8,
Expand All @@ -211,7 +211,7 @@
{
"data": {
"text/plain": [
"array([1.74067192, 3.03090621])"
"array([1.90080892, 3.61223393])"
]
},
"execution_count": 9,
Expand All @@ -228,7 +228,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "econometrics",
"display_name": "metrics",
"language": "python",
"name": "python3"
},
Expand All @@ -242,7 +242,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
"version": "3.12.6"
}
},
"nbformat": 4,
Expand Down
Loading