Skip to content

Commit cb3db4a

Browse files
committed
fix doc
1 parent 8976307 commit cb3db4a

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

docs/src/api.md

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ For example of Gaussian VI, we can construct the flow as follows:
1515
```@julia
1616
using Distributions, Bijectors
1717
T= Float32
18+
@leaf MvNormal # to prevent params in q₀ from being optimized
1819
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
1920
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))
2021
```
@@ -23,7 +24,7 @@ To train the Gaussian VI targeting at distirbution $p$ via ELBO maiximization, w
2324
using NormalizingFlows
2425
2526
sample_per_iter = 10
26-
flow_trained, stats, _ = train_flow(
27+
flow_trained, stats, _ , _ = train_flow(
2728
elbo,
2829
flow,
2930
logp,
@@ -83,11 +84,3 @@ NormalizingFlows.loglikelihood
8384
```@docs
8485
NormalizingFlows.optimize
8586
```
86-
87-
88-
## Utility Functions for Taking Gradient
89-
```@docs
90-
NormalizingFlows.grad!
91-
NormalizingFlows.value_and_gradient!
92-
```
93-

docs/src/example.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Here we used the `PlanarLayer()` from `Bijectors.jl` to construct a
3636

3737
```julia
3838
using Bijectors, FunctionChains
39+
using Functors
3940

4041
function create_planar_flow(n_layers::Int, q₀)
4142
d = length(q₀)
@@ -45,7 +46,9 @@ function create_planar_flow(n_layers::Int, q₀)
4546
end
4647

4748
# create a 20-layer planar flow
48-
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I))
49+
@leaf MvNormal # to prevent params in q₀ from being optimized
50+
q₀ = MvNormal(zeros(Float32, 2), I)
51+
flow = create_planar_flow(20, q₀)
4952
flow_untrained = deepcopy(flow) # keep a copy of the untrained flow for comparison
5053
```
5154
*Notice that here the flow layers are chained together using `fchain` function from [`FunctionChains.jl`](https://github.com/oschulz/FunctionChains.jl).
@@ -116,4 +119,4 @@ plot!(title = "Comparison of Trained and Untrained Flow", xlabel = "X", ylabel=
116119

117120
## Reference
118121

119-
- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
122+
- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning

0 commit comments

Comments
 (0)