Skip to content

Commit e6ba0c7

Browse files
authored
Add gather function; refine math, linalg, autograd, and constants (#44)
* Bump version to 0.0.18, add gather function, and enhance round function with decimals parameter * Update Publish.yml to enhance package build and publishing process for SAIUnit and BrainUnit * Add @set_module_as decorator to activation functions and update Publish.yml for Python setup and dependency installation * Refactor unit handling in functions and improve type assertions for clarity * Fix unit constants for parsec, barrel, and degree Fahrenheit for accuracy * Fix input dimension handling in FFT functions for improved accuracy * Refactor dynamic slicing functions to improve handling of Quantity objects and maintain unit consistency * Fix references to 'brainunit' in jacobian functions and add allow_int parameter for flexibility * Refactor linalg module to improve QR decomposition handling and clean up code formatting * Remove unused allow_int parameter from jacfwd function and update related documentation * Fix barrel unit scale adjustment in _unit_constants.py and update related test assertions * Update CI configuration to remove Python versions 3.11 and 3.12 from testing matrix
1 parent d26d65f commit e6ba0c7

22 files changed

+230
-89
lines changed

.github/workflows/CI-daily.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
strategy:
3838
fail-fast: false
3939
matrix:
40-
python-version: ["3.10", "3.11", "3.12", "3.13"]
40+
python-version: ["3.10", "3.13"]
4141
jax-version: ["0.4.38", "0.5.2", "0.6.0", ""]
4242
# Optional: Exclude incompatible combinations if needed
4343
# exclude:

.github/workflows/CI.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
strategy:
2929
fail-fast: false
3030
matrix:
31-
python-version: [ "3.10", "3.11", "3.12", "3.13" ]
31+
python-version: [ "3.10", "3.13" ]
3232

3333
steps:
3434
- name: Cancel Previous Runs
@@ -58,7 +58,7 @@ jobs:
5858
strategy:
5959
fail-fast: false
6060
matrix:
61-
python-version: [ "3.10", "3.11", "3.12", "3.13" ]
61+
python-version: [ "3.10", "3.13" ]
6262

6363
steps:
6464
- name: Cancel Previous Runs
@@ -124,7 +124,7 @@ jobs:
124124
strategy:
125125
fail-fast: false
126126
matrix:
127-
python-version: [ "3.10", "3.11", "3.12", "3.13" ]
127+
python-version: [ "3.10", "3.13" ]
128128

129129
steps:
130130
- name: Cancel Previous Runs

.github/workflows/Publish.yml

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: Publish to PyPI.org
22
on:
33
release:
4-
types: [published]
4+
types: [ published ]
55
jobs:
66
pypi:
77
runs-on: ubuntu-latest
@@ -10,8 +10,29 @@ jobs:
1010
uses: actions/checkout@v5
1111
with:
1212
fetch-depth: 0
13-
- run: python setup.py bdist_wheel
14-
- name: Publish package
13+
- name: Set up Python
14+
uses: actions/setup-python@v4
15+
with:
16+
python-version: '3.13'
17+
- name: Install build dependencies
18+
run: |
19+
python -m pip install --upgrade pip
20+
python -m pip install build setuptools wheel
21+
- name: build SAIUnit package
22+
run: |
23+
python -m build
24+
- name: Publish SAIUnit package
25+
uses: pypa/gh-action-pypi-publish@release/v1
26+
with:
27+
password: ${{ secrets.PYPI_API_TOKEN }}
28+
- name: build BrainUnit package
29+
run: |
30+
python -m pip install dist/saiunit* jax jaxlib numpy typing_extensions
31+
python make_brainunit_setup.py
32+
cd ./brainunit
33+
python -m build
34+
- name: Publish BrainUnit package
1535
uses: pypa/gh-action-pypi-publish@release/v1
1636
with:
1737
password: ${{ secrets.PYPI_API_TOKEN }}
38+
packages-dir: brainunit/dist/

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[build-system]
2-
requires = ["setuptools", "numpy", 'jax', 'jaxlib']
2+
requires = ["setuptools"]
33
build-backend = "setuptools.build_meta"
44

55

saiunit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
__version__ = "0.0.16"
16+
__version__ = "0.0.18"
1717

1818
from . import _matplotlib_compat
1919
from . import autograd

saiunit/_unit_constants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
au = astronomical_unit = Unit.create(meter.dim, name="astronomical unit", dispname="AU", scale=meter.scale + 11,
8585
factor=1.495978707)
8686
light_year = Unit.create(meter.dim, name="light year", dispname="ly", scale=meter.scale + 15, factor=9.460730777119564)
87-
parsec = Unit.create(meter.dim, name="parsec", dispname="pc", scale=meter.scale + 16, factor=3.085677581491367e16)
87+
parsec = Unit.create(meter.dim, name="parsec", dispname="pc", scale=meter.scale + 16, factor=3.085677581491367)
8888

8989
# ----- Pressure -----
9090
atm = atmosphere = Unit.create(pascal.dim, name="atmosphere", dispname="atm", scale=pascal.scale + 5, factor=1.013249966)
@@ -104,7 +104,7 @@
104104
factor=2.95735295625)
105105
fluid_ounce_imp = Unit.create(meter3.dim, name="imperial fluid ounce", dispname="fl oz imp", scale=meter3.scale - 5,
106106
factor=2.84130742)
107-
bbl = barrel = Unit.create(meter3.dim, name="barrel", dispname="bbl", scale=meter3.scale + 2, factor=1.5898729493)
107+
bbl = barrel = Unit.create(meter3.dim, name="barrel", dispname="bbl", scale=meter3.scale + 2, factor=1.58987294928)
108108

109109
# ----- Speed -----
110110
speed_unit = meter / second
@@ -118,8 +118,8 @@
118118
# ----- Temperature -----
119119
# TODO: The relationship between Celsius and Kelvin should be linear, but the current implementation is not.
120120
# zero_Celsius = Unit.create(kelvin.dim, name="zero Celsius", dispname="0°C", scale=kelvin.scale, factor=273.15)
121-
degree_Fahrenheit = Unit.create(kelvin.dim, name="degree Fahrenheit", dispname="°F", scale=kelvin.scale + 2,
122-
factor=2.55927778)
121+
degree_Fahrenheit = Unit.create(kelvin.dim, name="degree Fahrenheit", dispname="°F", scale=kelvin.scale,
122+
factor=5/9)
123123

124124
# ----- Energy -----
125125
eV = electron_volt = Unit.create(joule.dim, name="electronvolt", dispname="eV", scale=joule.scale - 19, factor=1.602176565)
@@ -130,7 +130,7 @@
130130
Btu = Btu_IT = Unit.create(joule.dim, name="British thermal unit (International Table)", dispname="Btu IT",
131131
scale=joule.scale + 3, factor=1.05505585262)
132132
Btu_th = Unit.create(joule.dim, name="British thermal unit (thermochemical)", dispname="Btu th", scale=joule.scale + 3,
133-
factor=1.0543499999744)
133+
factor=1.05435026444)
134134
ton_TNT = Unit.create(joule.dim, name="ton of TNT", dispname="ton TNT", scale=joule.scale + 9, factor=4.184)
135135

136136
# ----- Power -----

saiunit/autograd/_jacobian.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
_check_output_dtype_jacfwd
3030
)
3131
from jax.api_util import argnums_partial
32-
from jax.extend import linear_util
3332

3433
from saiunit._base import Quantity, maybe_decimal, get_magnitude, get_unit
3534
from saiunit._compatible_import import safe_map, wrap_init
@@ -134,15 +133,15 @@ def jacrev(
134133
In particular, an array is produced (with no pytrees involved) when the
135134
function input ``x`` and output ``fun(x)`` are each a single array, as in the
136135
``simple_function`` example above. If ``fun(x)`` has shape ``(out1, out2, ...)`` and ``x``
137-
has shape ``(in1, in2, ...)`` then ``saiunit.autograd.jacrec(fun)(x)`` has shape
136+
has shape ``(in1, in2, ...)`` then ``saiunit.autograd.jacrev(fun)(x)`` has shape
138137
``(out1, out2, ..., in1, in2, ..., in1, in2, ...)``. To flatten pytrees into
139138
1D vectors, consider using :py:func:`jax.flatten_util.flatten_pytree`.
140139
"""
141140
_check_callable(fun)
142141

143142
@wraps(fun)
144143
def jacfun(*args, **kwargs):
145-
f = wrap_init(fun, args, kwargs, 'brainunit.autograd.jacrev')
144+
f = wrap_init(fun, args, kwargs, 'saiunit.autograd.jacrev')
146145
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
147146
jax.tree.map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
148147
if not has_aux:
@@ -238,7 +237,7 @@ def jacfwd(
238237
fun: Callable,
239238
argnums: int | Sequence[int] = 0,
240239
has_aux: bool = False,
241-
holomorphic: bool = False
240+
holomorphic: bool = False,
242241
) -> Callable:
243242
"""
244243
Physical unit-aware version of `jax.jacfwd <https://jax.readthedocs.io/en/latest/_autosummary/jax.jacfwd.html>`_.
@@ -276,7 +275,7 @@ def jacfwd(
276275
[[3., 0.],
277276
[0., 4.]] * ohm)
278277
279-
`jacfwd` is a generalization of the usual definition of the JacFwd(Jacobian Reverse Mode).
278+
`jacfwd` is a generalization of the usual definition of the JacFwd(Jacobian Forward Mode).
280279
that supports nested Python containers (i.e. pytrees) as inputs and outputs.
281280
The tree structure of ``saiunit.autograd.jacfwd(fun)(x)`` is given by forming a tree
282281
product of the structure of ``fun(x)`` with a tree product of two copies of
@@ -324,7 +323,7 @@ def jacfwd(
324323
In particular, an array is produced (with no pytrees involved) when the
325324
function input ``x`` and output ``fun(x)`` are each a single array, as in the
326325
``simple_function`` example above. If ``fun(x)`` has shape ``(out1, out2, ...)`` and ``x``
327-
has shape ``(in1, in2, ...)`` then ``saiunit.autograd.jacrec(fun)(x)`` has shape
326+
has shape ``(in1, in2, ...)`` then ``saiunit.autograd.jacfwd(fun)(x)`` has shape
328327
``(out1, out2, ..., in1, in2, ..., in1, in2, ...)``. To flatten pytrees into
329328
1D vectors, consider using :py:func:`jax.flatten_util.flatten_pytree`.
330329
"""
@@ -333,7 +332,7 @@ def jacfwd(
333332

334333
@wraps(fun)
335334
def jacfun(*args, **kwargs):
336-
f = wrap_init(fun, args, kwargs, 'brainunit.autograd.jacfwd')
335+
f = wrap_init(fun, args, kwargs, 'saiunit.autograd.jacfwd')
337336
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
338337
jax.tree.map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
339338
if not has_aux:

saiunit/constants.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
# ==============================================================================
1515

1616
r"""
17-
A module providing some physical units as `Quantity` objects. Note that these
18-
units are not imported by wildcard imports, they
17+
A module providing some physical constants as `Quantity` objects. Note that these
18+
constants are not imported by wildcard imports, they
1919
have to be imported explicitly. You can use ``import ... as ...`` to import them
2020
with shorter names, e.g.::
2121
22-
from saiunit.constants import faraday_constant as F
22+
from saiunit.constants import faraday as F
2323
2424
The available constants are:
2525
@@ -61,18 +61,18 @@
6161
from ._unit_constants import speed_unit
6262

6363
__all__ = [
64-
'arcmin', 'arcminute', 'arcsec', 'arcsecond', 'atomic_mass', 'au', 'astronomical_unit',
65-
'angstrom', 'atm', 'atmosphere', 'avogadro', 'bar', 'blob', 'boltzmann', 'Btu', 'Btu_IT',
64+
'acre', 'arcmin', 'arcminute', 'arcsec', 'arcsecond', 'atomic_mass', 'au', 'astronomical_unit',
65+
'angstrom', 'atm', 'atmosphere', 'avogadro', 'bar', 'barrel', 'bbl', 'blob', 'boltzmann', 'Btu', 'Btu_IT',
6666
'Btu_th', 'carat', 'calorie', 'calorie_IT', 'calorie_th', 'day', 'degree', 'degree_Fahrenheit',
6767
'dyn', 'dyne', 'eV', 'electron_mass', 'electric', 'electronvolt', 'elementary_charge', 'erg',
6868
'faraday', 'fermi', 'fluid_ounce', 'fluid_ounce_US', 'fluid_ounce_imp', 'foot', 'gas', 'grain',
6969
'gallon', 'gallon_US', 'gallon_imp', 'gram', 'hectare', 'hour', 'hp', 'horsepower', 'IMF',
70-
'inch', 'julian_year', 'kelvin', 'kgf', 'kilogram_force', 'knot', 'lb', 'lbf', 'light_year',
71-
'long_ton', 'mach', 'magnetic', 'meter', 'metric_ton', 'micron', 'mil', 'mile', 'minute',
70+
'inch', 'julian_year', 'kelvin', 'kgf', 'kilogram_force', 'kmh', 'knot', 'lb', 'lbf', 'light_year',
71+
'long_ton', 'mach', 'magnetic', 'meter', 'metric_ton', 'micron', 'mil', 'mile', 'minute', 'mmHg',
7272
'molar_mass', 'month', 'mph', 'nautical_mile', 'newton', 'ounce', 'oz', 'parsec', 'pica',
73-
'point', 'pound', 'psi', 'radian', 'second', 'short_ton', 'slug', 'slinch', 'speed_unit',
73+
'point', 'pound', 'pound_force', 'psi', 'radian', 'second', 'short_ton', 'slug', 'slinch', 'speed_unit',
7474
'stone', 'survey_foot', 'survey_mile', 'torr', 'troy_ounce', 'troy_pound', 'ton_TNT', 'week',
75-
'watt', 'year', 'zero_celsius'
75+
'watt', 'yard', 'year', 'zero_celsius'
7676
]
7777

7878
#: Avogadro constant (http://physics.nist.gov/cgi-bin/cuu/Value?na)
@@ -90,9 +90,9 @@
9090
#: gas constant (http://physics.nist.gov/cgi-bin/cuu/Value?r)
9191
gas = np.asarray(8.3144598) * (joule / mole / kelvin)
9292
#: Magnetic constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu0)
93-
magnetic = np.asarray(4 * np.pi * 1e-7) * (newton / amp ** 2)
93+
magnetic = np.asarray(1.25663706212e-6) * (newton / amp ** 2)
9494
#: Molar mass constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu)
95-
molar_mass = np.asarray(1.) * (gram / mole)
95+
molar_mass = np.asarray(1e-3) * (kilogram / mole)
9696
#: zero degree Celsius
9797
zero_celsius = np.asarray(273.15) * kelvin
9898

@@ -161,7 +161,9 @@
161161
bbl = barrel = np.asarray(1.58987294928e2) * meter3 # Barrel (oil)
162162

163163
# ----- Temperature -----
164-
degree_Fahrenheit = np.asarray(2.55927778e2) * kelvin # Fahrenheit
164+
# Note: Fahrenheit is a temperature scale, not a unit. Use conversion functions instead.
165+
# This constant represents the conversion factor from Fahrenheit to Kelvin degrees
166+
degree_Fahrenheit = np.asarray(5/9) * kelvin # Fahrenheit degree size in Kelvin
165167

166168
# ----- Speed -----
167169
kmh = np.asarray(2.77777778e-1) * speed_unit # Kilometer per hour

saiunit/constants_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def test_quantity_constants_and_unit_constants(self):
7979
import saiunit.constants as quantity_constants
8080
import saiunit._unit_constants as unit_constants
8181
for c in constants_list:
82+
print(c)
8283
q_c = getattr(quantity_constants, c)
8384
u_c = getattr(unit_constants, c)
8485
assert u.math.isclose(
85-
q_c.to_decimal(q_c.unit),
86-
(1. * u_c).to_decimal(q_c.unit)
86+
q_c.to_decimal(q_c.unit), (1. * u_c).to_decimal(q_c.unit)
8787
), f"Mismatch between {c} in quantity_constants and unit_constants"

saiunit/fft/_fft_change_unit.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,8 @@ def fftn(
625625
>>> u.math.allclose(x, u.fft.ifftn(x_fftn))
626626
Array(True, dtype=bool)
627627
"""
628-
n = _calculate_fftn_dimension(a.ndim, axes)
628+
input_ndim = a.ndim if hasattr(a, 'ndim') else jnp.asarray(a).ndim
629+
n = _calculate_fftn_dimension(input_ndim, axes)
629630
_unit_change_fun = lambda u: u * (second ** n)
630631
# TODO: may cause computation overhead?
631632
fftn._unit_change_fun = _unit_change_fun
@@ -730,7 +731,8 @@ def rfftn(
730731
>>> u.fft.rfftn(x1)
731732
ArrayImpl([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64) * meter * second
732733
"""
733-
n = _calculate_fftn_dimension(a.ndim, axes)
734+
input_ndim = a.ndim if hasattr(a, 'ndim') else jnp.asarray(a).ndim
735+
n = _calculate_fftn_dimension(input_ndim, axes)
734736
_unit_change_fun = lambda u: u * (second ** n)
735737
# TODO: may cause computation overhead?
736738
rfftn._unit_change_fun = _unit_change_fun
@@ -978,7 +980,8 @@ def ifftn(
978980
ArrayImpl([[ 2.5 +0.j , 0. -0.58j, 0. +0.58j],
979981
[ 0.17+0.j , -0.83-0.29j, -0.83+0.29j]], dtype=complex64) * meter / second2
980982
"""
981-
n = _calculate_fftn_dimension(a.ndim, axes)
983+
input_ndim = a.ndim if hasattr(a, 'ndim') else jnp.asarray(a).ndim
984+
n = _calculate_fftn_dimension(input_ndim, axes)
982985
_unit_change_fun = lambda u: u / (second ** n)
983986
# TODO: may cause computation overhead?
984987
ifftn._unit_change_fun = _unit_change_fun
@@ -1067,7 +1070,8 @@ def irfftn(
10671070
[[-2., -2., -2.],
10681071
[-2., -2., -2.]]], dtype=float32) * meter / second
10691072
"""
1070-
n = _calculate_fftn_dimension(a.ndim, axes)
1073+
input_ndim = a.ndim if hasattr(a, 'ndim') else jnp.asarray(a).ndim
1074+
n = _calculate_fftn_dimension(input_ndim, axes)
10711075
_unit_change_fun = lambda u: u / (second ** n)
10721076
# TODO: may cause computation overhead?
10731077
irfftn._unit_change_fun = _unit_change_fun
@@ -1188,7 +1192,7 @@ def rfftfreq(
11881192
) -> Union[Quantity, jax.typing.ArrayLike]:
11891193
"""Return sample frequencies for the discrete Fourier transform.
11901194
1191-
saiunit implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
1195+
saiunit implementation of :func:`numpy.fft.rfftfreq`. Returns frequencies appropriate
11921196
for use with the outputs of :func:`~saiunit.fft.rfft` and
11931197
:func:`~saiunit.fft.irfft`.
11941198
@@ -1210,8 +1214,8 @@ def rfftfreq(
12101214
Example:
12111215
>>> import saiunit as u
12121216
>>> import jax.numpy as jnp
1213-
>>> x = jnp.array([1, 2, 3, 4]) * u.second
1214-
>>> u.fft.rfftfreq(4, x)
1217+
>>> d = 1 * u.second
1218+
>>> u.fft.rfftfreq(4, d)
12151219
ArrayImpl([0. , 0.25, 0.5 ], dtype=float32) * hertz
12161220
"""
12171221
if isinstance(d, Quantity):

0 commit comments

Comments
 (0)