33[ ![ active] ( https://www.repostatus.org/badges/latest/active.svg )] ( https://www.repostatus.org/#active )
44[ ![ ci] ( https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml/badge.svg )] ( https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml )
55[ ![ coverage] ( https://codecov.io/gh/ramsey-devs/ramsey/branch/main/graph/badge.svg?token=dn1xNBSalZ )] ( https://codecov.io/gh/ramsey-devs/ramsey )
6- [ ![ quality] ( https://app.codacy.com/project/badge/Grade/ed13460537fd4ac099c8534b1d9a0202 )] ( https://app.codacy.com/gh/ramsey-devs/ramsey/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade )
76[ ![ documentation] ( https://readthedocs.org/projects/ramsey/badge/?version=latest )] ( https://ramsey.readthedocs.io/en/latest/?badge=latest )
87[ ![ version] ( https://img.shields.io/pypi/v/ramsey.svg?colorB=black&style=flat )] ( https://pypi.org/project/ramsey/ )
98
1211## About
1312
1413Ramsey is a library for probabilistic deep learning using [ JAX] ( https://github.com/google/jax ) ,
15- [ Flax] ( https://github.com/google/flax ) and [ NumPyro] ( https://github.com/pyro-ppl/numpyro ) .
16-
17- Ramsey's scope covers
14+ [ Flax] ( https://github.com/google/flax ) and [ NumPyro] ( https://github.com/pyro-ppl/numpyro ) . Its scope covers
1815
1916- neural processes (vanilla, attentive, Markovian, convolutional, ...),
2017- neural Laplace and Fourier operator models,
21- - flow matching and denoising diffusion models,
2218- etc.
2319
2420## Example usage
@@ -29,35 +25,44 @@ You can, for instance, construct a simple neural process like this:
2925from flax import nnx
3026
3127from ramsey import NP
32- from ramsey.nn import MLP
28+ from ramsey.nn import MLP # just a flax.nnx module
3329
3430def get_neural_process (in_features , out_features ):
3531 dim = 128
3632 np = NP(
37- decoder = MLP(in_features, [dim, dim, out_features * 2 ], rngs = nnx.Rngs(0 )),\
3833 latent_encoder = (
39- MLP(in_features, [dim, dim], rngs = nnx.Rngs(1 )),
40- MLP(dim, [dim, dim * 2 ], rngs = nnx.Rngs(2 ))
41- )
34+ MLP(in_features, [dim, dim], rngs = nnx.Rngs(0 )),
35+ MLP(dim, [dim, dim * 2 ], rngs = nnx.Rngs(1 ))
36+ ),
37+ decoder = MLP(in_features, [dim, dim, out_features * 2 ], rngs = nnx.Rngs(2 ))
4238 )
4339 return np
4440
4541neural_process = get_neural_process(1 , 1 )
4642```
4743
48- The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically ` flax.nnx ` MLPs, but
49- Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can train
50- it by accessing the ELBO given input-output pairs via
44+ The neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically ` flax.nnx ` MLPs, but
45+ Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs.
46+
47+ Ramsey provides a unified interface where each method implements (at least) ` __call__ ` and ` loss `
48+ functions to transform a set of inputs and compute a training loss, respectively:
5149
5250``` python
5351from jax import random as jr
5452from ramsey.data import sample_from_sine_function
5553
56- key = jr.PRNGKey(0 )
57- data = sample_from_sine_function(key)
58-
54+ data = sample_from_sine_function(jr.key(0 ))
5955x_context, y_context = data.x[:, :20 , :], data.y[:, :20 , :]
6056x_target, y_target = data.x, data.y
57+
58+ # make a prediction
59+ pred = neural_process(
60+ x_context = x_context,
61+ y_context = y_context,
62+ x_target = x_target,
63+ )
64+
65+ # compute the loss
6166loss = neural_process.loss(
6267 x_context = x_context,
6368 y_context = y_context,
@@ -66,11 +71,6 @@ loss = neural_process.loss(
6671)
6772```
6873
69- Making predictions can be done like this:
70- ``` python
71- pred = neural_process(x_context = x_context, y_context = y_context, x_target = x_target)
72- ```
73-
7474## Installation
7575
7676To install from PyPI, call:
0 commit comments