Skip to content

Commit 6117e4e

Browse files
committed
simplify
1 parent 21df935 commit 6117e4e

File tree

1 file changed

+26
-188
lines changed

1 file changed

+26
-188
lines changed
Lines changed: 26 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -1,202 +1,40 @@
11
# Proposal for a New LogDensity Function Interface
22

3-
## Introduction
3+
<https://github.com/TuringLang/DynamicPPL.jl/issues/691>
44

5-
The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate:
5+
The goal is to design a flexible, user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts like Gibbs sampling and Bayesian workflows.
66

7-
- **Conditioning**: Incorporating observed data into the model.
8-
- **Fixing**: Fixing certain variables to specific values. (like `do` operator)
9-
- **Generated Quantities**: Computing additional expressions or functions based on the model parameters.
10-
- **Prediction**: Making predictions by fixing parameters and unconditioning on data.
7+
## Evaluation functions:
118

12-
This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs).
9+
1. `evaluate`
1310

14-
## Proposed Interface
11+
## Query functions:
1512

16-
Below is a proposed interface with key functionalities and their implementations.
13+
1. `is_parametric(model)`
14+
2. `dimension(model)` (only defined when `is_parametric(model) == true`)
15+
3. `is_conditioned(model)`
16+
4. `is_fixed(model)`
17+
5. `logjoint(model, params)`
18+
6. `loglikelihood(model, params)`
19+
7. `logprior(model, params)`
1720

18-
### Core Functions
21+
where `params` can be `Vector`, `NamedTuple`, `Dict`, etc.
1922

20-
#### Check if a Model is Parametric
23+
## Transformation functions:
2124

22-
```julia
23-
# Check if a log density model is parametric
24-
function is_parametric(model::LogDensityModel) -> Bool
25-
...
26-
end
27-
```
25+
1. `condition(model, conditioned_vars)`
26+
2. `fix(model, fixed_vars)`
27+
3. `factor(model, variables_in_the_factor)`
2828

29-
- **Description**: Determines if the model has a parameter space with a defined dimension.
30-
-
29+
`condition` and `factor` are similar, but `factor` effectively generates a sub-model.
3130

32-
#### Get the Dimension of a Parametric Model
31+
## Higher-order functions:
3332

34-
```julia
35-
# Get the dimension of the parameter space (only defined when is_parametric(model) is true)
36-
function dimension(model::LogDensityModel) -> Int
37-
...
38-
end
39-
```
33+
1. `generated_quantities(model, sample, [, expr])` or `generated_quantities(model, sample, f, args...)`
34+
1. `generated_quantities` computes things from the sampling result.
35+
2. In `DynamicPPL`, this is the model's return value. For more flexibility, we should allow passing an expression or function. (Currently, users can rewrite the model definition to achieve this in `DynamicPPL`, but with limitations. We want to make this more generic.)
36+
3. `rand` is a special case of `generated_quantities` (when no sample is passed).
37+
2. `predict(model, sample)`
4038

41-
- **Description**: Returns the dimension of the parameter space for parametric models.
42-
43-
### Log Density Computations
44-
45-
#### Log-Likelihood
46-
47-
```julia
48-
# Compute the log-likelihood given parameters
49-
function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
50-
...
51-
end
52-
```
53-
54-
- **Description**: Computes the log-likelihood of the data given the model parameters.
55-
56-
#### Log-Prior
57-
58-
```julia
59-
# Compute the log-prior given parameters
60-
function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
61-
...
62-
end
63-
```
64-
65-
- **Description**: Computes the log-prior probability of the model parameters.
66-
67-
#### Log-Joint
68-
69-
```julia
70-
# Compute the log-joint density (log-likelihood + log-prior)
71-
function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
72-
return loglikelihood(model, params) + logprior(model, params)
73-
end
74-
```
75-
76-
- **Description**: Computes the total log density by summing the log-likelihood and log-prior.
77-
78-
### Conditioning and Fixing Variables
79-
80-
#### Conditioning a Model
81-
82-
```julia
83-
# Condition the model on observed data
84-
function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel
85-
...
86-
end
87-
```
88-
89-
- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`.
90-
91-
#### Checking if a Model is Conditioned
92-
93-
```julia
94-
# Check if a model is conditioned
95-
function is_conditioned(model::LogDensityModel) -> Bool
96-
...
97-
end
98-
```
99-
100-
- **Description**: Checks whether the model has been conditioned on data.
101-
102-
#### Fixing Variables in a Model
103-
104-
```julia
105-
# Fix certain variables in the model
106-
function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel
107-
...
108-
end
109-
```
110-
111-
- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`.
112-
113-
#### Checking if a Model has Fixed Variables
114-
115-
```julia
116-
# Check if a model has fixed variables
117-
function is_fixed(model::LogDensityModel) -> Bool
118-
...
119-
end
120-
```
121-
122-
- **Description**: Determines if any variables in the model have been fixed.
123-
124-
### Specialized Models
125-
126-
#### Conditioned Model Methods
127-
128-
```julia
129-
# Log-likelihood for a conditioned model
130-
function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
131-
...
132-
end
133-
134-
# Log-prior for a conditioned model
135-
function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
136-
...
137-
end
138-
139-
# Log-joint for a conditioned model
140-
function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
141-
return loglikelihood(model, params) + logprior(model, params)
142-
end
143-
```
144-
145-
- **Description**: Overrides log density computations to account for the conditioned data.
146-
147-
#### Fixed Model Methods
148-
149-
```julia
150-
# Log-likelihood for a fixed model
151-
function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
152-
...
153-
end
154-
155-
# Log-prior for a fixed model
156-
function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
157-
...
158-
end
159-
160-
# Log-joint for a fixed model
161-
function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
162-
return loglikelihood(model, data) + logprior(model, data)
163-
end
164-
```
165-
166-
- **Description**: Adjusts log density computations based on the fixed variables.
167-
168-
### Additional Functionalities
169-
170-
#### Generated Quantities
171-
172-
```julia
173-
# Compute generated quantities after fixing parameters
174-
function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple
175-
...
176-
end
177-
```
178-
179-
- **Description**: Computes additional expressions or functions based on the fixed model parameters.
180-
181-
#### Prediction
182-
183-
```julia
184-
# Predict data based on fixed parameters
185-
function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple
186-
...
187-
end
188-
```
189-
190-
- **Description**: Generates predictions by fixing the parameters and unconditioning the data.
191-
192-
## Advantages of the Proposed Interface
193-
194-
- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling.
195-
196-
- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side.
197-
198-
- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve.
199-
200-
## Usage Examples
201-
202-
## Non-Parametric Models
39+
`generated_quantities` can be implemented by `fix`ing the model on `sample` and calling `evaluate`.
40+
`predict` can be implemented by `uncondition`ing the model on `data`, fixing it on `sample`, and calling `evaluate`.

0 commit comments

Comments
 (0)