Skip to content

Commit 21df935

Browse files
committed
first draft
1 parent 9e946c0 commit 21df935

File tree

1 file changed

+202
-0
lines changed

1 file changed

+202
-0
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Proposal for a New LogDensity Function Interface
2+
3+
## Introduction
4+
5+
The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate:
6+
7+
- **Conditioning**: Incorporating observed data into the model.
8+
- **Fixing**: Fixing certain variables to specific values. (like `do` operator)
9+
- **Generated Quantities**: Computing additional expressions or functions based on the model parameters.
10+
- **Prediction**: Making predictions by fixing parameters and unconditioning on data.
11+
12+
This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs).
13+
14+
## Proposed Interface
15+
16+
Below is a proposed interface with key functionalities and their implementations.
17+
18+
### Core Functions
19+
20+
#### Check if a Model is Parametric
21+
22+
```julia
23+
# Check if a log density model is parametric
24+
function is_parametric(model::LogDensityModel) -> Bool
25+
...
26+
end
27+
```
28+
29+
- **Description**: Determines if the model has a parameter space with a defined dimension.
30+
-
31+
32+
#### Get the Dimension of a Parametric Model
33+
34+
```julia
35+
# Get the dimension of the parameter space (only defined when is_parametric(model) is true)
36+
function dimension(model::LogDensityModel) -> Int
37+
...
38+
end
39+
```
40+
41+
- **Description**: Returns the dimension of the parameter space for parametric models.
42+
43+
### Log Density Computations
44+
45+
#### Log-Likelihood
46+
47+
```julia
48+
# Compute the log-likelihood given parameters
49+
function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
50+
...
51+
end
52+
```
53+
54+
- **Description**: Computes the log-likelihood of the data given the model parameters.
55+
56+
#### Log-Prior
57+
58+
```julia
59+
# Compute the log-prior given parameters
60+
function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
61+
...
62+
end
63+
```
64+
65+
- **Description**: Computes the log-prior probability of the model parameters.
66+
67+
#### Log-Joint
68+
69+
```julia
70+
# Compute the log-joint density (log-likelihood + log-prior)
71+
function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
72+
return loglikelihood(model, params) + logprior(model, params)
73+
end
74+
```
75+
76+
- **Description**: Computes the total log density by summing the log-likelihood and log-prior.
77+
78+
### Conditioning and Fixing Variables
79+
80+
#### Conditioning a Model
81+
82+
```julia
83+
# Condition the model on observed data
84+
function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel
85+
...
86+
end
87+
```
88+
89+
- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`.
90+
91+
#### Checking if a Model is Conditioned
92+
93+
```julia
94+
# Check if a model is conditioned
95+
function is_conditioned(model::LogDensityModel) -> Bool
96+
...
97+
end
98+
```
99+
100+
- **Description**: Checks whether the model has been conditioned on data.
101+
102+
#### Fixing Variables in a Model
103+
104+
```julia
105+
# Fix certain variables in the model
106+
function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel
107+
...
108+
end
109+
```
110+
111+
- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`.
112+
113+
#### Checking if a Model has Fixed Variables
114+
115+
```julia
116+
# Check if a model has fixed variables
117+
function is_fixed(model::LogDensityModel) -> Bool
118+
...
119+
end
120+
```
121+
122+
- **Description**: Determines if any variables in the model have been fixed.
123+
124+
### Specialized Models
125+
126+
#### Conditioned Model Methods
127+
128+
```julia
129+
# Log-likelihood for a conditioned model
130+
function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
131+
...
132+
end
133+
134+
# Log-prior for a conditioned model
135+
function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
136+
...
137+
end
138+
139+
# Log-joint for a conditioned model
140+
function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
141+
return loglikelihood(model, params) + logprior(model, params)
142+
end
143+
```
144+
145+
- **Description**: Overrides log density computations to account for the conditioned data.
146+
147+
#### Fixed Model Methods
148+
149+
```julia
150+
# Log-likelihood for a fixed model
151+
function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
152+
...
153+
end
154+
155+
# Log-prior for a fixed model
156+
function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
157+
...
158+
end
159+
160+
# Log-joint for a fixed model
161+
function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
162+
return loglikelihood(model, data) + logprior(model, data)
163+
end
164+
```
165+
166+
- **Description**: Adjusts log density computations based on the fixed variables.
167+
168+
### Additional Functionalities
169+
170+
#### Generated Quantities
171+
172+
```julia
173+
# Compute generated quantities after fixing parameters
174+
function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple
175+
...
176+
end
177+
```
178+
179+
- **Description**: Computes additional expressions or functions based on the fixed model parameters.
180+
181+
#### Prediction
182+
183+
```julia
184+
# Predict data based on fixed parameters
185+
function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple
186+
...
187+
end
188+
```
189+
190+
- **Description**: Generates predictions by fixing the parameters and unconditioning the data.
191+
192+
## Advantages of the Proposed Interface
193+
194+
- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling.
195+
196+
- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side.
197+
198+
- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve.
199+
200+
## Usage Examples
201+
202+
## Non-Parametric Models

0 commit comments

Comments
 (0)