Skip to content

Commit e607544

Browse files
committed
feat(): meta specification language
1 parent 0cc9a9d commit e607544

File tree

6 files changed

+125
-10
lines changed

6 files changed

+125
-10
lines changed

src/GraphPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ __get_current_backend() = ReactiveMPBackend()
99
include("utils.jl")
1010
include("model.jl")
1111
include("constraints.jl")
12+
include("meta.jl")
1213

1314
end # module

src/backends/reactivemp.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,4 +206,14 @@ end
206206

207207
function write_form_constraint_specification(::ReactiveMPBackend, T, args, kwargs)
208208
return :(ReactiveMP.FormConstraintsSpecification($T, $args, $kwargs))
209+
end
210+
211+
## Meta specification language
212+
213+
function write_meta_specification(::ReactiveMPBackend, entries)
214+
return :(ReactiveMP.MetaSpecification($entries))
215+
end
216+
217+
function write_meta_specification_entry(::ReactiveMPBackend, F, N, meta)
218+
return :(ReactiveMP.MetaSpecificationEntry(Val($F), Val($N), $meta))
209219
end

src/constraints.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656

5757
## Factorisation constraints
5858

59-
struct FactorisationConstraintLHSMeta
59+
struct FactorisationConstraintLHSInfo
6060
name :: String
6161
hash :: UInt
6262
varname :: Symbol
@@ -134,7 +134,7 @@ function generate_constraints_expression(backend, constraints_specification)
134134
cs_args = cs_args === nothing ? [] : cs_args
135135
cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
136136

137-
lhs_dict = Dict{UInt, FactorisationConstraintLHSMeta}()
137+
lhs_dict = Dict{UInt, FactorisationConstraintLHSInfo}()
138138

139139
marginals_form_constraints_symbol = gensym(:marginals_form_constraint)
140140
marginals_form_constraints_symbol_init = :($marginals_form_constraints_symbol = (;))
@@ -223,17 +223,17 @@ function generate_constraints_expression(backend, constraints_specification)
223223
(lhs_names == rhs_names) || error("LHS and RHS of the $(expression) expression has different set of variables.")
224224

225225
lhs_hash = hash(lhs)
226-
lhs_meta = if haskey(lhs_dict, lhs_hash)
226+
lhs_info = if haskey(lhs_dict, lhs_hash)
227227
lhs_dict[ lhs_hash ]
228228
else
229229
lhs_name = string("q(", join(names, ", "), ")")
230230
lhs_varname = gensym(lhs_name)
231-
lhs_meta = FactorisationConstraintLHSMeta(lhs_name, lhs_hash, lhs_varname)
232-
lhs_dict[lhs_hash] = lhs_meta
231+
lhs_info = FactorisationConstraintLHSInfo(lhs_name, lhs_hash, lhs_varname)
232+
lhs_dict[lhs_hash] = lhs_info
233233
end
234234

235-
lhs_name = lhs_meta.name
236-
lhs_varname = lhs_meta.varname
235+
lhs_name = lhs_info.name
236+
lhs_varname = lhs_info.varname
237237

238238
new_factorisation_specification = write_factorisation_constraint(backend, :(Val(($(map(QuoteNode, names)...),))), :(Val($(rhs))))
239239
check_is_not_defined = write_check_factorisation_is_not_defined(backend, lhs_varname)
@@ -251,9 +251,9 @@ function generate_constraints_expression(backend, constraints_specification)
251251

252252
# This block write initial variables for factorisation specification
253253
cs_lhs_init_block = map(collect(lhs_dict)) do pair
254-
lhs_meta = last(pair)
255-
lhs_name = lhs_meta.name
256-
lhs_varname = lhs_meta.varname
254+
lhs_info = last(pair)
255+
lhs_name = lhs_info.name
256+
lhs_varname = lhs_info.varname
257257
lhs_symbol = Symbol(lhs_name)
258258
return write_init_factorisation_not_defined(backend, lhs_varname, lhs_symbol)
259259
end

src/meta.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
export @meta
2+
3+
"""
4+
write_meta_specification(backend, entries)
5+
"""
6+
function write_meta_specification end
7+
8+
"""
9+
write_meta_specification_entry(backend, F, N, meta)
10+
"""
11+
function write_meta_specification_entry end
12+
13+
macro meta(meta_specification)
14+
return generate_meta_expression(__get_current_backend(), meta_specification)
15+
end
16+
17+
struct MetaSpecificationLHSInfo
18+
hash :: UInt
19+
checkname :: Symbol
20+
end
21+
22+
function generate_meta_expression(backend, meta_specification)
23+
24+
if isblock(meta_specification)
25+
generatedfname = gensym(:constraints)
26+
generatedfbody = :(function $(generatedfname)() $meta_specification end)
27+
return :($(generate_meta_expression(backend, generatedfbody))())
28+
end
29+
30+
@capture(meta_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) ||
31+
error("Meta specification language requires full function definition")
32+
33+
cs_args = cs_args === nothing ? [] : cs_args
34+
cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
35+
36+
lhs_dict = Dict{UInt, MetaSpecificationLHSInfo}()
37+
38+
meta_spec_symbol = gensym(:meta)
39+
meta_spec_symbol_init = :($meta_spec_symbol = ())
40+
41+
cs_body = postwalk(cs_body) do expression
42+
if @capture(expression, f_(args__) = meta_)
43+
44+
if !issymbol(f) || any(a -> !issymbol(a), args)
45+
error("Invalid meta specification $(expression)")
46+
end
47+
48+
lhs = :($f($(args...)))
49+
lhs_hash = hash(lhs)
50+
lhs_info = if haskey(lhs_dict, lhs_hash)
51+
lhs_dict[ lhs_hash ]
52+
else
53+
lhs_checkname = gensym(f)
54+
lhs_info = MetaSpecificationLHSInfo(lhs_hash, lhs_checkname)
55+
lhs_dict[lhs_hash] = lhs_info
56+
end
57+
58+
lhs_checkname = lhs_info.checkname
59+
error_msg = "Meta specification $lhs has been redefined"
60+
meta_entry = write_meta_specification_entry(backend, QuoteNode(f), :(($(map(QuoteNode, args)...), )), meta)
61+
62+
return quote
63+
($lhs_checkname) && error($error_msg)
64+
$meta_spec_symbol = ($meta_spec_symbol..., $meta_entry)
65+
$lhs_checkname = true
66+
end
67+
end
68+
return expression
69+
end
70+
71+
lhs_checknames_init = map(collect(pairs(lhs_dict))) do pair
72+
lhs_info = last(pair)
73+
lhs_checkname = lhs_info.checkname
74+
return quote
75+
$lhs_checkname = false
76+
end
77+
end
78+
79+
ret_meta_specification = write_meta_specification(backend, meta_spec_symbol)
80+
81+
res = quote
82+
function $cs_name($(cs_args...); $(cs_kwargs...))
83+
$meta_spec_symbol_init
84+
$(lhs_checknames_init...)
85+
$cs_body
86+
$ret_meta_specification
87+
end
88+
end
89+
90+
return esc(res)
91+
end

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11

2+
issymbol(::Symbol) = true
3+
issymbol(any) = false
24

35
isexpr(expr::Expr) = true
46
isexpr(expr) = false

test/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@ module UtilsTests
33
using Test
44
using GraphPPL
55

6+
@testset "issymbol tests" begin
7+
import GraphPPL: issymbol
8+
9+
@test issymbol(:(f(1))) === false
10+
@test issymbol(:(f(1))) === false
11+
@test issymbol(:(if true 1 else 2 end)) === false
12+
@test issymbol(:hello) === true
13+
@test issymbol(:a) === true
14+
@test issymbol(123) === false
15+
end
16+
617
@testset "isexpr tests" begin
718
import GraphPPL: isexpr
819

0 commit comments

Comments
 (0)