Skip to content

Commit 2b8d551

Browse files
Ensure only model parameters are included in the sorted parameter list (#292)
If a Julia code to compute log density can be generated for a model, then the topo order of nodes in the model need to be updated to sync with the generated Julia source. The original code used the topologically sorted list of all nodes (`pass.sorted_nodes`) to determine the final ordered list of parameters for the `Model`. However, `pass.sorted_nodes` can include variables that are purely "transformed data" - deterministic nodes computed entirely from provided data constants at compile time (as described in `docs/src/source_gen.md` under "Handling Mixed Data Transformation and Deterministic Assignments"). For example, if `y = [1, 2, missing, missing, 2]` and the model has `x[i] = y[i] + 1`, then `x[1]`, `x[2]`, and `x[5]` are transformed data, while only `x[3]` and `x[4]` (dependent on the missing `y` values) are actual deterministic parameters within the model graph. Including transformed data nodes in the final `parameters` list is incorrect, as they are constants and not part of the model's parameter space to be evaluated or sampled. This PR is a simple fix by filtering the `pass.sorted_nodes` list, keeping only those `VarName`s that are already present in the `parameters` set. This ensures the final `parameters` field of the `Model` struct contains only true model parameters, correctly ordered according to the graph's topology. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 2213aa2 commit 2b8d551

File tree

4 files changed

+62
-61
lines changed

4 files changed

+62
-61
lines changed

.github/workflows/TestsMacOS.yml

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ jobs:
1212
test:
1313
name: Julia ${{ matrix.version }} on ${{ matrix.os }} (${{ matrix.arch }})
1414
runs-on: ${{ matrix.os }}
15-
continue-on-error: ${{ matrix.version == 'nightly' }}
15+
continue-on-error: ${{ matrix.version == 'pre' }}
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
version: ['1', '1.10', 'nightly']
19+
version: ['1', 'min', 'pre']
2020
os: [macOS-latest]
2121
arch: [aarch64]
2222

@@ -55,6 +55,11 @@ jobs:
5555
env:
5656
TEST_GROUP: "log_density"
5757

58+
- name: Running `source_gen` tests
59+
uses: julia-actions/julia-runtest@v1
60+
env:
61+
TEST_GROUP: "source_gen"
62+
5863
- name: Running `gibbs` tests
5964
uses: nick-fields/retry@v3
6065
with:
@@ -77,17 +82,4 @@ jobs:
7782
uses: julia-actions/julia-runtest@v1
7883
env:
7984
TEST_GROUP: "experimental"
80-
81-
- uses: julia-actions/julia-processcoverage@v1
82-
if: matrix.coverage
8385

84-
- uses: codecov/codecov-action@v4
85-
if: matrix.coverage
86-
with:
87-
file: lcov.info
88-
89-
- uses: coverallsapp/github-action@master
90-
if: matrix.coverage
91-
with:
92-
github-token: ${{ secrets.GITHUB_TOKEN }}
93-
path-to-lcov: lcov.info

docs/src/source_gen.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ end
457457

458458
We made a simple change to the program to prepare for lowering: we need to distinguish between observations and model parameters (because they correspond to different code). We introduce a new operator into the program `\eqsim` to indicate that the left hand side is an observation.
459459

460+
### Handling Mixed Observations and Parameters
461+
460462
BUGS supports mixing observations and model parameters for different elements of the same array variable.
461463
To support this, we introduce a guard to use conditional logic to decide what computation to do for different iteration of the same statement.
462464

@@ -487,3 +489,27 @@ begin
487489
end
488490
end
489491
```
492+
493+
### Handling Mixed Data Transformation and Deterministic Assignments
494+
495+
For instance
496+
497+
```julia
498+
for i in 1:5
499+
x[i] = y[i] + 1
500+
end
501+
```
502+
503+
if the data is
504+
505+
```julia
506+
y = [1, 2, missing, missing, 2]
507+
```
508+
509+
this is generally allowed in BUGS.
510+
511+
`x[1], x[2], x[5]` can be computed at compile time, so these are "transformed data".
512+
`x[3], x[4]` need to be computed at evaluation time.
513+
And only `x[3]` and `x[4]` are in the compiled graph.
514+
515+
For generated Julia program, if a statement can be eliminated because all the variables stemmed from this statement are "transformed data". While in the above case, where a statements corresponds to both transformed data and deterministic variables. It will be left in the generated program as is. In this case, there will be redundant computation.

src/model.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ function BUGSModel(
233233
sorted_nodes = pass.sorted_nodes
234234
original_parameters_length = length(parameters)
235235
parameters = VarName[vn for vn in sorted_nodes if vn in parameters]
236+
sorted_nodes = [
237+
vn for vn in sorted_nodes if vn in flattened_graph_node_data.sorted_nodes
238+
]
236239
@assert length(parameters) == original_parameters_length "there are less parameters in the generated log density function than in the original model"
237240
flattened_graph_node_data = FlattenedGraphNodeData(g, sorted_nodes)
238241
else

test/source_gen.jl

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
using BangBang
22
using Bijectors
33
using JuliaBUGS
4-
using JuliaBUGS: _generate_lowered_model_def, _gen_log_density_computation_function_expr
5-
using JuliaBUGS: CollectSortedNodes
64
using JuliaBUGS.BUGSPrimitives
75
using LogDensityProblems
86
using OrderedCollections
97

10-
# bones, mice, kidney have missings in data
118
test_examples = [
129
:rats,
1310
:pumps,
@@ -32,52 +29,21 @@ test_examples = [
3229
:air,
3330
:birats,
3431
:schools,
32+
:cervix,
3533
]
3634

37-
function _create_model(model_name::Symbol)
38-
(; model_def, data, inits) = getfield(JuliaBUGS.BUGSExamples, model_name)
39-
model = compile(model_def, data, inits)
40-
evaluation_env = deepcopy(model.evaluation_env)
41-
return model, evaluation_env
42-
end
43-
44-
function _create_bugsmodel_with_consistent_sorted_nodes(
45-
model::JuliaBUGS.BUGSModel, reconstructed_model_def
46-
)
47-
pass = CollectSortedNodes(model.evaluation_env)
48-
JuliaBUGS.analyze_block(pass, reconstructed_model_def)
49-
sorted_nodes = pass.sorted_nodes
50-
sorted_parameters = [vn for vn in sorted_nodes if vn in model.parameters]
51-
new_flattened_graph_node_data = JuliaBUGS.FlattenedGraphNodeData(model.g, sorted_nodes)
52-
new_model = BangBang.setproperty!!(model, :parameters, sorted_parameters)
53-
new_model = BangBang.setproperty!!(
54-
new_model, :flattened_graph_node_data, new_flattened_graph_node_data
55-
)
56-
return new_model
57-
end
58-
5935
@testset "source_gen: $example_name" for example_name in test_examples
60-
model, evaluation_env = _create_model(example_name)
61-
lowered_model_def, reconstructed_model_def = _generate_lowered_model_def(
62-
model.model_def, model.g, evaluation_env
63-
)
64-
log_density_computation_expr = _gen_log_density_computation_function_expr(
65-
lowered_model_def, evaluation_env, gensym(example_name)
66-
)
67-
log_density_computation_function = eval(log_density_computation_expr)
68-
69-
model_with_consistent_sorted_nodes = _create_bugsmodel_with_consistent_sorted_nodes(
70-
model, reconstructed_model_def
71-
)
72-
result_with_old_model = JuliaBUGS.evaluate!!(model)[2]
73-
params = JuliaBUGS.getparams(model_with_consistent_sorted_nodes)
74-
result_with_bugsmodel = JuliaBUGS.evaluate!!(
75-
model_with_consistent_sorted_nodes, params
76-
)[2]
77-
result_with_log_density_computation_function = log_density_computation_function(
78-
evaluation_env, params
79-
)
80-
@test result_with_old_model result_with_bugsmodel
36+
(; model_def, data, inits) = getfield(JuliaBUGS.BUGSExamples, example_name)
37+
model = compile(model_def, data, inits)
38+
params = JuliaBUGS.getparams(model)
39+
result_with_bugsmodel = begin
40+
model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph())
41+
LogDensityProblems.logdensity(model, params)
42+
end
43+
result_with_log_density_computation_function = begin
44+
model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction())
45+
LogDensityProblems.logdensity(model, params)
46+
end
8147
@test result_with_log_density_computation_function result_with_bugsmodel
8248
end
8349

@@ -88,3 +54,17 @@ end
8854
end
8955
)
9056
end
57+
58+
@testset "mixed data transformation and deterministic assignments" begin
59+
model_def = @bugs begin
60+
for i in 1:5
61+
y[i] ~ Normal(0, 1)
62+
end
63+
for i in 1:5
64+
x[i] = y[i] + 1
65+
end
66+
end
67+
data = (; y=[1, 2, missing, missing, 2])
68+
69+
model = compile(model_def, data)
70+
end

0 commit comments

Comments
 (0)