Skip to content

Commit 1eb4129

Browse files
femtomcclaude
andcommitted
Fix markdown-exec code blocks in documentation
- Fix categorical distribution to use logits instead of probs - Fix selection API usage: sel("a", "b") → sel("a") | sel("b") - Fix coin flip model to avoid Python loops (use manual unrolling) - Fix distribution docstring examples to use correct API - All markdown-exec code blocks now execute without errors 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent eead190 commit 1eb4129

File tree

3 files changed

+30
-27
lines changed

3 files changed

+30
-27
lines changed

docs/reference/core.md

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ from genjax import gen, distributions
4646
def coin_flip_model(n_flips):
4747
"""A simple coin flipping model with unknown bias."""
4848
bias = distributions.beta(1.0, 1.0) @ "bias"
49-
flips = []
50-
for i in range(n_flips):
51-
flip = distributions.bernoulli(bias) @ f"flip_{i}"
52-
flips.append(flip)
53-
return jnp.array(flips)
49+
50+
# For demonstration, we'll show manual unrolling
51+
# In practice, use Scan combinator for loops
52+
flip_0 = distributions.bernoulli(bias) @ "flip_0"
53+
flip_1 = distributions.bernoulli(bias) @ "flip_1"
54+
flip_2 = distributions.bernoulli(bias) @ "flip_2"
55+
56+
return jnp.array([flip_0, flip_1, flip_2])
5457

5558
print("Model defined successfully!")
5659
```
@@ -66,11 +69,14 @@ from genjax import gen, distributions
6669
def coin_flip_model(n_flips):
6770
"""A simple coin flipping model with unknown bias."""
6871
bias = distributions.beta(1.0, 1.0) @ "bias"
69-
flips = []
70-
for i in range(n_flips):
71-
flip = distributions.bernoulli(bias) @ f"flip_{i}"
72-
flips.append(flip)
73-
return jnp.array(flips)
72+
73+
# For demonstration, we'll show manual unrolling
74+
# In practice, use Scan combinator for loops
75+
flip_0 = distributions.bernoulli(bias) @ "flip_0"
76+
flip_1 = distributions.bernoulli(bias) @ "flip_1"
77+
flip_2 = distributions.bernoulli(bias) @ "flip_2"
78+
79+
return jnp.array([flip_0, flip_1, flip_2])
7480

7581
# Assess the log probability of specific choices
7682
choices = {"bias": 0.7, "flip_0": 1, "flip_1": 1, "flip_2": 0}
@@ -88,10 +94,10 @@ from genjax import sel, Selection
8894

8995
# Create various selections
9096
s1 = sel("bias") # Select only bias
91-
s2 = sel("flip_0", "flip_1") # Select two flips
97+
s2 = sel("flip_0") | sel("flip_1") # Select two flips with OR
9298
s3 = sel("bias") | sel("flip_2") # Select bias OR flip_2
9399

94100
print(f"Selection s1 targets: bias")
95-
print(f"Selection s2 targets: flip_0, flip_1")
101+
print(f"Selection s2 targets: flip_0 or flip_1")
96102
print(f"Selection s3 targets: bias or flip_2")
97103
```

docs/reference/distributions.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ from genjax import distributions
4646
log_prob_bern, _ = distributions.bernoulli.assess(1, 0.7)
4747
print(f"Log prob of 1 under Bernoulli(0.7): {log_prob_bern:.3f}")
4848

49-
# Categorical distribution
49+
# Categorical distribution (uses logits, not probs)
5050
probs = jnp.array([0.2, 0.3, 0.5])
51-
log_prob_cat, _ = distributions.categorical.assess(2, probs)
52-
print(f"Log prob of category 2 under Categorical({probs}): {log_prob_cat:.3f}")
51+
logits = jnp.log(probs)
52+
log_prob_cat, _ = distributions.categorical.assess(2, logits)
53+
print(f"Log prob of category 2 under Categorical(probs={probs}): {log_prob_cat:.3f}")
5354

5455
# Poisson distribution
5556
log_prob_pois, _ = distributions.poisson.assess(4, 3.0)
@@ -65,7 +66,7 @@ print("- normal(mu, sigma)")
6566
print("- beta(alpha, beta)")
6667
print("- exponential(rate)")
6768
print("- bernoulli(p)")
68-
print("- categorical(probs)")
69+
print("- categorical(logits) # Note: uses logits, not probs")
6970
print("- poisson(rate)")
7071
print("- gamma(concentration, rate)")
7172
print("- uniform(low, high)")

src/genjax/distributions.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,13 @@
188188
import jax.numpy as jnp
189189
from genjax import distributions
190190
191-
# Create a key for randomness
192-
key = jax.random.PRNGKey(42)
193-
194-
# Sample from normal distribution
195-
trace = distributions.normal.simulate(key, (0.0, 1.0))
196-
sample = trace.retval
191+
# Sample from normal distribution
192+
trace = distributions.normal.simulate(0.0, 1.0)
193+
sample = trace.get_retval()
197194
print(f"Sample from Normal(0, 1): {sample:.3f}")
198195
199196
# Evaluate log probability
200-
log_prob, _ = distributions.normal.assess(1.5, (0.0, 1.0))
197+
log_prob, _ = distributions.normal.assess(1.5, 0.0, 1.0)
201198
print(f"Log prob of 1.5 under Normal(0, 1): {log_prob:.3f}")
202199
203200
# Use in a generative function
@@ -210,10 +207,9 @@ def model():
210207
return x + y
211208
212209
# Simulate the model
213-
key, subkey = jax.random.split(key)
214-
trace = model.simulate(subkey, ())
215-
print(f"Model output: {trace.retval:.3f}")
216-
print(f"Choices: x={trace['x']:.3f}, y={trace['y']:.3f}")
210+
trace = model.simulate()
211+
print(f"Model output: {trace.get_retval():.3f}")
212+
print(f"Choices: x={trace.get_choices()['x']:.3f}, y={trace.get_choices()['y']:.3f}")
217213
```
218214
219215
References:

0 commit comments

Comments
 (0)