Skip to content

Commit c8bf796

Browse files
committed
feat(): support for custom functional dependencies and new pipeline from ReactiveMP
1 parent d2e9b8a commit c8bf796

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

src/backends/reactivemp.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ function write_node_options(::ReactiveMPBackend, fform, variables, options)
3939
if @capture(option, q = fconstraint_)
4040
return write_fconstraint_option(fform, variables, fconstraint)
4141
elseif @capture(option, meta = fmeta_)
42-
return write_meta_option(fmeta)
42+
return write_meta_option(fform, fmeta)
4343
elseif @capture(option, pipeline = fpipeline_)
44-
return write_pipeline_option(fpipeline)
44+
return write_pipeline_option(fform, fpipeline)
4545
end
4646

4747
error("Unknown option '$option' for '$fform' node")
@@ -50,14 +50,42 @@ end
5050

5151
# Meta helper functions
5252

53-
function write_meta_option(fmeta)
53+
function write_meta_option(fform, fmeta)
5454
return :(meta = $fmeta)
5555
end
5656

5757
# Pipeline helper functions
5858

59-
function write_pipeline_option(fpipeline)
60-
return :(pipeline = $fpipeline)
59+
function write_pipeline_option(fform, fpipeline)
60+
if @capture(fpipeline, +(stages__))
61+
return :(pipeline = ReactiveMP.FactorNodePipeline(+($(map(stage -> write_pipeline_stage(fform, stage), stages)...))))
62+
else
63+
return :(pipeline = ReactiveMP.FactorNodePipeline($(write_pipeline_stage(fform, fpipeline))))
64+
end
65+
end
66+
67+
function write_pipeline_stage(fform, stage)
68+
if @capture(stage, Default())
69+
return :(ReactiveMP.DefaultFunctionalDependencies())
70+
elseif @capture(stage, RequireInbound(args__))
71+
72+
specs = map(args) do arg
73+
if @capture(arg, name_Symbol)
74+
return (name, :nothing)
75+
elseif @capture(arg, name_Symbol = dist_)
76+
return (name, dist)
77+
else
78+
error("Invalid arg specification in node's WithInbound dependencies list: $(arg). Should be either `name` or `name = initial` expression")
79+
end
80+
end
81+
82+
indices = Expr(:tuple, map(s -> :(ReactiveMP.interface_get_index(Val{ $(GraphPPL.fquote(fform)) }, Val{ $(GraphPPL.fquote(first(s))) })), specs)...)
83+
initials = Expr(:tuple, map(s -> :($(last(s))), specs)...)
84+
85+
return :(RequireInboundFunctionalDependencies($indices, $initials))
86+
else
87+
return stage
88+
end
6189
end
6290

6391
# Factorisation constraint helper functions

0 commit comments

Comments
 (0)