Skip to content

Commit f8db9aa

Browse files
Fix doc mistakes (#701)
* Fix equation formatting * Clarify JAX gradient error * Fix punctuation + capitalization * Fix grammar Should not begin sentence with "i.e." in English. * Fix math formatting error * Fix typo Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation. * Add SVGD citation to appear in doc Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation. To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring. * Fix grammar + clarify doc * Fix typo --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
1 parent 5764a2b commit f8db9aa

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

blackjax/adaptation/meads_adaptation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple):
3636
alpha
3737
Value of the alpha parameter of the generalized HMC algorithm.
3838
delta
39-
Value of the alpha parameter of the generalized HMC algorithm.
39+
Value of the delta parameter of the generalized HMC algorithm.
4040
4141
"""
4242

@@ -60,7 +60,7 @@ def base():
6060
with shape.
6161
6262
This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain
63-
adaptation instead of parallel ensample chain adaptation.
63+
adaptation instead of parallel ensemble chain adaptation.
6464
6565
Returns
6666
-------

blackjax/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple):
8989
"""A pair of functions that represents a MCMC sampling algorithm.
9090
9191
Blackjax sampling algorithms are implemented as a pair of pure functions: a
92-
kernel, that takes a new samples starting from the current state, and an
92+
kernel, that generates a new sample from the current state, and an
9393
initialization function that creates a kernel state from a chain position.
9494
9595
As they represent Markov kernels, the kernel functions are pure functions

blackjax/vi/svgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def as_top_level_api(
135135
kernel: Callable = rbf_kernel,
136136
update_kernel_parameters: Callable = update_median_heuristic,
137137
):
138-
"""Implements the (basic) user interface for the svgd algorithm.
138+
"""Implements the (basic) user interface for the svgd algorithm :cite:p:`liu2016stein`.
139139
140140
Parameters
141141
----------

docs/examples/howto_custom_gradients.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ Functions can be defined as the minimum of another one, $f(x) = min_{y} g(x,y)$.
2929
Our example is taken from the theory of [convex conjugates](https://en.wikipedia.org/wiki/Convex_conjugate), used for example in optimal transport. Let's consider the following function:
3030

3131
$$
32-
\begin{align*}
33-
g(x, y) &= h(y) - \langle x, y\rangle\\
34-
h(x) &= \frac{1}{p}|x|^p,\qquad p > 1.\\
35-
\end{align*}
32+
\begin{equation*}
33+
g(x, y) = h(y) - \langle x, y\rangle,\qquad h(x) = \frac{1}{p}|x|^p,\qquad p > 1.
34+
\end{equation*}
3635
$$
3736

3837
And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as:
@@ -69,7 +68,7 @@ Note the we also return the value of $y$ where the minimum of $g$ is achieved (t
6968

7069
### Trying to differentate the function with `jax.grad`
7170

72-
The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error:
71+
The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops used in BFGS, and trying to compute it directly raises an error:
7372

7473
```{code-cell} ipython3
7574
# We only want the gradient with respect to `x`
@@ -97,15 +96,15 @@ The first order optimality criterion
9796
\end{equation*}
9897
```
9998

100-
Ensures that:
99+
ensures that
101100

102101
```{math}
103102
\begin{equation*}
104103
\frac{df}{dx} = y(x).
105104
\end{equation*}
106105
```
107106

108-
i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.
107+
In other words, the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.
109108

110109

111110
### Telling JAX to use a custom gradient

0 commit comments

Comments
 (0)