Skip to content

Commit a630f63

Browse files
authored
Merge pull request #260 from ReactiveBayes/node_id
Add node ID plugin and rename node tag plugin
2 parents 7e25c6c + 46c0470 commit a630f63

File tree

14 files changed

+200
-123
lines changed

14 files changed

+200
-123
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphPPL"
22
uuid = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
33
authors = ["Wouter Nuijten <[email protected]>", "Dmitry Bagaev <[email protected]>"]
4-
version = "4.4.1"
4+
version = "4.5.0"
55

66
[deps]
77
BitSetTuples = "0f2f92aa-23a3-4d05-b791-88071d064721"

docs/make.jl

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,11 @@ makedocs(
66
modules = [GraphPPL],
77
clean = true,
88
sitename = "GraphPPL.jl",
9-
pages = [
10-
"Home" => "index.md",
11-
"Getting Started" => "getting_started.md",
12-
"Syntax Guide" => "syntax_guide.md",
13-
"Nested Models" => "nested_models.md",
14-
"Plugins" => [
15-
"Overview" => "plugins/overview.md",
16-
"Variational Inference & Constraints" => "plugins/constraint_specification.md",
17-
"Attaching metadata to nodes" => "plugins/meta_specification.md",
18-
"Tracking creation of nodes" => "plugins/created_by.md",
19-
"Setting ID of nodes" => "plugins/node_id.md",
20-
],
21-
"Migration Guide (from v3 to v4)" => "migration_3_to_4.md",
22-
"Developers Guide" => "developers_guide.md",
23-
"Custom backend" => "custom_backend.md"
24-
],
9+
pages = ["Home" => "index.md", "Getting Started" => "getting_started.md", "Syntax Guide" => "syntax_guide.md", "Nested Models" => "nested_models.md", "Plugins" => ["Overview" => "plugins/overview.md", "Variational Inference & Constraints" => "plugins/constraint_specification.md", "Attaching metadata to nodes" => "plugins/meta_specification.md", "Tracking creation of nodes" => "plugins/created_by.md", "Setting tag of nodes" => "plugins/node_tag.md", "Setting ID of nodes" => "plugins/node_id.md"], "Migration Guide (from v3 to v4)" => "migration_3_to_4.md", "Developers Guide" => "developers_guide.md", "Custom backend" => "custom_backend.md"],
2510
format = Documenter.HTML(prettyurls = get(ENV, "CI", nothing) == "true"),
2611
warnonly = false
2712
)
2813

2914
if get(ENV, "CI", nothing) == "true"
30-
deploydocs(
31-
repo = "github.com/ReactiveBayes/GraphPPL.jl.git",
32-
devbranch = "main",
33-
forcepush = true
34-
)
15+
deploydocs(repo = "github.com/ReactiveBayes/GraphPPL.jl.git", devbranch = "main", forcepush = true)
3516
end

docs/src/plugins/node_id.md

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ GraphPPL provides a built-in plugin to mark factor nodes with a specific ID for
66
GraphPPL.NodeIdPlugin
77
```
88

9-
The plugin allows to specify the `id` in the `where { ... }` block during the node construction. Here how it works:
9+
The plugin sets a field in the `GraphPPL.NodeData` for every factor node that acts as a unique identified. Here how it works:
1010

1111
```@example plugin-node-id
1212
using GraphPPL, Distributions, Test #hide
1313
import GraphPPL: @model #hide
1414
1515
@model function submodel(y, x, z)
16-
y ~ Normal(x, z) where { id = "from submodel" }
16+
y ~ Normal(x, z)
1717
end
1818
1919
@model function mainmodel()
@@ -23,7 +23,7 @@ end
2323
end
2424
```
2525

26-
In this example we have created three `Normal` factor nodes and would like to access the one which has been created within the `submodel`.
26+
In this example we have created three `Normal` factor nodes and would like to discern between them using only the `NodeData`.
2727
To do that, we need to instantiate our model with the [`GraphPPL.NodeIdPlugin`](@ref) plugin.
2828

2929
```@example plugin-node-id
@@ -36,18 +36,12 @@ model = GraphPPL.create_model(
3636
nothing #hide
3737
```
3838

39-
After, we can fetch all the nodes with a specific id using the [`GraphPPL.by_nodeid`](@ref) function.
40-
41-
```@docs
42-
GraphPPL.by_nodeid
43-
```
39+
After, we can fetch all the node ids.
4440

4541
```@example plugin-node-id
46-
labels = collect(filter(GraphPPL.by_nodeid("from submodel"), model))
47-
@test all(label -> GraphPPL.getname(label) == Normal, labels) #hide
48-
@test length(labels) === 1 #hide
42+
labels = collect(filter(GraphPPL.as_node(), model))
4943
foreach(labels) do label
50-
println(GraphPPL.getname(label))
44+
println(GraphPPL.getextra(model[label], :id))
5145
end
5246
nothing #hide
5347
```

docs/src/plugins/node_tag.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# [Node tag plugin](@id plugins-node-tag)
2+
3+
GraphPPL provides a built-in plugin to mark factor nodes with a specific tag for later analysis or debugging purposes.
4+
5+
```@docs
6+
GraphPPL.NodeTagPlugin
7+
```
8+
9+
The plugin allows to specify the `tag` in the `where { ... }` block during the node construction. Here how it works:
10+
11+
```@example plugin-node-tag
12+
using GraphPPL, Distributions, Test #hide
13+
import GraphPPL: @model #hide
14+
15+
@model function submodel(y, x, z)
16+
y ~ Normal(x, z) where { tag = "from submodel" }
17+
end
18+
19+
@model function mainmodel()
20+
x ~ Normal(0.0, 1.0)
21+
z ~ Normal(0.0, 1.0)
22+
y ~ submodel(x = x, z = z)
23+
end
24+
```
25+
26+
In this example we have created three `Normal` factor nodes and would like to access the one which has been created within the `submodel`.
27+
To do that, we need to instantiate our model with the [`GraphPPL.NodeTagPlugin`](@ref) plugin.
28+
29+
```@example plugin-node-tag
30+
model = GraphPPL.create_model(
31+
GraphPPL.with_plugins(
32+
mainmodel(),
33+
GraphPPL.PluginsCollection(GraphPPL.NodeTagPlugin())
34+
)
35+
)
36+
nothing #hide
37+
```
38+
39+
After, we can fetch all the nodes with a specific tag using the [`GraphPPL.by_nodetag`](@ref) function.
40+
41+
```@docs
42+
GraphPPL.by_nodetag
43+
```
44+
45+
```@example plugin-node-tag
46+
labels = collect(filter(GraphPPL.by_nodetag("from submodel"), model))
47+
@test all(label -> GraphPPL.getname(label) == Normal, labels) #hide
48+
@test length(labels) === 1 #hide
49+
foreach(labels) do label
50+
println(GraphPPL.getname(label))
51+
end
52+
nothing #hide
53+
```

docs/src/plugins/overview.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ The following plugins are available by default in `GraphPPL`:
1919
- [`GraphPPL.VariationalConstraintsPlugin`](@ref): adds [constraints](@ref constraints-specification) to the model that are used in variational inference.
2020
- [`GraphPPL.MetaPlugin`](@ref): adds arbitrary metadata to nodes in the model.
2121
- [`GraphPPL.NodeCreatedByPlugin`](@ref): adds information about the line of code that created the node.
22-
- [`GraphPPL.NodeIdPlugin`](@ref): allows attaching an `id` to factor nodes for later inspection.
22+
- [`GraphPPL.NodeTagPlugin`](@ref): allows attaching a `tag` to factor nodes for later inspection.
23+
- [`GraphPPL.NodeIdPlugin`](@ref): allows attaching a unique `id` to factor nodes for later inspection.
2324

2425
## Using a plugin
2526

src/GraphPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include("model_macro.jl")
1111

1212
include("plugins/node_created_by.jl")
1313
include("plugins/node_id.jl")
14+
include("plugins/node_tag.jl")
1415
include("plugins/variational_constraints/variational_constraints.jl")
1516
include("plugins/meta/meta.jl")
1617

src/graph_engine.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ Base.broadcastable(label::NodeLabel) = Ref(label)
255255

256256
getname(label::NodeLabel) = label.name
257257
getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(first(labels))
258-
getid(label::NodeLabel) = label.global_counter
259258
iterate(label::NodeLabel) = (label, nothing)
260259
iterate(label::NodeLabel, any) = nothing
261260

@@ -766,10 +765,9 @@ mutable struct NodeData
766765
const context :: Context
767766
const properties :: Union{VariableNodeProperties, FactorNodeProperties{NodeData}}
768767
const extra :: UnorderedDictionary{Symbol, Any}
769-
const id :: Int
770768
end
771769

772-
NodeData(context, properties, id) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}(), id)
770+
NodeData(context, properties) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}())
773771

774772
function Base.show(io::IO, nodedata::NodeData)
775773
context = getcontext(nodedata)
@@ -1531,7 +1529,7 @@ end
15311529
function __add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
15321530
# In theory plugins are able to overwrite this
15331531
potential_label = generate_nodelabel(model, name)
1534-
potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options), getid(potential_label))
1532+
potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options))
15351533
label, nodedata = preprocess_plugins(
15361534
UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
15371535
)
@@ -1645,7 +1643,7 @@ function add_atomic_factor_node!(model::Model, context::Context, options::NodeCr
16451643
factornode_id = generate_factor_nodelabel(context, fform)
16461644

16471645
potential_label = generate_nodelabel(model, fform)
1648-
potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options), getid(potential_label))
1646+
potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options))
16491647

16501648
label, nodedata = preprocess_plugins(
16511649
UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options

src/plugins/node_id.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""
22
NodeIdPlugin
33
4-
A plugin that adds an `id` property to the factor node. This field can be used to
5-
find a node given its `id` with the `GraphPPL.by_nodeid` filter.
4+
A plugin that adds an `id` property to the factor node. This field is unique for every factor node.
65
"""
76
struct NodeIdPlugin end
87

@@ -11,23 +10,6 @@ plugin_type(::NodeIdPlugin) = FactorAndVariableNodesPlugin()
1110
function preprocess_plugin(
1211
::NodeIdPlugin, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options::NodeCreationOptions
1312
)
14-
if haskey(options, :id)
15-
setextra!(nodedata, :id, getindex(options, :id))
16-
end
13+
setextra!(nodedata, :id, label.global_counter)
1714
return label, nodedata
1815
end
19-
20-
struct FilterById <: AbstractModelFilterPredicate
21-
id
22-
end
23-
24-
"""
25-
by_nodeid(id)
26-
27-
A filter predicate that can be used to find a node given its `id` in a model.
28-
"""
29-
by_nodeid(id) = FilterById(id)
30-
31-
function apply(predicate::FilterById, model, something)
32-
return hasextra(model[something], :id) && isequal(getextra(model[something], :id), predicate.id)
33-
end

src/plugins/node_tag.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
NodeTagPlugin
3+
4+
A plugin that adds an `tag` property to the factor node. This field can be used to
5+
find a node given its `tag` with the `GraphPPL.by_nodetag` filter.
6+
"""
7+
struct NodeTagPlugin end
8+
9+
plugin_type(::NodeTagPlugin) = FactorAndVariableNodesPlugin()
10+
11+
function preprocess_plugin(
12+
::NodeTagPlugin, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options::NodeCreationOptions
13+
)
14+
if haskey(options, :tag)
15+
setextra!(nodedata, :tag, getindex(options, :tag))
16+
end
17+
return label, nodedata
18+
end
19+
20+
struct FilterByTag <: AbstractModelFilterPredicate
21+
tag
22+
end
23+
24+
"""
25+
by_nodetag(tag)
26+
27+
A filter predicate that can be used to find a node given its `tag` in a model.
28+
"""
29+
by_nodetag(tag) = FilterByTag(tag)
30+
31+
function apply(predicate::FilterByTag, model, something)
32+
return hasextra(model[something], :tag) && isequal(getextra(model[something], :tag), predicate.tag)
33+
end

test/graph_construction_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2002,4 +2002,4 @@ end
20022002
@test length(collect(filter(as_node(Normal), model))) == 2
20032003
@test length(collect(filter(as_variable(:y), model))) == 1
20042004
@test length(collect(filter(as_variable(:x), model))) == 1
2005-
end
2005+
end

0 commit comments

Comments
 (0)