Skip to content

Commit 74e2916

Browse files
Add __main__ for mle
1 parent 95f5760 commit 74e2916

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,13 @@ Gsd uses unitest for testing. To run the tests, use the following command:
4848
```
4949
$ hatch run test
5050
```
51+
52+
### Standalone estimator
53+
54+
You can quickly estimate GSD parameters from a command line interface
55+
56+
```shell
57+
python3 -m gsd 0 12 13 4 0
58+
```
59+
60+
GSDParams(psi=Array(2.6272388, dtype=float32), rho=Array(0.9041536, dtype=float32))

src/gsd/__main__.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
import jax
1+
import argparse
22
import jax.numpy as jnp
3-
4-
import gsd
3+
from gsd import fit_mle
54

65
if __name__ == '__main__':
7-
gsd.log_prob(1., 0.5, 2)
8-
m = gsd.mean(3., 0.7)
9-
v = gsd.variance(3., 0.7)
10-
k = jax.random.key(43)
11-
s = gsd.sample(3., 0.7, (24,), k)
6+
parser = argparse.ArgumentParser(description='GSD estimator')
127

13-
jnp.mean(s), jnp.var(s)
8+
parser.add_argument("response", nargs=5, type=int,
9+
metavar=("num1", "num2", "num3", "num4", "num5"),
10+
help="List of 5 counts")
1411

15-
# jax.vmap(gsd.log_prob, in_axes=(None,None,0))(3.,0.7,s)
12+
args = parser.parse_args()
1613

17-
print('test')
14+
hat,_ = fit_mle(data=jnp.asarray(args.response, dtype=jnp.float32))
15+
print(hat)

src/gsd/fit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jax import Array
77
from jax.typing import ArrayLike
88

9-
from gsd.gsd import vmax, vmin, log_prob
9+
from .gsd import vmax, vmin, log_prob
1010

1111

1212
class GSDParams(NamedTuple):
@@ -108,3 +108,4 @@ def cond_fun(state: OptState) -> bool:
108108
OptState(params=theta0, previous_params=jtu.tree_map(lambda _: jnp.inf, theta0),
109109
count=0))
110110
return opt_state.params, opt_state
111+

0 commit comments

Comments
 (0)