Skip to content

Commit 3ab5d44

Browse files
authored
Merge pull request #266 from ReactiveBayes/save-model-string
Add functionality to access the original source code of the model macro
2 parents 4dab3d2 + a86807c commit 3ab5d44

17 files changed

+344
-66
lines changed

docs/src/developers_guide.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ GraphPPL.ResizableArray
116116

117117
```@docs
118118
GraphPPL.Context
119-
GraphPPL.ModelGenerator
120119
GraphPPL.FactorID
121120
GraphPPL.NodeData
122121
GraphPPL.NodeLabel
@@ -156,6 +155,11 @@ GraphPPL.hasextra
156155
GraphPPL.getextra
157156
GraphPPL.setextra!
158157
158+
GraphPPL.ModelGenerator
159+
GraphPPL.with_plugins
160+
GraphPPL.with_backend
161+
GraphPPL.with_source
162+
159163
GraphPPL.make_node!
160164
GraphPPL.add_atomic_factor_node!
161165
GraphPPL.add_toplevel_model!

docs/src/plugins/overview.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@ The following plugins are available by default in `GraphPPL`:
2424

2525
## Using a plugin
2626

27-
To use a plugin, call the `with_plugins` function when constructing a model:
28-
29-
```@docs
30-
GraphPPL.with_plugins
31-
```
27+
To use a plugin, call the [`GraphPPL.with_plugins`](@ref) function when constructing a model:
3228

3329
The `PluginsCollection` is a collection of plugins that will be applied to the model. The order of plugins in the collection is important, as the `preprocess_plugin` and `postprocess_plugin` functions are called in the order of the plugins in the collection.
3430

src/graph_engine.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,14 @@ Fields:
214214
- `graph`: A `MetaGraph` object representing the factor graph.
215215
- `plugins`: A `PluginsCollection` object representing the plugins enabled in the model.
216216
- `backend`: A `Backend` object representing the backend used in the model.
217+
- `source`: A `Source` object representing the original source code of the model (typically a `String` object).
217218
- `counter`: A `Base.RefValue{Int64}` object keeping track of the number of nodes in the graph.
218219
"""
219-
struct Model{G, P, B}
220+
struct Model{G, P, B, S}
220221
graph::G
221222
plugins::P
222223
backend::B
224+
source::S
223225
counter::Base.RefValue{Int64}
224226
end
225227

@@ -228,6 +230,7 @@ Base.isempty(model::Model) = iszero(nv(model.graph)) && iszero(ne(model.graph))
228230

229231
getplugins(model::Model) = model.plugins
230232
getbackend(model::Model) = model.backend
233+
getsource(model::Model) = model.source
231234
getcounter(model::Model) = model.counter[]
232235
setcounter!(model::Model, value) = model.counter[] = value
233236

@@ -1190,20 +1193,20 @@ function Base.convert(::Type{NamedTuple}, ::StaticInterfaces{I}, t::Tuple) where
11901193
return NamedTuple{I}(t)
11911194
end
11921195

1193-
function Model(graph::MetaGraph, plugins::PluginsCollection, backend)
1194-
return Model(graph, plugins, backend, Base.RefValue(0))
1196+
function Model(graph::MetaGraph, plugins::PluginsCollection, backend, source)
1197+
return Model(graph, plugins, backend, source, Base.RefValue(0))
11951198
end
11961199

11971200
function Model(fform::F, plugins::PluginsCollection) where {F}
1198-
return Model(fform, plugins, default_backend(fform))
1201+
return Model(fform, plugins, default_backend(fform), nothing)
11991202
end
12001203

1201-
function Model(fform::F, plugins::PluginsCollection, backend) where {F}
1204+
function Model(fform::F, plugins::PluginsCollection, backend, source) where {F}
12021205
label_type = NodeLabel
12031206
edge_data_type = EdgeLabel
12041207
vertex_data_type = NodeData
12051208
graph = MetaGraph(Graph(), label_type, vertex_data_type, edge_data_type, Context(fform))
1206-
model = Model(graph, plugins, backend)
1209+
model = Model(graph, plugins, backend, source)
12071210
return model
12081211
end
12091212

src/model_generator.jl

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,101 @@
1-
21
"""
3-
ModelGenerator(model, kwargs, plugins)
2+
ModelGenerator(model, kwargs, [ plugins ])
3+
4+
The `ModelGenerator` structure is used to lazily create the model with the given `model` and `kwargs` and (optional) `plugins`.
5+
6+
# Fields
7+
- `model`: The model function to be used for creating the graph
8+
- `kwargs`: Named tuple of keyword arguments to be passed to the model
9+
- `plugins`: Collection of plugins to be used (optional)
10+
- `backend`: Backend to be used for model creation (defaults to model's default_backend)
11+
- `source`: Original source code of the model (for debugging purposes)
12+
13+
# Extended Functionality
14+
The ModelGenerator supports several extension methods:
15+
16+
- `with_plugins(generator, plugins)`: Create new generator with updated plugins
17+
- `with_backend(generator, backend)`: Create new generator with different backend
18+
- `with_source(generator, source)`: Create new generator with different source code
19+
20+
# Examples
21+
22+
```jldoctest
23+
julia> import GraphPPL: @model
424
5-
The `ModelGenerator` structure is used to lazily create
6-
the model with the given `model` and `kwargs` and `plugins`.
25+
julia> @model function beta_bernoulli(y)
26+
θ ~ Beta(1, 1)
27+
for i = eachindex(y)
28+
y[i] ~ Bernoulli(θ)
29+
end
30+
end
31+
32+
julia> generator = beta_bernoulli(y = rand(10));
33+
34+
julia> struct CustomBackend end
35+
36+
julia> generator_with_backend = GraphPPL.with_backend(generator, CustomBackend());
37+
38+
julia> generator_with_plugins = GraphPPL.with_plugins(generator, GraphPPL.PluginsCollection());
39+
40+
julia> println(GraphPPL.getsource(generator))
41+
function beta_bernoulli(y)
42+
θ ~ Beta(1, 1)
43+
for i = eachindex(y)
44+
y[i] ~ Bernoulli(θ)
45+
end
46+
end
47+
48+
julia> generator_with_source = GraphPPL.with_source(generator, "Hello, world!");
49+
50+
julia> println(GraphPPL.getsource(generator_with_source))
51+
Hello, world!
52+
```
53+
54+
See also: [`with_plugins`](@ref), [`with_backend`](@ref), [`with_source`](@ref)
755
"""
8-
struct ModelGenerator{G, K, P, B}
56+
struct ModelGenerator{G, K, P, B, S}
957
model::G
1058
kwargs::K
1159
plugins::P
1260
backend::B
61+
source::S
1362
end
1463

15-
ModelGenerator(model::G, kwargs::K) where {G, K} = ModelGenerator(model, kwargs, PluginsCollection(), default_backend(model))
64+
ModelGenerator(model, kwargs) = ModelGenerator(model, kwargs, PluginsCollection())
65+
ModelGenerator(model, kwargs, plugins) = ModelGenerator(model, kwargs, plugins, default_backend(model))
66+
ModelGenerator(model, kwargs, plugins, backend) = ModelGenerator(model, kwargs, plugins, backend, nothing)
1667

1768
getmodel(generator::ModelGenerator) = generator.model
1869
getkwargs(generator::ModelGenerator) = generator.kwargs
1970
getplugins(generator::ModelGenerator) = generator.plugins
2071
getbackend(generator::ModelGenerator) = generator.backend
72+
getsource(generator::ModelGenerator) = generator.source
2173

2274
"""
2375
with_plugins(generator::ModelGenerator, plugins::PluginsCollection)
2476
25-
Attaches the `plugins` to the `generator`. For example:
26-
```julia
27-
plugins = GraphPPL.PluginsCollection(GraphPPL.NodeCreatedByPlugin())
28-
new_generator = GraphPPL.with_plugins(generator, plugins)
29-
```
77+
Overwrites the `plugins` specified in the `generator`.
3078
"""
3179
function with_plugins(generator::ModelGenerator, plugins::PluginsCollection)
32-
return ModelGenerator(generator.model, generator.kwargs, generator.plugins + plugins, generator.backend)
80+
return ModelGenerator(generator.model, generator.kwargs, generator.plugins + plugins, generator.backend, generator.source)
3381
end
3482

83+
"""
84+
with_backend(generator::ModelGenerator, plugins::PluginsCollection)
85+
86+
Overwrites the `backend` specified in the `generator`.
87+
"""
3588
function with_backend(generator::ModelGenerator, backend)
36-
return ModelGenerator(generator.model, generator.kwargs, generator.plugins, backend)
89+
return ModelGenerator(generator.model, generator.kwargs, generator.plugins, backend, generator.source)
90+
end
91+
92+
"""
93+
with_source(generator::ModelGenerator, source)
94+
95+
Overwrites the `source` specified in the `generator`.
96+
"""
97+
function with_source(generator::ModelGenerator, source)
98+
return ModelGenerator(generator.model, generator.kwargs, generator.plugins, generator.backend, source)
3799
end
38100

39101
function create_model(generator::ModelGenerator)
@@ -71,7 +133,7 @@ true
71133
```
72134
"""
73135
function create_model(callback, generator::ModelGenerator)
74-
model = Model(getmodel(generator), getplugins(generator), getbackend(generator))
136+
model = Model(getmodel(generator), getplugins(generator), getbackend(generator), getsource(generator))
75137
context = getcontext(model)
76138

77139
extrakwargs = callback(model, context)

src/model_macro.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,7 @@ function preprocess_interface_expression(arg::Expr; warn = true)
689689
end
690690
end
691691

692-
function get_make_node_function(ms_body, ms_args, ms_name)
693-
# TODO (bvdmitri): prettify
692+
function get_make_node_function(model_specification, ms_body, ms_args, ms_name)
694693
ms_arg_names = map((arg) -> preprocess_interface_expression(arg; warn = false), ms_args)
695694
init_input_arguments = map(zip(ms_args, ms_arg_names)) do (arg, arg_name)
696695
error_msg = "Missing interface $(arg_name)"
@@ -704,6 +703,11 @@ function get_make_node_function(ms_body, ms_args, ms_name)
704703
unsupported_positional_arguments_errmsg = """
705704
The `$(ms_name)` model macro does not support positional arguments. Use keyword arguments `$(ms_name)($(join(map(a -> string(a, " = ..."), ms_arg_names), ", ")))` instead.
706705
"""
706+
707+
# ModelGenerator saves the source code of the function
708+
ms_string = string(MacroTools.prewalk(MacroTools.rmlines, model_specification))
709+
ms_string_symbol = gensym(string(ms_name, :source_code))
710+
707711
make_node_function = quote
708712
function GraphPPL.make_node!(
709713
::GraphPPL.Composite,
@@ -749,8 +753,12 @@ function get_make_node_function(ms_body, ms_args, ms_name)
749753
error("Model $(__fform__) is not defined for $(N) interfaces ($(keys(__interfaces__))).")
750754
end
751755

756+
$(ms_string_symbol)::String = $(ms_string)
757+
752758
function ($ms_name)(; kwargs...)
753-
return GraphPPL.ModelGenerator($ms_name, kwargs)
759+
generator = GraphPPL.ModelGenerator($ms_name, kwargs)
760+
generator = GraphPPL.with_source(generator, $(ms_string_symbol))
761+
return generator
754762
end
755763

756764
function ($ms_name)(args...; kwargs...)
@@ -797,7 +805,7 @@ function model_macro_interior(backend_type, model_specification)
797805
pipeline_collection = GraphPPL.model_macro_interior_pipelines(instantiate(backend_type))
798806
ms_body = apply_pipeline_collection(ms_body, pipeline_collection)
799807

800-
make_node_function = get_make_node_function(ms_body, ms_args, ms_name)
808+
make_node_function = get_make_node_function(model_specification, ms_body, ms_args, ms_name)
801809
result = quote
802810
$boilerplate_functions
803811
$make_node_function

src/plugins/meta/meta_engine.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ getmetainfo(m::MetaObject) = m.meta_object
3737
struct MetaSpecification
3838
meta_objects::Vector
3939
submodel_meta::Vector
40+
source_code::String
4041
end
4142

4243
function Base.show(io::IO, c::MetaSpecification)
@@ -68,6 +69,8 @@ getgeneralsubmodelmeta(m::MetaSpecification) = filter(m -> is_generalsubmodelmet
6869
getspecificsubmodelmeta(m::MetaSpecification, tag::Any) = get(filter(m -> getsubmodel(m) == tag, getsubmodelmeta(m)), 1, nothing)
6970
getgeneralsubmodelmeta(m::MetaSpecification, fform::Any) = get(filter(m -> getsubmodel(m) == fform, getsubmodelmeta(m)), 1, nothing)
7071

72+
source_code(m::MetaSpecification) = m.source_code
73+
7174
struct SpecificSubModelMeta
7275
tag::FactorID
7376
meta_objects::MetaSpecification
@@ -106,7 +109,8 @@ function Base.show(io::IO, constraint::SubModelMeta)
106109
)
107110
end
108111

109-
MetaSpecification() = MetaSpecification(Vector{MetaObject}(), Vector{SubModelMeta}())
112+
MetaSpecification() = MetaSpecification("")
113+
MetaSpecification(source_code::String) = MetaSpecification(Vector{MetaObject}(), Vector{SubModelMeta}(), source_code)
110114

111115
Base.push!(m::MetaSpecification, o::MetaObject) = push!(m.meta_objects, o)
112116
Base.push!(m::MetaSpecification, o::SubModelMeta) = push!(m.submodel_meta, o)

src/plugins/meta/meta_macro.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,27 @@ and assigns it to the `__meta__` variable. It then evaluates the given expressio
1717
- `e::Expr`: The expression that will generate the `GraphPPL.MetaSpecification` object.
1818
"""
1919
function add_meta_construction(e::Expr)
20+
c_body_string = string(MacroTools.unblock(MacroTools.prewalk(MacroTools.rmlines, e)))
21+
c_body_string_symbol = gensym(:meta_source_code)
22+
2023
if @capture(e, (function m_name_(m_args__; m_kwargs__)
2124
c_body_
2225
end) | (function m_name_(m_args__)
2326
c_body_
2427
end))
2528
m_kwargs = m_kwargs === nothing ? [] : m_kwargs
2629
return quote
30+
$(c_body_string_symbol)::String = $(c_body_string)
2731
function $m_name($(m_args...); $(m_kwargs...))
28-
__meta__ = GraphPPL.MetaSpecification()
32+
__meta__ = GraphPPL.MetaSpecification($(c_body_string_symbol))
2933
$c_body
3034
return __meta__
3135
end
3236
end
3337
else
3438
return quote
35-
let __meta__ = GraphPPL.MetaSpecification()
39+
$(c_body_string_symbol)::String = $(c_body_string)
40+
let __meta__ = GraphPPL.MetaSpecification($(c_body_string_symbol))
3641
$e
3742
__meta__
3843
end

src/plugins/variational_constraints/variational_constraints.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct VariationalConstraintsPlugin{C}
3131
constraints::C
3232
end
3333

34-
const UnspecifiedConstraints = Constraints((), (), (), (;), (;))
34+
const UnspecifiedConstraints = Constraints((), (), (), (;), (;), "UnspecifiedConstraints")
3535

3636
default_constraints(::Any) = UnspecifiedConstraints
3737

src/plugins/variational_constraints/variational_constraints_engine.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,29 +306,33 @@ getconstraint(c::SpecificSubModelConstraints) = c.constraints
306306
307307
An instance of `Constraints` represents a set of constraints to be applied to a variational posterior in a factor graph model.
308308
"""
309-
struct Constraints{F, P, M, G, S}
309+
struct Constraints{F, P, M, G, S, C}
310310
factorization_constraints::F
311311
marginal_form_constraints::P
312312
message_form_constraints::M
313313
general_submodel_constraints::G
314314
specific_submodel_constraints::S
315+
source_code::C
315316
end
316317

317318
factorization_constraints(c::Constraints) = c.factorization_constraints
318319
marginal_form_constraints(c::Constraints) = c.marginal_form_constraints
319320
message_form_constraints(c::Constraints) = c.message_form_constraints
320321
general_submodel_constraints(c::Constraints) = c.general_submodel_constraints
321322
specific_submodel_constraints(c::Constraints) = c.specific_submodel_constraints
322-
323-
function Constraints()
324-
return Constraints(
325-
Vector{FactorizationConstraint}(),
326-
Vector{MarginalFormConstraint}(),
327-
Vector{MessageFormConstraint}(),
328-
Dict{Function, GeneralSubModelConstraints}(),
329-
Dict{FactorID, SpecificSubModelConstraints}()
330-
)
331-
end
323+
source_code(c::Constraints) = c.source_code
324+
325+
# By default `Constraints` are being created with an empty source code
326+
Constraints() = Constraints("")
327+
328+
Constraints(source_code::String) = Constraints(
329+
Vector{FactorizationConstraint}(),
330+
Vector{MarginalFormConstraint}(),
331+
Vector{MessageFormConstraint}(),
332+
Dict{Function, GeneralSubModelConstraints}(),
333+
Dict{FactorID, SpecificSubModelConstraints}(),
334+
source_code
335+
)
332336

333337
Constraints(constraints::Vector) = begin
334338
c = Constraints()

src/plugins/variational_constraints/variational_constraints_macro.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,27 @@ end
1818
check_for_returns_constraints = (x) -> check_for_returns(x; tag = "constraints")
1919

2020
function add_constraints_construction(e::Expr)
21+
c_body_string = string(MacroTools.unblock(MacroTools.prewalk(MacroTools.rmlines, e)))
22+
c_body_string_symbol = gensym(:constraints_source_code)
23+
2124
if @capture(e, (function c_name_(c_args__; c_kwargs__)
2225
c_body_
2326
end) | (function c_name_(c_args__)
2427
c_body_
2528
end))
2629
c_kwargs = c_kwargs === nothing ? [] : c_kwargs
2730
return quote
31+
$(c_body_string_symbol)::String = $(c_body_string)
2832
function $c_name($(c_args...); $(c_kwargs...))
29-
__constraints__ = GraphPPL.Constraints()
33+
__constraints__ = GraphPPL.Constraints($(c_body_string_symbol))
3034
$c_body
3135
return __constraints__
3236
end
3337
end
3438
else
3539
return quote
36-
let __constraints__ = GraphPPL.Constraints()
40+
$(c_body_string_symbol)::String = $(c_body_string)
41+
let __constraints__ = GraphPPL.Constraints($(c_body_string_symbol))
3742
$e
3843
__constraints__
3944
end

0 commit comments

Comments
 (0)