|
1 | 1 | # Proposal for a New LogDensity Function Interface
|
2 | 2 |
|
3 |
| -## Introduction |
| 3 | +<https://github.com/TuringLang/DynamicPPL.jl/issues/691> |
4 | 4 |
|
5 |
| -The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate: |
| 5 | +The goal is to design a flexible, user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts like Gibbs sampling and Bayesian workflows. |
6 | 6 |
|
7 |
| -- **Conditioning**: Incorporating observed data into the model. |
8 |
| -- **Fixing**: Fixing certain variables to specific values. (like `do` operator) |
9 |
| -- **Generated Quantities**: Computing additional expressions or functions based on the model parameters. |
10 |
| -- **Prediction**: Making predictions by fixing parameters and unconditioning on data. |
| 7 | +## Evaluation functions: |
11 | 8 |
|
12 |
| -This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs). |
| 9 | +1. `evaluate` |
13 | 10 |
|
14 |
| -## Proposed Interface |
| 11 | +## Query functions: |
15 | 12 |
|
16 |
| -Below is a proposed interface with key functionalities and their implementations. |
| 13 | +1. `is_parametric(model)` |
| 14 | +2. `dimension(model)` (only defined when `is_parametric(model) == true`) |
| 15 | +3. `is_conditioned(model)` |
| 16 | +4. `is_fixed(model)` |
| 17 | +5. `logjoint(model, params)` |
| 18 | +6. `loglikelihood(model, params)` |
| 19 | +7. `logprior(model, params)` |
17 | 20 |
|
18 |
| -### Core Functions |
| 21 | +where `params` can be `Vector`, `NamedTuple`, `Dict`, etc. |
19 | 22 |
|
20 |
| -#### Check if a Model is Parametric |
| 23 | +## Transformation functions: |
21 | 24 |
|
22 |
| -```julia |
23 |
| -# Check if a log density model is parametric |
24 |
| -function is_parametric(model::LogDensityModel) -> Bool |
25 |
| - ... |
26 |
| -end |
27 |
| -``` |
| 25 | +1. `condition(model, conditioned_vars)` |
| 26 | +2. `fix(model, fixed_vars)` |
| 27 | +3. `factor(model, variables_in_the_factor)` |
28 | 28 |
|
29 |
| -- **Description**: Determines if the model has a parameter space with a defined dimension. |
30 |
| -- |
| 29 | +`condition` and `factor` are similar, but `factor` effectively generates a sub-model. |
31 | 30 |
|
32 |
| -#### Get the Dimension of a Parametric Model |
| 31 | +## Higher-order functions: |
33 | 32 |
|
34 |
| -```julia |
35 |
| -# Get the dimension of the parameter space (only defined when is_parametric(model) is true) |
36 |
| -function dimension(model::LogDensityModel) -> Int |
37 |
| - ... |
38 |
| -end |
39 |
| -``` |
| 33 | +1. `generated_quantities(model, sample, [, expr])` or `generated_quantities(model, sample, f, args...)` |
| 34 | + 1. `generated_quantities` computes things from the sampling result. |
| 35 | + 2. In `DynamicPPL`, this is the model's return value. For more flexibility, we should allow passing an expression or function. (Currently, users can rewrite the model definition to achieve this in `DynamicPPL`, but with limitations. We want to make this more generic.) |
| 36 | + 3. `rand` is a special case of `generated_quantities` (when no sample is passed). |
| 37 | +2. `predict(model, sample)` |
40 | 38 |
|
41 |
| -- **Description**: Returns the dimension of the parameter space for parametric models. |
42 |
| - |
43 |
| -### Log Density Computations |
44 |
| - |
45 |
| -#### Log-Likelihood |
46 |
| - |
47 |
| -```julia |
48 |
| -# Compute the log-likelihood given parameters |
49 |
| -function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
50 |
| - ... |
51 |
| -end |
52 |
| -``` |
53 |
| - |
54 |
| -- **Description**: Computes the log-likelihood of the data given the model parameters. |
55 |
| - |
56 |
| -#### Log-Prior |
57 |
| - |
58 |
| -```julia |
59 |
| -# Compute the log-prior given parameters |
60 |
| -function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
61 |
| - ... |
62 |
| -end |
63 |
| -``` |
64 |
| - |
65 |
| -- **Description**: Computes the log-prior probability of the model parameters. |
66 |
| - |
67 |
| -#### Log-Joint |
68 |
| - |
69 |
| -```julia |
70 |
| -# Compute the log-joint density (log-likelihood + log-prior) |
71 |
| -function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
72 |
| - return loglikelihood(model, params) + logprior(model, params) |
73 |
| -end |
74 |
| -``` |
75 |
| - |
76 |
| -- **Description**: Computes the total log density by summing the log-likelihood and log-prior. |
77 |
| - |
78 |
| -### Conditioning and Fixing Variables |
79 |
| - |
80 |
| -#### Conditioning a Model |
81 |
| - |
82 |
| -```julia |
83 |
| -# Condition the model on observed data |
84 |
| -function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel |
85 |
| - ... |
86 |
| -end |
87 |
| -``` |
88 |
| - |
89 |
| -- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`. |
90 |
| - |
91 |
| -#### Checking if a Model is Conditioned |
92 |
| - |
93 |
| -```julia |
94 |
| -# Check if a model is conditioned |
95 |
| -function is_conditioned(model::LogDensityModel) -> Bool |
96 |
| - ... |
97 |
| -end |
98 |
| -``` |
99 |
| - |
100 |
| -- **Description**: Checks whether the model has been conditioned on data. |
101 |
| - |
102 |
| -#### Fixing Variables in a Model |
103 |
| - |
104 |
| -```julia |
105 |
| -# Fix certain variables in the model |
106 |
| -function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel |
107 |
| - ... |
108 |
| -end |
109 |
| -``` |
110 |
| - |
111 |
| -- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`. |
112 |
| - |
113 |
| -#### Checking if a Model has Fixed Variables |
114 |
| - |
115 |
| -```julia |
116 |
| -# Check if a model has fixed variables |
117 |
| -function is_fixed(model::LogDensityModel) -> Bool |
118 |
| - ... |
119 |
| -end |
120 |
| -``` |
121 |
| - |
122 |
| -- **Description**: Determines if any variables in the model have been fixed. |
123 |
| - |
124 |
| -### Specialized Models |
125 |
| - |
126 |
| -#### Conditioned Model Methods |
127 |
| - |
128 |
| -```julia |
129 |
| -# Log-likelihood for a conditioned model |
130 |
| -function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
131 |
| - ... |
132 |
| -end |
133 |
| - |
134 |
| -# Log-prior for a conditioned model |
135 |
| -function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
136 |
| - ... |
137 |
| -end |
138 |
| - |
139 |
| -# Log-joint for a conditioned model |
140 |
| -function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 |
141 |
| - return loglikelihood(model, params) + logprior(model, params) |
142 |
| -end |
143 |
| -``` |
144 |
| - |
145 |
| -- **Description**: Overrides log density computations to account for the conditioned data. |
146 |
| - |
147 |
| -#### Fixed Model Methods |
148 |
| - |
149 |
| -```julia |
150 |
| -# Log-likelihood for a fixed model |
151 |
| -function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 |
152 |
| - ... |
153 |
| -end |
154 |
| - |
155 |
| -# Log-prior for a fixed model |
156 |
| -function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 |
157 |
| - ... |
158 |
| -end |
159 |
| - |
160 |
| -# Log-joint for a fixed model |
161 |
| -function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 |
162 |
| - return loglikelihood(model, data) + logprior(model, data) |
163 |
| -end |
164 |
| -``` |
165 |
| - |
166 |
| -- **Description**: Adjusts log density computations based on the fixed variables. |
167 |
| - |
168 |
| -### Additional Functionalities |
169 |
| - |
170 |
| -#### Generated Quantities |
171 |
| - |
172 |
| -```julia |
173 |
| -# Compute generated quantities after fixing parameters |
174 |
| -function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple |
175 |
| - ... |
176 |
| -end |
177 |
| -``` |
178 |
| - |
179 |
| -- **Description**: Computes additional expressions or functions based on the fixed model parameters. |
180 |
| - |
181 |
| -#### Prediction |
182 |
| - |
183 |
| -```julia |
184 |
| -# Predict data based on fixed parameters |
185 |
| -function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple |
186 |
| - ... |
187 |
| -end |
188 |
| -``` |
189 |
| - |
190 |
| -- **Description**: Generates predictions by fixing the parameters and unconditioning the data. |
191 |
| - |
192 |
| -## Advantages of the Proposed Interface |
193 |
| - |
194 |
| -- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling. |
195 |
| - |
196 |
| -- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side. |
197 |
| - |
198 |
| -- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve. |
199 |
| - |
200 |
| -## Usage Examples |
201 |
| - |
202 |
| -## Non-Parametric Models |
| 39 | +`generated_quantities` can be implemented by `fix`ing the model on `sample` and calling `evaluate`. |
| 40 | +`predict` can be implemented by `uncondition`ing the model on `data`, fixing it on `sample`, and calling `evaluate`. |
0 commit comments