Skip to content

Commit 93bcdfc

Browse files
committed
"patch jax optional import bug; add dynamic versioning"
1 parent 73cb671 commit 93bcdfc

File tree

6 files changed

+24
-20
lines changed

6 files changed

+24
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,4 @@ cython_debug/
164164
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
165165
#.idea/
166166
.DS_Store
167+
.claude/settings.local.json

meson.build

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
project('pyensmallen',
22
['cpp'],
3-
version: '0.2.9',
43
default_options: ['cpp_std=c++14'])
54

65
py = import('python').find_installation(pure: false)

pyensmallen/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
from ._pyensmallen import *
1919
from .losses import linear_obj, logistic_obj, poisson_obj
20-
from .gmm import EnsmallenEstimator
21-
2220

21+
# Core optimizers and loss functions (always available)
2322
__all__ = [
2423
"L_BFGS",
2524
"FrankWolfe",
@@ -32,5 +31,12 @@
3231
"linear_obj",
3332
"logistic_obj",
3433
"poisson_obj",
35-
"EnsmallenEstimator",
3634
]
35+
36+
# Conditionally import GMM functionality (requires JAX)
37+
try:
38+
from .gmm import EnsmallenEstimator
39+
__all__.append("EnsmallenEstimator")
40+
except ImportError:
41+
# JAX not available - install with: pip install pyensmallen[gmm]
42+
pass

pyensmallen/gmm.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,10 @@
99

1010
from . import _pyensmallen as pe
1111

12-
try:
13-
import jax
14-
import jax.numpy as jnp
12+
import jax
13+
import jax.numpy as jnp
1514

16-
jax.config.update("jax_enable_x64", True)
17-
except ImportError:
18-
warnings.warn(
19-
"JAX is not installed. JAX autodifferentiation will not be available. "
20-
"Install JAX to enable autodiff support for GMM estimation.",
21-
ImportWarning,
22-
)
23-
jax = None
24-
jnp = None
15+
jax.config.update("jax_enable_x64", True)
2516

2617

2718
class EnsmallenEstimator:
@@ -41,6 +32,7 @@ def __init__(
4132
moment_cond: Function that computes moment conditions. Should be JAX-compatible.
4233
weighting_matrix: Either "optimal" for two-step GMM or a custom weighting matrix
4334
"""
35+
4436
self.moment_cond = moment_cond
4537
self.weighting_matrix = weighting_matrix
4638

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ requires = [
55
"pybind11>=2.4",
66
"meson-python",
77
"numpy",
8+
"setuptools-scm",
89
]
910
build-backend = "mesonpy"
1011

@@ -13,7 +14,7 @@ skip = ["cp36-*", "*-win32", "*-manylinux_i686"]
1314

1415
[project]
1516
name = 'pyensmallen'
16-
version = '0.2.9'
17+
dynamic = ["version"]
1718
description = 'Python bindings for the Ensmallen library.'
1819
readme = 'README.md'
1920
requires-python = '>=3.10'
@@ -22,3 +23,9 @@ authors = [
2223
{ name = 'Apoorva Lal', email = 'lal.apoorva@gmail.com' },
2324
{ name = 'Matthew Wardrop', email = 'mpwardrop@gmail.com' },
2425
]
26+
27+
[project.optional-dependencies]
28+
gmm = ["jax", "jaxlib"]
29+
30+
[tool.setuptools_scm]
31+
write_to = "pyensmallen/_version.py"

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from setuptools.command.build_ext import build_ext
77
from setuptools.dist import Distribution
88

9-
__version__ = "0.2.9"
109

1110

1211
class BinaryDistribution(Distribution):
@@ -99,14 +98,14 @@ def build_extensions(self):
9998

10099
setup(
101100
name="pyensmallen",
102-
version=__version__,
103101
author="Apoorva Lal",
104102
author_email="lal.apoorva@gmail.com",
105103
description="Python bindings for the ensmallen optimization library",
106104
long_description="",
107105
ext_modules=ext_modules,
108106
install_requires=["pybind11>=2.4"],
109-
setup_requires=["pybind11>=2.4"],
107+
setup_requires=["pybind11>=2.4", "setuptools-scm"],
108+
use_scm_version=True,
110109
cmdclass={"build_ext": BuildExt},
111110
packages=find_packages(),
112111
include_package_data=True,

0 commit comments

Comments
 (0)