Skip to content

Commit 94045a9

Browse files
committed
Principled linking
1 parent ad5e550 commit 94045a9

33 files changed

+675
-503
lines changed

HISTORY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ These have all been replaced by three functions
110110

111111
- `setindex!!` is the one to use for simply setting a variable in `VarInfo` to a known value. It works regardless of whether the variable already exists.
112112
- `setindex_internal!!` is the one to use for setting the internal, vectorised representation of a variable. See the docstring for details.
113-
- `setindex_with_dist!!` is to be used when you want to set a value, but choose the internal representation based on which distribution this value is a sample for.
113+
- `setindex_with_dist!!` is to be used when setting a `TransformedValue` into a VarInfo's values. You should really try not to use this unless you absolutely must! It is quite low-level and we much prefer that you use the accumulator API instead.
114114

115115
The order of the arguments for some of these functions has also changed, and now more closely matches the usual convention for `setindex!!`.
116116

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ makedocs(;
4747
"vnt/arraylikeblocks.md",
4848
],
4949
"Initialisation strategies" => "init.md",
50+
"Link strategies" => "link.md",
5051
"Accumulators" => "accumulators.md",
5152
"Model evaluation" => "flow.md",
5253
"Storing values" => "values.md",

docs/src/accumulators.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ end
120120
model = f(2.0)
121121
122122
vi = DynamicPPL.OnlyAccsVarInfo((VarNameLogpAccumulator(),))
123-
_, vi = DynamicPPL.init!!(model, vi, InitFromParams((; x=1.0)))
123+
_, vi = DynamicPPL.init!!(model, vi, InitFromParams((; x=1.0)), UnlinkAll())
124124
125125
# This is why we used a const.
126126
output_acc = DynamicPPL.getacc(vi, Val(VARNAMELOGP_NAME))
@@ -180,7 +180,7 @@ This is slightly hacky, see the warning below and links therein for more discuss
180180

181181
```@example 1
182182
x = 1.0
183-
model = setleafcontext(model, DynamicPPL.InitContext(InitFromParams((; x=x))))
183+
model = setleafcontext(model, DynamicPPL.InitContext(InitFromParams((; x=x)), UnlinkAll()))
184184
_, tsvi = DynamicPPL._evaluate!!(model, tsvi)
185185
tsvi.accs_by_thread
186186
```

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ DynamicPPL.link!!
455455
DynamicPPL.invlink!!
456456
DynamicPPL.update_link_status!!
457457
DynamicPPL.generate_linked_value
458+
DynamicPPL.apply_link_strategy
458459
```
459460

460461
```@docs

docs/src/flow.md

Lines changed: 87 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,41 @@
11
# How data flows through a model
22

3-
Having discussed initialisation strategies and accumulators, we can now put all the pieces together to show how data enters a model, is used to perform computations, and how the results are extracted.
3+
This page aims to show how data can enter a model, how we use it to perform computations, and how the results are extracted.
44

5-
**The summary is: initialisation strategies are responsible for telling the model what values to use for its parameters, whereas accumulators act as containers for aggregated outputs.**
5+
As a high-level summary:
66

7-
Thus, there is a clear separation between the *inputs* to the model, and the *outputs* of the model.
7+
- **initialisation strategies** are responsible for generating parameter values
8+
- **link strategies** are responsible for telling the model how to interpret those values (i.e., whether log-Jacobians need to be computed)
9+
- **accumulators** are responsible for aggregating the outputs of the model (e.g. log probabilities, transformed values, etc.)
810

9-
!!! note
11+
In this model, there is a clear separation between the *inputs* to the model, and the *outputs* of the model.
12+
13+
!!! note "DefaultContext"
1014

11-
While `VarInfo` and `DefaultContext` still exist, this is mostly a historical remnant. `DefaultContext` means that the inputs should come from the values of the provided `VarInfo`, and the outputs are stored in the accumulators of the provided `VarInfo`. However, this can easily be refactored such that the values are provided directly as an initialisation strategy. See [this issue](https://github.com/TuringLang/DynamicPPL.jl/issues/1184) for more details.
15+
While `VarInfo` and `DefaultContext` still exist, this is mostly a historical remnant.
16+
`DefaultContext` means that the inputs should come from the values of the provided `VarInfo`, and the outputs are stored in the accumulators of the provided `VarInfo`.
17+
However, this can easily be refactored such that the values are provided directly as an initialisation strategy.
18+
See [this issue](https://github.com/TuringLang/DynamicPPL.jl/issues/1184) for more details.
1219

1320
There are three stages to every tilde-statement:
1421

1522
1. Initialisation: get an `AbstractTransformedValue` from the initialisation strategy.
16-
17-
2. Computation: figure out the untransformed (raw) value; compute the log-Jacobian if necessary.
23+
2. Computation: figure out the untransformed (raw) value and the linked value (where necessary); compute the relevant log-Jacobian.
1824
3. Accumulation: pass all the relevant information to the accumulators, which individually decide what to do with it.
1925

20-
In fact this (more or less) directly translates into three lines of code: see e.g. the method for `tilde_assume!!` in `src/onlyaccs.jl`, which (as of the time of writing) looks like:
26+
In fact this (more or less) directly translates into three lines of code: see e.g. the method for `tilde_assume!!` in `src/contexts/init.jl`, which (as of the time of writing) looks like:
2127

2228
```julia
2329
function DynamicPPL.tilde_assume!!(ctx::InitContext, dist, vn, template, vi)
2430
# 1. Initialisation
25-
tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy)
31+
init_tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy)
2632

2733
# 2. Computation
28-
# (Though see also the warning in the computation section below.)
29-
x, inv_logjac = Bijectors.with_logabsdet_jacobian(
30-
DynamicPPL.get_transform(tval), DynamicPPL.get_internal_value(tval)
31-
)
34+
x, tval, logjac = apply_link_strategy(ctx.link_strategy, init_tval, vn, dist)
3235

3336
# 3. Accumulation
34-
vi = DynamicPPL.accumulate_assume!!(vi, x, tval, -inv_logjac, vn, dist, template)
37+
vi = DynamicPPL.setindex_with_dist!!(vi, tval, dist, vn, template)
38+
vi = DynamicPPL.accumulate_assume!!(vi, x, tval, logjac, vn, dist, template)
3539
return x, vi
3640
end
3741
```
@@ -46,15 +50,15 @@ In the following sections, we stick to the three sections of `tilde_assume!!`.
4650
## Initialisation
4751

4852
```julia
49-
tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy)
53+
init_tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy)
5054
```
5155

5256
The initialisation step is handled by the `init` function, which dispatches on the initialisation strategy.
5357
For example, if `ctx.strategy` is `InitFromPrior()`, then `init()` samples a value from the distribution `dist`.
5458

55-
!!! note
59+
!!! note "DefaultContext"
5660

57-
For `DefaultContext`, this is replaced by looking for the value stored inside `vi`. As described above, this can be refactored in the near future.
61+
For `DefaultContext`, initialisation is handled by looking for the value stored inside `vi`.
5862

5963
As discussed in the [Initialisation strategies](./init.md) page, this step, in general, does not return just the raw value (like `rand(dist)`).
6064
It returns an [`DynamicPPL.AbstractTransformedValue`](@ref), which represents a value that _may_ have been transformed.
@@ -63,59 +67,93 @@ In the case of `InitFromPrior()`, the value is of course not transformed; we ret
6367
However, consider the case where we are using parameters stored inside a `VarInfo`: the value may have been stored either as a vectorised form, or as a linked vectorised form.
6468
In this case, `init()` will return either a [`DynamicPPL.VectorValue`](@ref) or a [`DynamicPPL.LinkedVectorValue`](@ref).
6569

66-
The reason why we return this wrapped value is because sometimes we don't want to eagerly perform the transformation.
67-
Consider the case where we have an accumulator that attempts to store linked values (this is done precisely when linking a VarInfo: the linked values are stored in an accumulator, which then becomes the basis of the linked VarInfo).
68-
In this case, if we eagerly perform the inverse link transformation, we would have to link it again inside the accumulator, which is inefficient!
70+
The reason why we return this wrapped value is because we want to avoid having to perform transformations multiple times.
71+
Each step is responsible for only performing the transformations it needs to.
72+
At this stage, there has not yet been any need for the raw value, so we do not perform any transformations yet.
73+
Thus, the `AbstractTransformedValue` is passed straight through and is used by both the computation and accumulation steps.
6974

70-
The `AbstractTransformedValue` is passed straight through and is used by both the computation and accumulation steps.
75+
!!! note "The return type of init() doesn't matter"
76+
77+
The _type_ of `AbstractTransformedValue` returned by `init()`, in general, has no impact on whether the value is considered to be linked or not.
78+
That is determined solely by the link strategy (see below).
79+
This separation allows us to perform the minimum amount of transformations necessary inside `init()`.
80+
If we were to eagerly transform the value inside `init()`, we could easily end up performing the same transformation multiple times across the different steps.
7181

7282
## Computation
7383

7484
```julia
75-
x, inv_logjac = Bijectors.with_logabsdet_jacobian(
76-
DynamicPPL.get_transform(tval), DynamicPPL.get_internal_value(tval)
77-
)
85+
x, tval, logjac = apply_link_strategy(ctx.link_strategy, init_tval, vn, dist)
7886
```
7987

80-
At *some* point, we do need to perform the transformation to get the actual raw value.
81-
This is because DynamicPPL promises in the model that the variables on the left-hand side of the tilde are actual raw values.
88+
There are three return values in this step, and they correspond to the three things that this step needs to do.
89+
They are all interconnected, which is why they are computed together inside `apply_link_strategy()`: by doing so we can ensure that `with_logabsdet_jacobian` is only called a maximum of once per tilde-statement.
8290

83-
```julia
84-
@model function f()
85-
x ~ dist
86-
# Here, `x` _must_ be the actual raw value.
87-
@show x
88-
end
89-
```
91+
1. **Get the raw (untransformed) value `x`**
92+
93+
At *some* point, we do need to perform the transformation to get the actual raw value.
94+
This is because DynamicPPL promises in the model that the variables on the left-hand side of the tilde are actual raw values.
95+
96+
```julia
97+
@model function f()
98+
x ~ dist
99+
# Here, `x` _must_ be the actual raw value.
100+
@show x
101+
end
102+
```
103+
104+
Thus, regardless of what we are accumulating, we will have to unwrap the transformed value provided by `init()`.
90105

91-
Thus, regardless of what we are accumulating, we will have to unwrap the transformed value provided by `init()`.
92-
We also need to account for the log-Jacobian of the transformation, if any.
106+
2. **Get the (possibly linked) value `tval`**
107+
108+
In addition to the raw value, if the link strategy indicates that we should treat `vn` as being in linked space, we also need to compute the linked value.
109+
This is because some accumulators may need to work with the linked value instead of the raw value.
110+
111+
(If there is a full VarInfo being used, the linked value will also have to be set inside the VarInfo.)
112+
3. **Compute the log-Jacobian `logjac`**
113+
114+
`logjac` is only accumulated if the link strategy indicates that `vn` is linked.
115+
The convention in DynamicPPL is that the log-Jacobian is always computed with respect to the forward link transformation.
93116

94-
!!! note
117+
It is worth emphasising that whether a value is linked or not is determined by the *link strategy* provided to the model (i.e., `ctx.link_strategy`), not the initialisation strategy (`ctx.strategy`).
118+
The reason for this is to allow a separation between the source of the values (initialisation) and how those values are to be interpreted (link strategy).
119+
120+
This allows us to, for example, generate values from the (unlinked) prior but also calculate their log-density in linked space and accumulate linked values by combining `InitFromPrior()` with `LinkAll()`.
121+
It also allows us to read values from an existing `VarInfo` but interpret them as being in a different space by combining `InitFromParams()` with a different link strategy: this corresponds exactly to the act of 'linking' a VarInfo.
122+
123+
!!! note "DefaultContext"
124+
125+
For DefaultContext, whether or not the variable is linked will depend on the `VarInfo` used for evaluation. If the variable is stored as linked in the `VarInfo`, then it will be treated as linked here.
126+
Notice that both the initialisation strategy as well as the link strategy are effectively determined by the `VarInfo` in this case.
127+
The separation described above is not possible when using `DefaultContext`.
128+
129+
The move away from `DefaultContext` and towards `InitContext` is motivated by the desire to separate these two concerns, and to enable a more modular and declarative way of specifying how a model is to be evaluated.
130+
131+
!!! note "Log-Jacobian computation"
95132

96133
In principle, if the log-Jacobian is not of interest to any of the accumulators, we _could_ skip computing it here.
97134
However, that is not easy to determine in practice.
98-
We also cannot defer the log-Jacobian computation to the accumulator, since if multiple accumulators need the log-Jacobian, we would end up computing it multiple times.
135+
We also cannot defer the log-Jacobian computation to the accumulator, since it is often more efficient to compute it at the same time as the transformation (i.e., using `with_logabsdet_jacobian`).
99136
The current situation of computing it once here is the most sensible compromise (for now).
100137

101138
One could envision a future where accumulators declare upfront (via their type) whether they need the log-Jacobian or not. We could then skip computing it if no accumulator needs it.
102139

103-
!!! warning
104-
105-
If you look at the source code for that method, it is more complicated than the above!
106-
Have we lied?
107-
It turns out that there is a subtlety here: the transformation obtained from `DynamicPPL.get_transform(tval)` may in fact be incorrect.
108-
109-
Consider the case where a transform is dependent on the value itself (e.g., a variable whose support depends on another variable).
110-
In this case, setting new values into a VarInfo (via `unflatten!!`) may cause the cached transformations to be invalid.
111-
Where possible, it is better to re-obtain the transformation from `dist`, which is always up-to-date since it is obtained from model execution.
112-
113140
## Accumulation
114141

115142
```julia
116-
vi = DynamicPPL.accumulate_assume!!(vi, x, tval, -inv_logjac, vn, dist, template)
143+
vi = DynamicPPL.setindex_with_dist!!(vi, tval, dist, vn, template)
144+
vi = DynamicPPL.accumulate_assume!!(vi, x, tval, logjac, vn, dist, template)
117145
```
118146

119-
This step is where most of the interesting action happens.
147+
!!! note
148+
149+
The first line, `setindex_with_dist!!`, is only necessary when using a full `VarInfo`.
150+
It essentially stores the value `tval` inside the `VarInfo`, but makes sure to store a vectorised form (i.e., if `tval` is an `UntransformedValue`, it will be converted to a `VectorValue` before being stored).
151+
This is entirely equivalent to using a `VectorValueAccumulator` to store the values; it's just that when using a full `VarInfo` that accumulator is 'built-in' as `vi.values`.
152+
153+
Since conceptually this is the same as an accumulator, we will not discuss it further here.
154+
155+
Here, we pass all of the information we have gathered so far for this tilde-statement to the accumulators.
156+
`accumulate_assume!!(vi::AbstractVarInfo, ...)` will loop over all accumulators stored inside `vi`, and call each of their individual `accumulate_assume!!` methods.
157+
This method is responsible for deciding how to combine the information provided.
120158

121159
Accumulators are described in much more detail on the [Accumulators](./accumulators.md) page; please read that for more information!

docs/src/init.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ nothing # hide
5959
we can then make a proposal for `x` as follows:
6060

6161
```@example 1
62-
new_x, new_vi = DynamicPPL.init!!(model, VarInfo(), InitRandomWalk(x_prev, 0.5))
62+
new_x, new_vi = DynamicPPL.init!!(
63+
model, VarInfo(), InitRandomWalk(x_prev, 0.5), UnlinkAll()
64+
)
6365
nothing # hide
6466
```
6567

@@ -84,10 +86,6 @@ For example, [`DynamicPPL.InitFromParams`](@ref) reads from a set of given param
8486

8587
## The returned `AbstractTransformedValue`
8688

87-
!!! warning
88-
89-
The correctness of this section is contingent on https://github.com/TuringLang/DynamicPPL.jl/pull/1231 being merged.
90-
9189
As mentioned above, the `init` function must return an `AbstractTransformedValue`.
9290
The subtype of `AbstractTransformedValue` used does not affect the result of the model evaluation, but it may have performance implications.
9391
**In particular, the returned subtype does not determine whether the log-Jacobian term is accumulated or not: that is determined by a separate _link strategy_.**

docs/src/link.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Link strategies
2+
3+
Blah blah blah.

0 commit comments

Comments
 (0)