Skip to content

Commit 5e77111

Browse files
committed
Refactor module exports; update __module__ attributes to 'braintools.param'
1 parent f8adc47 commit 5e77111

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

braintools/param/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class Transform(ABC):
6262
... def inverse(self, y):
6363
... return jnp.sqrt(y)
6464
"""
65-
__module__ = 'braintools'
65+
__module__ = 'braintools.param'
6666

6767
def __call__(self, x: ArrayLike) -> Array:
6868
r"""
@@ -138,7 +138,7 @@ def inverse(self, y: ArrayLike) -> Array:
138138

139139

140140
class Identity(Transform):
141-
__module__ = 'braintools'
141+
__module__ = 'braintools.param'
142142

143143
def forward(self, x: ArrayLike) -> Array:
144144
return x

braintools/param/_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class Param(brainstate.ParamState, u.CustomArray):
109109
the unconstrained ``.value``, which often leads to better optimization dynamics.
110110
111111
"""
112-
__module__ = 'braintools'
112+
__module__ = 'braintools.param'
113113

114114
def __init__(
115115
self,

braintools/param/_transform.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class Sigmoid(Transform):
133133
>>> y = transform.forward(x)
134134
>>> # y ≈ [0.0]
135135
"""
136-
__module__ = 'braintools'
136+
__module__ = 'braintools.param'
137137

138138
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:
139139
r"""
@@ -254,7 +254,7 @@ class Softplus(Transform):
254254
>>> y = transform.forward(x)
255255
>>> # y ≈ [2.693]
256256
"""
257-
__module__ = 'braintools'
257+
__module__ = 'braintools.param'
258258

259259
def __init__(self, lower: ArrayLike) -> None:
260260
"""
@@ -369,7 +369,7 @@ class NegSoftplus(Softplus):
369369
>>> y = transform.forward(x)
370370
>>> # y ≈ [-2.693]
371371
"""
372-
__module__ = 'braintools'
372+
__module__ = 'braintools.param'
373373

374374
def __init__(self, upper: ArrayLike) -> None:
375375
"""
@@ -439,7 +439,7 @@ class Log(Transform):
439439
lower : array_like
440440
Lower bound of the target interval.
441441
"""
442-
__module__ = 'braintools'
442+
__module__ = 'braintools.param'
443443

444444
def __init__(self, lower: ArrayLike) -> None:
445445
super().__init__()
@@ -459,7 +459,7 @@ class Exp(Transform):
459459
460460
Equivalent to Log; provided for explicit naming.
461461
"""
462-
__module__ = 'braintools'
462+
__module__ = 'braintools.param'
463463

464464
def __init__(self, lower: ArrayLike) -> None:
465465
super().__init__()
@@ -480,7 +480,7 @@ class Tanh(Transform):
480480
y = lower + width * (tanh(x) + 1) / 2
481481
x = arctanh(2 * (y - lower) / width - 1)
482482
"""
483-
__module__ = 'braintools'
483+
__module__ = 'braintools.param'
484484

485485
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:
486486
super().__init__()
@@ -503,7 +503,7 @@ class Softsign(Transform):
503503
y = lower + width * (x / (1 + |x|) + 1) / 2
504504
x = z / (1 - |z|), where z = 2 * (y - lower) / width - 1, z in (-1, 1)
505505
"""
506-
__module__ = 'braintools'
506+
__module__ = 'braintools.param'
507507

508508
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:
509509
super().__init__()
@@ -562,7 +562,7 @@ class Clip(Transform):
562562
>>> y = transform.forward(x)
563563
>>> # y = [0.0, 0.5, 1.0]
564564
"""
565-
__module__ = 'braintools'
565+
__module__ = 'braintools.param'
566566

567567
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:
568568
"""
@@ -672,7 +672,7 @@ class Affine(Transform):
672672
>>> fahrenheit = transform.forward(celsius)
673673
>>> # fahrenheit ≈ [32.0, 212.0]
674674
"""
675-
__module__ = 'braintools'
675+
__module__ = 'braintools.param'
676676

677677
def __init__(self, scale: ArrayLike, shift: ArrayLike):
678678
"""
@@ -780,7 +780,7 @@ class Chain(Transform):
780780
>>> softplus = Softplus(0)
781781
>>> chain = Chain(standardize, softplus)
782782
"""
783-
__module__ = 'braintools'
783+
__module__ = 'braintools.param'
784784

785785
def __init__(self, *transforms: Sequence[Transform]) -> None:
786786
"""
@@ -913,7 +913,7 @@ class Masked(Transform):
913913
>>> sigmoid = Sigmoid(-1, 1)
914914
>>> transform = Masked(corr_mask, sigmoid)
915915
"""
916-
__module__ = 'braintools'
916+
__module__ = 'braintools.param'
917917

918918
def __init__(self, mask: ArrayLike, transform: Transform) -> None:
919919
"""
@@ -1059,7 +1059,7 @@ class Custom(Transform):
10591059
... return ((y / 2) + 1) ** 2 - 1
10601060
>>> boxcox = Custom(boxcox_forward, boxcox_inverse)
10611061
"""
1062-
__module__ = 'braintools'
1062+
__module__ = 'braintools.param'
10631063

10641064
def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:
10651065
"""

0 commit comments

Comments
 (0)