|
1 | 1 | # Proposal for a New LogDensity Function Interface |
2 | 2 |
|
3 | | -## Introduction |
| 3 | +<https://github.com/TuringLang/DynamicPPL.jl/issues/691> |
4 | 4 |
|
5 | | -The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate: |
| 5 | +The goal is to design a flexible, user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts like Gibbs sampling and Bayesian workflows. |
6 | 6 |
|
7 | | -- **Conditioning**: Incorporating observed data into the model. |
8 | | -- **Fixing**: Fixing certain variables to specific values. (like `do` operator) |
9 | | -- **Generated Quantities**: Computing additional expressions or functions based on the model parameters. |
10 | | -- **Prediction**: Making predictions by fixing parameters and unconditioning on data. |
| 7 | +## Evaluation functions: |
11 | 8 |
|
12 | | -This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs). |
| 9 | +1. `evaluate` |
13 | 10 |
|
14 | | -## Proposed Interface |
| 11 | +## Query functions: |
15 | 12 |
|
16 | | -Below is a proposed interface with key functionalities and their implementations. |
| 13 | +1. `is_parametric(model)` |
| 14 | +2. `dimension(model)` (only defined when `is_parametric(model) == true`) |
| 15 | +3. `is_conditioned(model)` |
| 16 | +4. `is_fixed(model)` |
| 17 | +5. `logjoint(model, params)` |
| 18 | +6. `loglikelihood(model, params)` |
| 19 | +7. `logprior(model, params)` |
17 | 20 |
|
18 | | -### Core Functions |
| 21 | +where `params` can be `Vector`, `NamedTuple`, `Dict`, etc. |
19 | 22 |
|
20 | | -#### Check if a Model is Parametric |
| 23 | +## Transformation functions: |
21 | 24 |
|
22 | | -```julia |
23 | | -# Check if a log density model is parametric |
24 | | -function is_parametric(model::LogDensityModel) -> Bool |
25 | | - ... |
26 | | -end |
27 | | -``` |
| 25 | +1. `condition(model, conditioned_vars)` |
| 26 | +2. `fix(model, fixed_vars)` |
| 27 | +3. `factor(model, variables_in_the_factor)` |
28 | 28 |
|
29 | | -- **Description**: Determines if the model has a parameter space with a defined dimension. |
30 | | -- |
| 29 | +`condition` and `factor` are similar, but `factor` effectively generates a sub-model. |
31 | 30 |
|
32 | | -#### Get the Dimension of a Parametric Model |
| 31 | +## Higher-order functions: |
33 | 32 |
|
34 | | -```julia |
35 | | -# Get the dimension of the parameter space (only defined when is_parametric(model) is true) |
36 | | -function dimension(model::LogDensityModel) -> Int |
37 | | - ... |
38 | | -end |
39 | | -``` |
| 33 | +1. `generated_quantities(model, sample, [, expr])` or `generated_quantities(model, sample, f, args...)` |
| 34 | + 1. `generated_quantities` computes things from the sampling result. |
| 35 | + 2. In `DynamicPPL`, this is the model's return value. For more flexibility, we should allow passing an expression or function. (Currently, users can rewrite the model definition to achieve this in `DynamicPPL`, but with limitations. We want to make this more generic.) |
| 36 | + 3. `rand` is a special case of `generated_quantities` (when no sample is passed). |
| 37 | +2. `predict(model, sample)` |
40 | 38 |
|
41 | | -- **Description**: Returns the dimension of the parameter space for parametric models. |
42 | | - |
43 | | -### Log Density Computations |
44 | | - |
45 | | -#### Log-Likelihood |
46 | | - |
47 | | -```julia |
48 | | -# Compute the log-likelihood given parameters |
49 | | -function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
50 | | - ... |
51 | | -end |
52 | | -``` |
53 | | - |
54 | | -- **Description**: Computes the log-likelihood of the data given the model parameters. |
55 | | - |
56 | | -#### Log-Prior |
57 | | - |
58 | | -```julia |
59 | | -# Compute the log-prior given parameters |
60 | | -function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
61 | | - ... |
62 | | -end |
63 | | -``` |
64 | | - |
65 | | -- **Description**: Computes the log-prior probability of the model parameters. |
66 | | - |
67 | | -#### Log-Joint |
68 | | - |
69 | | -```julia |
70 | | -# Compute the log-joint density (log-likelihood + log-prior) |
71 | | -function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
72 | | - return loglikelihood(model, params) + logprior(model, params) |
73 | | -end |
74 | | -``` |
75 | | - |
76 | | -- **Description**: Computes the total log density by summing the log-likelihood and log-prior. |
77 | | - |
78 | | -### Conditioning and Fixing Variables |
79 | | - |
80 | | -#### Conditioning a Model |
81 | | - |
82 | | -```julia |
83 | | -# Condition the model on observed data |
84 | | -function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel |
85 | | - ... |
86 | | -end |
87 | | -``` |
88 | | - |
89 | | -- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`. |
90 | | - |
91 | | -#### Checking if a Model is Conditioned |
92 | | - |
93 | | -```julia |
94 | | -# Check if a model is conditioned |
95 | | -function is_conditioned(model::LogDensityModel) -> Bool |
96 | | - ... |
97 | | -end |
98 | | -``` |
99 | | - |
100 | | -- **Description**: Checks whether the model has been conditioned on data. |
101 | | - |
102 | | -#### Fixing Variables in a Model |
103 | | - |
104 | | -```julia |
105 | | -# Fix certain variables in the model |
106 | | -function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel |
107 | | - ... |
108 | | -end |
109 | | -``` |
110 | | - |
111 | | -- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`. |
112 | | - |
113 | | -#### Checking if a Model has Fixed Variables |
114 | | - |
115 | | -```julia |
116 | | -# Check if a model has fixed variables |
117 | | -function is_fixed(model::LogDensityModel) -> Bool |
118 | | - ... |
119 | | -end |
120 | | -``` |
121 | | - |
122 | | -- **Description**: Determines if any variables in the model have been fixed. |
123 | | - |
124 | | -### Specialized Models |
125 | | - |
126 | | -#### Conditioned Model Methods |
127 | | - |
128 | | -```julia |
129 | | -# Log-likelihood for a conditioned model |
130 | | -function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
131 | | - ... |
132 | | -end |
133 | | - |
134 | | -# Log-prior for a conditioned model |
135 | | -function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
136 | | - ... |
137 | | -end |
138 | | - |
139 | | -# Log-joint for a conditioned model |
140 | | -function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
141 | | - return loglikelihood(model, params) + logprior(model, params) |
142 | | -end |
143 | | -``` |
144 | | - |
145 | | -- **Description**: Overrides log density computations to account for the conditioned data. |
146 | | - |
147 | | -#### Fixed Model Methods |
148 | | - |
149 | | -```julia |
150 | | -# Log-likelihood for a fixed model |
151 | | -function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 |
152 | | - ... |
153 | | -end |
154 | | - |
155 | | -# Log-prior for a fixed model |
156 | | -function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 |
157 | | - ... |
158 | | -end |
159 | | - |
160 | | -# Log-joint for a fixed model |
161 | | -function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 |
162 | | - return loglikelihood(model, data) + logprior(model, data) |
163 | | -end |
164 | | -``` |
165 | | - |
166 | | -- **Description**: Adjusts log density computations based on the fixed variables. |
167 | | - |
168 | | -### Additional Functionalities |
169 | | - |
170 | | -#### Generated Quantities |
171 | | - |
172 | | -```julia |
173 | | -# Compute generated quantities after fixing parameters |
174 | | -function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple |
175 | | - ... |
176 | | -end |
177 | | -``` |
178 | | - |
179 | | -- **Description**: Computes additional expressions or functions based on the fixed model parameters. |
180 | | - |
181 | | -#### Prediction |
182 | | - |
183 | | -```julia |
184 | | -# Predict data based on fixed parameters |
185 | | -function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple |
186 | | - ... |
187 | | -end |
188 | | -``` |
189 | | - |
190 | | -- **Description**: Generates predictions by fixing the parameters and unconditioning the data. |
191 | | - |
192 | | -## Advantages of the Proposed Interface |
193 | | - |
194 | | -- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling. |
195 | | - |
196 | | -- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side. |
197 | | - |
198 | | -- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve. |
199 | | - |
200 | | -## Usage Examples |
201 | | - |
202 | | -## Non-Parametric Models |
| 39 | +`generated_quantities` can be implemented by `fix`ing the model on `sample` and calling `evaluate`. |
| 40 | +`predict` can be implemented by `uncondition`ing the model on `data`, fixing it on `sample`, and calling `evaluate`. |
0 commit comments