Skip to content

Commit 1ed3ac6

Browse files
Merge pull request #21 from gsd-authors/dev
Fix imports and typing
2 parents 4c8a75d + 6d5f107 commit 1ed3ac6

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

src/gsd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.3dev"
1+
__version__ = "0.2.3"
22
from gsd.fit import fit_moments as fit_moments, GSDParams as GSDParams
33
from gsd.gsd import (
44
log_prob as log_prob,

src/gsd/gsd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import jax.numpy as jnp
55
import numpy as np
66
from jax import Array
7-
from jax.random import PRNGKeyArray
87
from jax.scipy.special import betaln
98
from jax.typing import ArrayLike
109

@@ -113,7 +112,7 @@ def variance(psi: ArrayLike, rho: ArrayLike) -> Array:
113112
return rho * vmin(psi) + (1 - rho) * vmax(psi)
114113

115114

116-
def sample(psi: ArrayLike, rho: ArrayLike, shape: Shape, key: PRNGKeyArray) -> Array:
115+
def sample(psi: ArrayLike, rho: ArrayLike, shape: Shape, key: Array) -> Array:
117116
"""Sample from GSD
118117
119118
:param psi: mean

tests/experimental_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from gsd.experimental.fit import GridEstimator
1717
from gsd.experimental.max_entropy import MaxEntropyGSD, vmax
1818
from gsd.fit import fit_moments, GSDParams, log_pmax, pairs, pmax
19-
from gsd.gsd import make_softvmin, vmax, vmin
19+
from gsd.gsd import make_softvmin, vmin
2020

2121

2222
class FitTestCase(unittest.TestCase):

0 commit comments

Comments
 (0)