Skip to content

Commit d034d65

Browse files
sunxd3mhauruyebaigithub-actions[bot]
authored
Implement Turing-like model macro (#291)
Address #281 Edited: I wrote some documentation in markdown for this PR and design, there is a slight change from last weeks' Monday meeting. The markdown doc should be concise enough and I don't think I can make it much shorter without losing essential information. So could the reviews, please, give https://github.com/TuringLang/JuliaBUGS.jl/blob/sunxd/new_model_syntax/docs/src/julia_syntax.md a quick read? --------- Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 098451b commit d034d65

File tree

5 files changed

+587
-8
lines changed

5 files changed

+587
-8
lines changed

docs/src/julia_syntax.md

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# How to Specify and Create a `BUGSModel`
2+
3+
Creating a `BUGSModel` requires two key components: a BUGS program that defines the model structure and values for specific variables that parameterize the model.
4+
5+
To understand how to specify a model properly, it is important to distinguish between the different types of values you can provide to the JuliaBUGS compiler:
6+
7+
* **Constants**: Values used in loop bounds and index resolution
8+
* These are essential for model specification as they determine the model's dimensionality (how many variables are created) and establish the dependency structure between variables
9+
10+
* **Independent variables** (also called features, predictors, or covariates): Non-stochastic inputs required for forward simulation of the model
11+
* Examples include predictor variables in a regression model or time points in a time series model
12+
13+
* **Observations**: Values for stochastic variables that you wish to condition on
14+
* These are not necessary to specify the model structure, but when provided, they become the data that your model is conditioned on
15+
* (Note: In some advanced cases, stochastic variables can contribute to the log density without being part of a strictly generative model)
16+
17+
* **Initialization values**: Starting points for MCMC sampling
18+
* While optional in many cases, some models (particularly those with weakly informative priors or complex structures) require carefully chosen initialization values for effective sampling
19+
20+
## Syntax from previous BUGS softwares and their R packages
21+
22+
Traditionally, BUGS models were created through a software interface following these steps:
23+
1. Write the model in a text file
24+
2. Check the model syntax (parsing)
25+
3. Compile the model with program text and data
26+
4. Initialize the sampling process (optional)
27+
28+
R interface packages for BUGS maintained this workflow pattern through text-based interfaces that closely mirrored the original software.
29+
30+
JuliaBUGS initially adopted this familiar workflow to accommodate users with prior BUGS experience. Specifically, JuliaBUGS provides a `@bugs` macro that accepts model definitions either as strings or within a `begin...end` block:
31+
32+
```julia
33+
# Example using string macro
34+
@bugs"""
35+
model {
36+
for( i in 1 : N ) {
37+
r[i] ~ dbin(p[i],n[i])
38+
b[i] ~ dnorm(0.0,tau)
39+
logit(p[i]) <- alpha0 + alpha1 * x1[i] + alpha2 * x2[i] +
40+
alpha12 * x1[i] * x2[i] + b[i]
41+
}
42+
alpha0 ~ dnorm(0.0,1.0E-6)
43+
alpha1 ~ dnorm(0.0,1.0E-6)
44+
alpha2 ~ dnorm(0.0,1.0E-6)
45+
alpha12 ~ dnorm(0.0,1.0E-6)
46+
tau ~ dgamma(0.001,0.001)
47+
sigma <- 1 / sqrt(tau)
48+
}
49+
"""
50+
51+
# Example using block macro
52+
@bugs begin
53+
for i in 1:N
54+
r[i] ~ dbin(p[i], n[i])
55+
b[i] ~ dnorm(0.0, tau)
56+
p[i] = logistic(alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] +
57+
b[i])
58+
end
59+
alpha0 ~ dnorm(0.0, 1.0e-6)
60+
alpha1 ~ dnorm(0.0, 1.0e-6)
61+
alpha2 ~ dnorm(0.0, 1.0e-6)
62+
alpha12 ~ dnorm(0.0, 1.0e-6)
63+
tau ~ dgamma(0.001, 0.001)
64+
sigma = 1 / sqrt(tau)
65+
end
66+
```
67+
68+
In both cases, the macro returns a Julia AST representation of the model. The `compile` function then takes this AST and user-provided values (as a `NamedTuple`) to create a `BUGSModel` instance.
69+
70+
While we maintain this interface for compatibility, we now also offer a more idiomatic Julia approach.
71+
72+
## The Interface
73+
74+
JuliaBUGS provides a Julian interface inspired by Turing.jl's model macro syntax. The `@model` macro creates a "model creating function" that returns a model object supporting operations like `AbstractMCMC.sample` (which samples MCMC chains) and `condition` (which modifies the model by incorporating observations).
75+
76+
### The `@model` Macro
77+
78+
```julia
79+
JuliaBUGS.@model function model_definition((;r, b, alpha0, alpha1, alpha2, alpha12, tau)::SeedsParams, x1, x2, N, n)
80+
for i in 1:N
81+
r[i] ~ dbin(p[i], n[i])
82+
b[i] ~ dnorm(0.0, tau)
83+
p[i] = logistic(alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + b[i])
84+
end
85+
alpha0 ~ dnorm(0.0, 1.0E-6)
86+
alpha1 ~ dnorm(0.0, 1.0E-6)
87+
alpha2 ~ dnorm(0.0, 1.0E-6)
88+
alpha12 ~ dnorm(0.0, 1.0E-6)
89+
tau ~ dgamma(0.001, 0.001)
90+
sigma = 1 / sqrt(tau)
91+
end
92+
```
93+
94+
The `@model` macro requires a specific function signature:
95+
96+
1. The first argument must declare stochastic parameters (variables defined with `~`) using destructuring assignment with the format `(; param1, param2, ...)`.
97+
2. We recommend providing a type annotation (e.g., `(; r, b, ...)::SeedsParams`). If `SeedsParams` is defined using `@parameters`, the macro automatically defines a constructor `SeedsParams(model::BUGSModel)` for extracting parameter values from the model.
98+
3. Alternatively, you can use a `NamedTuple` instead of a custom type. In this case, no type annotation is needed, but you would need to manually create a `NamedTuple` with `ParameterPlaceholder()` values or arrays of `missing` values for parameters that don't have observations.
99+
4. The remaining arguments must specify all constants and independent variables required by the model (variables used on the RHS but not on the LHS).
100+
101+
The `@parameters` macro simplifies creating structs to hold model parameters:
102+
103+
```julia
104+
JuliaBUGS.@parameters struct SeedsParams
105+
r
106+
b
107+
alpha0
108+
alpha1
109+
alpha2
110+
alpha12
111+
tau
112+
end
113+
```
114+
115+
This macro applies `Base.@kwdef` to enable keyword initialization and creates a no-argument constructor. By default, fields are initialized to `JuliaBUGS.ParameterPlaceholder`. The concrete types and sizes of parameters are determined during compilation when the model function is called with constants. A constructor `SeedsParams(::BUGSModel)` is created for easy extraction of parameter values.
116+
117+
### Example
118+
119+
```julia
120+
julia> @model function seeds(
121+
(; r, b, alpha0, alpha1, alpha2, alpha12, tau)::SeedsParams, x1, x2, N, n
122+
)
123+
for i in 1:N
124+
r[i] ~ dbin(p[i], n[i])
125+
b[i] ~ dnorm(0.0, tau)
126+
p[i] = logistic(
127+
alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + b[i]
128+
)
129+
end
130+
alpha0 ~ dnorm(0.0, 1.0E-6)
131+
alpha1 ~ dnorm(0.0, 1.0E-6)
132+
alpha2 ~ dnorm(0.0, 1.0E-6)
133+
alpha12 ~ dnorm(0.0, 1.0E-6)
134+
tau ~ dgamma(0.001, 0.001)
135+
sigma = 1 / sqrt(tau)
136+
end
137+
seeds (generic function with 1 method)
138+
139+
julia> (; x1, x2, N, n) = JuliaBUGS.BUGSExamples.seeds.data; # extract data from existing BUGS example
140+
141+
julia> @parameters struct SeedsParams
142+
r
143+
b
144+
alpha0
145+
alpha1
146+
alpha2
147+
alpha12
148+
tau
149+
end
150+
151+
julia> m = seeds(SeedsParams(), x1, x2, N, n)
152+
BUGSModel (parameters are in transformed (unconstrained) space, with dimension 47):
153+
154+
Model parameters:
155+
alpha2
156+
b[21], b[20], b[19], b[18], b[17], b[16], b[15], b[14], b[13], b[12], b[11], b[10], b[9], b[8], b[7], b[6], b[5], b[4], b[3], b[2], b[1]
157+
r[21], r[20], r[19], r[18], r[17], r[16], r[15], r[14], r[13], r[12], r[11], r[10], r[9], r[8], r[7], r[6], r[5], r[4], r[3], r[2], r[1]
158+
tau
159+
alpha12
160+
alpha1
161+
alpha0
162+
163+
Variable sizes and types:
164+
b: size = (21,), type = Vector{Float64}
165+
p: size = (21,), type = Vector{Float64}
166+
n: size = (21,), type = Vector{Int64}
167+
alpha2: type = Float64
168+
sigma: type = Float64
169+
alpha12: type = Float64
170+
alpha0: type = Float64
171+
N: type = Int64
172+
tau: type = Float64
173+
alpha1: type = Float64
174+
r: size = (21,), type = Vector{Float64}
175+
x1: size = (21,), type = Vector{Int64}
176+
x2: size = (21,), type = Vector{Int64}
177+
178+
julia> SeedsParams(m)
179+
SeedsParams:
180+
r = [0.0, 0.0, 0.0, 0.0, 39.0, 0.0, 0.0, 72.0, 0.0, 0.0 0.0, 0.0, 0.0, 0.0, 4.0, 12.0, 0.0, 0.0, 0.0, 0.0]
181+
b = [-Inf, -Inf, -Inf, -Inf, Inf, -Inf, -Inf, Inf, -Inf, -Inf -Inf, -Inf, -Inf, -Inf, Inf, Inf, -Inf, -Inf, -Inf, -Inf]
182+
alpha0 = -1423.52
183+
alpha1 = 1981.99
184+
alpha2 = -545.664
185+
alpha12 = 1338.25
186+
tau = 0.0
187+
```

src/JuliaBUGS.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,30 @@ include("source_gen.jl")
4343
include("BUGSExamples/BUGSExamples.jl")
4444

4545
function check_input(input::NamedTuple)
46+
valid_pairs = Pair{Symbol,Any}[]
4647
for (k, v) in pairs(input)
47-
if v isa AbstractArray
48-
if !(eltype(v) <: Union{Int,Float64,Missing})
48+
if v === missing
49+
continue # Skip missing values
50+
elseif v isa AbstractArray
51+
# Allow arrays containing Int, Float64, or Missing
52+
allowed_eltypes = Union{Int,Float64,Missing}
53+
if !(eltype(v) <: allowed_eltypes)
4954
error(
50-
"For array input, only Int, Float64, or Missing types are supported. Received: $(typeof(v)).",
55+
"For array input '$k', only elements of type $allowed_eltypes are supported. Received array with eltype: $(eltype(v)).",
5156
)
5257
end
53-
elseif v === missing
54-
error("Scalars cannot be missing. Received: $k")
55-
elseif !(v isa Union{Int,Float64})
56-
error("Scalars must be of type Int or Float64. Received: $k")
58+
push!(valid_pairs, k => v)
59+
elseif v isa Union{Int,Float64}
60+
# Allow scalar Int or Float64
61+
push!(valid_pairs, k => v)
62+
else
63+
# Error for other scalar types
64+
error(
65+
"Scalar input '$k' must be of type Int or Float64. Received: $(typeof(v))."
66+
)
5767
end
5868
end
59-
return input
69+
return NamedTuple(valid_pairs)
6070
end
6171
function check_input(input::Dict{KT,VT}) where {KT,VT}
6272
if isempty(input)
@@ -177,6 +187,16 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N
177187
)
178188
return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
179189
end
190+
# function compile(
191+
# model_str::String,
192+
# data::NamedTuple,
193+
# initial_params::NamedTuple=NamedTuple();
194+
# replace_period::Bool=true,
195+
# no_enclosure::Bool=false,
196+
# )
197+
# model_def = _bugs_string_input(model_str, replace_period, no_enclosure)
198+
# return compile(model_def, data, initial_params)
199+
# end
180200

181201
"""
182202
@register_primitive(expr)
@@ -253,6 +273,8 @@ Only defined with `MCMCChains` extension.
253273
"""
254274
function gen_chains end
255275

276+
include("model_macro.jl")
277+
256278
include("experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl")
257279

258280
end

0 commit comments

Comments
 (0)