|
3 | 3 |
|
4 | 4 | import GraphPPL: |
5 | 5 | NodeIdPlugin, |
6 | | - EmptyID, |
7 | 6 | NodeCreationOptions, |
8 | 7 | PluginsCollection, |
9 | 8 | add_atomic_factor_node!, |
10 | 9 | create_model, |
11 | 10 | getcontext, |
12 | 11 | hasextra, |
13 | 12 | getextra, |
14 | | - by_nodeid |
| 13 | + create_model, |
| 14 | + with_plugins |
15 | 15 |
|
16 | 16 | include("../testutils.jl") |
17 | | - |
18 | | - model = create_test_model(plugins = PluginsCollection(NodeIdPlugin())) |
| 17 | + @model function node_with_two_anonymous() |
| 18 | + x[1] ~ Normal(0, 1) |
| 19 | + y[1] ~ Normal(0, 1) |
| 20 | + for i in 2:10 |
| 21 | + y[i] ~ Normal(0, 1) |
| 22 | + x[i] ~ Normal(y[i - 1] + 1, y[i] + 1) |
| 23 | + end |
| 24 | + end |
| 25 | + model = create_model(with_plugins(node_with_two_anonymous(), GraphPPL.PluginsCollection(NodeIdPlugin()))) |
19 | 26 | ctx = getcontext(model) |
20 | 27 |
|
21 | 28 | @testset begin |
22 | | - label1, nodedata1, properties1 = add_atomic_factor_node!(model, ctx, NodeCreationOptions(id = 1), Normal) |
23 | | - label2, nodedata2, properties2 = add_atomic_factor_node!(model, ctx, NodeCreationOptions(id = "2"), Normal) |
24 | | - label3, nodedata3, properties3 = add_atomic_factor_node!(model, ctx, NodeCreationOptions(id = :id3), Normal) |
25 | | - label4, nodedata4, properties4 = add_atomic_factor_node!(model, ctx, NodeCreationOptions(), Normal) |
26 | | - label5, nodedata5, properties5 = add_atomic_factor_node!(model, ctx, NodeCreationOptions(id = 4), Normal) |
27 | | - label6, nodedata6, properties6 = add_atomic_factor_node!(model, ctx, NodeCreationOptions(id = 4), Normal) |
28 | | - |
29 | | - @test length(collect(filter(as_node(Normal), model))) === 6 |
30 | | - # Not all have the `id` label associated with them |
31 | | - @test !all(n -> hasextra(model[n], :id), collect(filter(as_node(Normal), model))) |
32 | | - # But at least some should have the `id` label associated with it |
33 | | - @test any(n -> hasextra(model[n], :id), collect(filter(as_node(Normal), model))) |
34 | | - |
35 | | - # id = 1 |
36 | | - @test length(collect(filter(by_nodeid(1), model))) === 1 |
37 | | - @test model[first(collect(filter(by_nodeid(1), model)))] === nodedata1 |
38 | | - |
39 | | - # id = "2" |
40 | | - @test length(collect(filter(by_nodeid("2"), model))) === 1 |
41 | | - @test model[first(collect(filter(by_nodeid("2"), model)))] === nodedata2 |
42 | | - |
43 | | - # id = :id3 |
44 | | - @test length(collect(filter(by_nodeid(:id3), model))) === 1 |
45 | | - @test model[first(collect(filter(by_nodeid(:id3), model)))] === nodedata3 |
| 29 | + nodes = collect(filter(as_node(), model)) |
| 30 | + nodedata = getindex.(Ref(model), nodes) |
| 31 | + for node in nodedata |
| 32 | + @test hasextra(node, :id) |
| 33 | + end |
46 | 34 |
|
47 | | - # id = 4 |
48 | | - @test length(collect(filter(by_nodeid(4), model))) === 2 |
49 | | - @test model[collect(filter(by_nodeid(4), model))[1]] === nodedata5 |
50 | | - @test model[collect(filter(by_nodeid(4), model))[2]] === nodedata6 |
| 35 | + @test length(unique(getextra.(nodedata, :id))) == length(nodedata) |
51 | 36 | end |
52 | 37 | end |
0 commit comments