Skip to content

Commit dd25ec7

Browse files
authored
[jax] Improve exprel numerical stability and add gradient tests (#57)
* Bump version to 0.1.4 and improve numerical stability in exprel function * Fix formatting in pyproject.toml for version and keywords * Implement numerically stable exprel function and add comprehensive tests * Update lax array creation to use jax.numpy for zero initialization and update CI jax version * Remove redundant lax array creation tests for improved clarity * Update changelog for version 0.1.4: add exprel function, improve lax array creation, and enhance CI compatibility
1 parent c4b356b commit dd25ec7

File tree

11 files changed

+1109
-91
lines changed

11 files changed

+1109
-91
lines changed

.github/workflows/CI-daily.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
fail-fast: false
3939
matrix:
4040
python-version: ["3.13"]
41-
jax-version: ["0.4.38", "0.5.2", "0.6.0", "0.7.2", ""]
41+
jax-version: ["0.4.38", "0.5.2", "0.6.0", "0.7.2", "0.8.0", ""]
4242

4343
steps:
4444
- name: Cancel Previous Runs

changelog.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Release Notes
22

3+
## Version 0.1.4
4+
5+
- Added numerically stable `exprel` function with comprehensive test coverage
6+
- Updated lax array creation to use `jax.numpy` for zero initialization
7+
- Updated CI JAX version for improved compatibility
8+
- Improved code quality and removed redundant tests
9+
310
## Version 0.1.3
411

512
- Compatible with `jax>=0.8.2`

pyproject.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55

66
[tool.setuptools.dynamic]
7-
version = {attr = "saiunit.__version__"}
7+
version = { attr = "saiunit.__version__" }
88

99
[tool.setuptools.packages.find]
1010
exclude = [
@@ -44,7 +44,7 @@ classifiers = [
4444
'Topic :: Software Development :: Libraries',
4545
]
4646

47-
keywords = ['physical unit', 'physical quantity', 'brain modeling', 'scientific computing', 'AI for science',]
47+
keywords = ['physical unit', 'physical quantity', 'brain modeling', 'scientific computing', 'AI for science', ]
4848

4949
dependencies = [
5050
'jax',
@@ -64,9 +64,7 @@ repository = 'https://github.com/chaobrain/saiunit'
6464
"Documentation" = "https://saiunit.readthedocs.io/"
6565

6666
[project.optional-dependencies]
67-
testing = [
68-
'pytest',
69-
]
67+
testing = ['pytest']
7068
cpu = ["jax[cpu]"]
7169
cuda12 = ["jax[cuda12]"]
7270
cuda13 = ["jax[cuda13]"]

saiunit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
__version__ = "0.1.3"
16+
__version__ = "0.1.4"
1717
__version_info__ = tuple(map(int, __version__.split(".")))
1818

1919
from . import _matplotlib_compat

saiunit/_compatible_import.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,19 @@
2424
'safe_map',
2525
'unzip2',
2626
'wrap_init',
27+
'Primitive',
2728
]
2829

2930
T = TypeVar("T")
3031
T1 = TypeVar("T1")
3132
T2 = TypeVar("T2")
3233
T3 = TypeVar("T3")
3334

35+
if jax.__version_info__ < (0, 4, 38):
36+
from jax.core import Primitive
37+
else:
38+
from jax.extend.core import Primitive
39+
3440

3541
def wrap_init(fun: Callable, args: tuple, kwargs: dict, name: str):
3642
if jax.__version_info__ < (0, 6, 0):

saiunit/lax/_lax_array_creation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import jax
1919
from jax import lax
20+
import jax.numpy as jnp
2021

2122
from .._base import Unit, Quantity
2223
from .._misc import set_module_as, maybe_custom_array
@@ -28,7 +29,8 @@
2829
'zeros_like_array',
2930

3031
# array creation(misc)
31-
'iota', 'broadcasted_iota',
32+
'iota',
33+
'broadcasted_iota',
3234
]
3335

3436

@@ -43,13 +45,13 @@ def zeros_like_array(
4345
if unit is not None:
4446
assert isinstance(unit, Unit), 'unit must be an instance of Unit.'
4547
x = x.in_unit(unit)
46-
return Quantity(lax.zeros_like_array(x.mantissa), unit=x.unit)
48+
return Quantity(jnp.zeros_like(x.mantissa), unit=x.unit)
4749
else:
4850
if unit is not None:
4951
assert isinstance(unit, Unit), 'unit must be an instance of Unit.'
50-
return lax.zeros_like_array(x) * unit
52+
return jnp.zeros_like(x) * unit
5153
else:
52-
return lax.zeros_like_array(x)
54+
return jnp.zeros_like(x)
5355

5456

5557
# array creation (misc)

saiunit/lax/_lax_array_creation_test.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,6 @@ def __init__(self, *args, **kwargs):
3737

3838
print()
3939

40-
@parameterized.product(
41-
array=[jnp.array([1.0, 2.0]), jnp.array([[1.0, 2.0], [3.0, 4.0]])],
42-
unit=[second, meter]
43-
)
44-
def test_lax_array_creation_given_array(self, array, unit):
45-
bulax_fun_list = [getattr(bulax, fun) for fun in lax_array_creation_given_array]
46-
lax_fun_list = [getattr(lax, fun) for fun in lax_array_creation_given_array]
47-
48-
for bulax_fun, lax_fun in zip(bulax_fun_list, lax_fun_list):
49-
print(f'fun: {bulax_fun.__name__}')
50-
51-
result = bulax_fun(array)
52-
expected = lax_fun(array)
53-
assert_quantity(result, expected)
54-
55-
result = bulax_fun(array, unit=unit)
56-
expected = lax_fun(array)
57-
assert_quantity(result, expected, unit=unit)
58-
5940
@parameterized.product(
6041
value=[1, 10, 100],
6142
unit=[second, meter]

0 commit comments

Comments
 (0)