Skip to content

Commit 41bf090

Browse files
authored
Fix typos and improve clarity in README.md
1 parent fc24560 commit 41bf090

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ The snippets below develop the polynomial regression example from our paper's *O
4646

4747
### Vectorizing Generative Functions with vmap
4848

49-
We begin by expressing the quadratic regression model as a composition of generative functions (`@gen`-decorated Python functions).
49+
We begin by expressing the polynomial regression model as a composition of generative functions (`@gen`-decorated Python functions).
5050

51-
Each random choice (invocation of a generative function) is tagged with a string address (`"a"`, `"b"`, `"c"`, `"obs"`), which is used to construct a structured representation of the model’s latent variables and observed data, called a _trace_.
51+
Each random choice (invocation of a generative function) is tagged with a string address (`"a"`, `"b"`, `"c"`, `"obs"`), which is used to construct a structured representation of the model’s random variables, called a _trace_.
5252

53-
Packaging the coefficients inside a callable `Lambda` Pytree mirrors the notion of sampling a function-valued random variable: downstream computations can call the curve directly while the trace retains access to its parameters.
53+
In GenJAX, packaging the coefficients inside a callable `Lambda` Pytree is a convenient way to allow downstream computations to call the curve directly, while the trace retains access to its parameters.
5454

5555
```python
5656
from genjax import gen, normal
@@ -98,13 +98,13 @@ print(trace.get_choices()["curve"].keys())
9898
print(trace.get_choices()["ys"]["obs"].shape)
9999
```
100100

101-
Vectorizing the `point` generative function with `vmap` mirrors the Overview section's Figure 3: the resulting trace preserves the hierarchical structure of the coefficients while lifting the observation random choice into a vectorized array valued version. This "structure preserving vectorization is what later enables us to reason about datasets and other inference logic in a vectorized fashion.
101+
Vectorizing the `point` generative function with `vmap` mirrors the Overview section's Figure 3: the resulting trace preserves the hierarchical structure of the coefficients of the polynomial while lifting the observation random choice into a vectorized array-valued version. This structure preserving vectorization is what later enables us to reason about datasets consisting of many points (and other inference logic) in a vectorized fashion.
102102

103103
### Vectorized Programmable Inference
104104

105-
The generative function interface supplies a small set of methods`simulate`, `generate`, `assess`, `update`that we can compose into inference algorithms.
105+
The generative function interface supplies a small set of methods - `simulate`, `generate`, `assess`, `update` - that we can compose into inference algorithms.
106106

107-
Here we implement likelihood weighting (importance sampling): a single-particle routine constrains the observation site via the `generate` interface, while a vectorized wrapper increases the number of particles. The logic of guessing (sampling) and checking (computing an importance weight) -- internally implemented in `generate` -- remains the same across particles, only the array dimensions vary with the particle count.
107+
Here, we implement likelihood weighting (importance sampling): a single-particle routine constrains the observation site given a fixed value via the `generate` interface, while a vectorized wrapper increases the number of particles. The logic of guessing (sampling) and checking (computing an importance weight) -- internally implemented in `generate` -- remains the same across particles, only the array dimensions vary with the particle count.
108108

109109
```python
110110
from jax.scipy.special import logsumexp
@@ -138,7 +138,7 @@ traces, log_weights = vectorized_importance_sampling(
138138
print(traces.get_choices()["curve"]["a"].shape, log_marginal_likelihood(log_weights))
139139
```
140140

141-
Running on a GPU allows us to increase the axis size as far as memory allows, just as in the scaling curves shown in Figure 5 in the paper.
141+
Running on a GPU allows us to increase the number of particles as far as memory allows, just as in the scaling curves shown in Figure 5 in the paper.
142142

143143
### Improving Robustness using Stochastic Branching
144144

0 commit comments

Comments
 (0)