Skip to content

Commit c112d1f

Browse files
authored
Merge pull request #234 from ReactiveBayes/lazylabels
Implementation of VariableRef aka lazylabel
2 parents 564a369 + d9141da commit c112d1f

File tree

13 files changed

+1366
-916
lines changed

13 files changed

+1366
-916
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphPPL"
22
uuid = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
33
authors = ["Wouter Nuijten <[email protected]>", "Dmitry Bagaev <[email protected]>"]
4-
version = "4.1.0"
4+
version = "4.2.0"
55

66
[deps]
77
BitSetTuples = "0f2f92aa-23a3-4d05-b791-88071d064721"

docs/src/developers_guide.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,25 @@ GraphPPL.NodeData
119119
GraphPPL.NodeLabel
120120
GraphPPL.EdgeLabel
121121
GraphPPL.ProxyLabel
122+
GraphPPL.indexed_last
123+
GraphPPL.lift_index
124+
GraphPPL.datalabel
122125
GraphPPL.StaticInterfaces
123-
GraphPPL.LazyIndex
126+
GraphPPL.VariableRef
127+
GraphPPL.makevarref
124128
GraphPPL.MissingCollection
125129
GraphPPL.VariableNodeProperties
126130
GraphPPL.FactorNodeProperties
127131
GraphPPL.VarDict
128132
GraphPPL.AnonymousVariable
129133
GraphPPL.NodeDataExtraKey
130-
GraphPPL.LazyNodeLabel
131134
GraphPPL.IndexedVariable
132135
136+
GraphPPL.VariableKindRandom
137+
GraphPPL.VariableKindData
138+
GraphPPL.VariableKindConstant
139+
GraphPPL.VariableKindUnknown
140+
133141
GraphPPL.Deterministic
134142
GraphPPL.Stochastic
135143
GraphPPL.Atomic
@@ -151,7 +159,6 @@ GraphPPL.add_variable_node!
151159
GraphPPL.add_composite_factor_node!
152160
GraphPPL.copy_markov_blanket_to_child_context
153161
GraphPPL.generate_nodelabel
154-
GraphPPL.check_variate_compatability
155162
156163
GraphPPL.FunctionalIndex
157164
GraphPPL.FunctionalRange

docs/src/getting_started.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,15 @@ end
128128

129129
## Instantiating the model
130130

131-
To instantiate the model we need to pass the data `x`. We can do that with the [`GraphPPL.create_model`](@ref) function in combination with [`GraphPPL.LazyIndex`](@ref).
131+
To instantiate the model we need to pass the data `x`. We can do that with the [`GraphPPL.create_model`](@ref) function in combination with [`GraphPPL.datalabel`](@ref).
132132

133133
```@example getting-started
134-
xdata = [ 1.0, 0.0, 0.0, 1.0 ]
134+
data_for_x = [ 1.0, 0.0, 0.0, 1.0 ]
135135
136136
model = GraphPPL.create_model(coin_toss()) do model, context
137137
return (;
138138
# This expression creates data handle for `x` in the model using the `xdata` as the underlying collection
139-
x = GraphPPL.getorcreate!(model, context, GraphPPL.NodeCreationOptions(kind = :data), :x, GraphPPL.LazyIndex(xdata))
139+
x = GraphPPL.datalabel(model, context, GraphPPL.NodeCreationOptions(kind = GraphPPL.VariableKindData), :x, data_for_x)
140140
)
141141
end
142142
nothing #hide

ext/GraphPPLDistributionsExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ end
1313
function distributions_ext_default_parametrization(
1414
t::Type{<:Distributions.Distribution}, ::GraphPPL.StaticInterfaces{interfaces}, interface_values
1515
) where {interfaces}
16-
@assert length(interface_values) == length(interfaces) "Distribution $t has $(length(interfaces)) fields $(interfaces) but $(length(interface_values)) values were provided."
16+
if !(length(interface_values) == length(interfaces))
17+
error(lazy"Distribution $t has $(length(interfaces)) fields $(interfaces) but $(length(interface_values)) values were provided.")
18+
end
1719
return NamedTuple{interfaces}(interface_values)
1820
end
1921

src/graph_engine.jl

Lines changed: 474 additions & 330 deletions
Large diffs are not rendered by default.

src/model_generator.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ GraphPPL.@model function beta_bernoulli(y, a, b)
5757
end
5858
end
5959
60-
data_y = rand(Bernoulli(0.5), 100)
60+
data_for_y = rand(Bernoulli(0.5), 100)
6161
6262
model = GraphPPL.create_model(beta_bernoulli(a = 1, b = 1)) do model, ctx
6363
# Inject the data into the model
64-
y = GraphPPL.getorcreate!(model, ctx, GraphPPL.NodeCreationOptions(kind = :data), :y, GraphPPL.LazyIndex(data_y))
64+
y = GraphPPL.datalabel(model, ctx, GraphPPL.NodeCreationOptions(kind = GraphPPL.VariableKindData), :y, data_for_y)
6565
return (; y = y, )
6666
end
6767

src/model_macro.jl

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ function add_get_or_create_expression(e::Expr)
350350
if @capture(e, (lhs_ ~ rhs_ where {options__}))
351351
@capture(lhs, (var_[index__]) | (var_))
352352
return quote
353-
$(generate_get_or_create(var, index))
353+
$(generate_get_or_create(var, index, rhs))
354354
$e
355355
end
356356
end
@@ -371,7 +371,7 @@ Generates code to get or create a variable in the graph. This function is used t
371371
# Returns
372372
A `quote` block with the code to get or create the variable in the graph.
373373
"""
374-
generate_get_or_create(s::Symbol, index::Nothing) = generate_get_or_create(s, :((nothing,)))
374+
generate_get_or_create(s::Symbol, index::Nothing, rhs) = generate_get_or_create(s, :((nothing,)), rhs)
375375

376376
"""
377377
generate_get_or_create(s::Symbol, lhs::Expr, index::AbstractArray)
@@ -385,20 +385,15 @@ Generates code to get or create a variable in the graph. This function is used t
385385
# Returns
386386
A `quote` block with the code to get or create the variable in the graph.
387387
"""
388-
generate_get_or_create(s::Symbol, index::AbstractArray) = generate_get_or_create(s, :(($(index...),)))
388+
generate_get_or_create(s::Symbol, index::AbstractArray, rhs) = generate_get_or_create(s, :(($(index...),)), rhs)
389389

390-
function generate_get_or_create(s::Symbol, index::Expr)
390+
function generate_get_or_create(s::Symbol, index::Expr, rhs)
391+
type = @capture(rhs, (f_()) | (f_(args__) | (f_(; kwargs__)) | (f_(args__; kwargs__)))) ? f : :(GraphPPL.Composite())
391392
return quote
392393
$s = if !@isdefined($s)
393-
GraphPPL.getorcreate!(__model__, __context__, $(QuoteNode(s)), $(index)...)
394+
GraphPPL.makevarref($type, __model__, __context__, GraphPPL.NodeCreationOptions(), $(QuoteNode(s)), $(index))
394395
else
395-
(
396-
if GraphPPL.check_variate_compatability($s, $(index)...)
397-
$s
398-
else
399-
GraphPPL.getorcreate!(__model__, __context__, $(QuoteNode(s)), $(index)...)
400-
end
401-
)
396+
$s
402397
end
403398
end
404399
end
@@ -445,7 +440,7 @@ Converts an expression into its proxied equivalent. Used to pass variables in su
445440
julia> x = GraphPPL.NodeLabel(:x, 1)
446441
x_1
447442
julia> GraphPPL.proxy_args(:(y = x))
448-
:(y = GraphPPL.proxylabel(:x, nothing, x))
443+
:(y = GraphPPL.proxylabel(:x, x, nothing, GraphPPL.False()))
449444
```
450445
"""
451446
function proxy_args end
@@ -464,29 +459,21 @@ function proxy_args(arg)
464459
end
465460

466461
function proxy_args_lhs_eq_rhs(lhs, rhs)
467-
@assert isa(lhs, Symbol) "Cannot wrap a ProxyLabel of `$lhs = $rhs` expression. The LHS must be a Symbol."
462+
if !(isa(lhs, Symbol))
463+
error(lazy"Cannot wrap a ProxyLabel of `$lhs = $rhs` expression. The LHS must be a Symbol.")
464+
end
468465
return :($lhs = $(proxy_args_rhs(rhs)))
469466
end
470467

471468
function proxy_args_rhs(rhs)
472469
if isa(rhs, Symbol)
473-
return :(GraphPPL.proxylabel($(QuoteNode(rhs)), nothing, $rhs))
470+
return :(GraphPPL.proxylabel($(QuoteNode(rhs)), $rhs, nothing, GraphPPL.False()))
474471
elseif @capture(rhs, rlabel_[index__])
475-
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), $(Expr(:tuple, index...)), $rlabel))
472+
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), $rlabel, $(Expr(:tuple, index...)), GraphPPL.False()))
476473
elseif @capture(rhs, new(rlabel_[index__]))
477-
newrhs = gensym(:force_create)
478-
errmsg = "Cannot force create a new label with the `new($rlabel[$(index...)])`. The label already exists."
479-
return :(
480-
let $newrhs = if isassigned($rlabel, $(index...))
481-
error($errmsg)
482-
else
483-
GraphPPL.getorcreate!(__model__, __context__, $(QuoteNode(rlabel)), $(index...))
484-
end
485-
GraphPPL.proxylabel($(QuoteNode(rlabel)), $(Expr(:tuple, index...)), $newrhs)
486-
end
487-
)
474+
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), $rlabel, $(Expr(:tuple, index...)), GraphPPL.True()))
488475
end
489-
return rhs
476+
return :(GraphPPL.proxylabel(:anonymous, $rhs, nothing, GraphPPL.False()))
490477
end
491478

492479
"""
@@ -553,10 +540,10 @@ end
553540
combine_broadcast_args(args::Nothing, kwargs::Nothing) = nothing
554541

555542
generate_lhs_proxylabel(var, index::Nothing) = quote
556-
GraphPPL.proxylabel($(QuoteNode(var)), nothing, $var)
543+
GraphPPL.proxylabel($(QuoteNode(var)), $var, nothing, GraphPPL.True())
557544
end
558545
generate_lhs_proxylabel(var, index::AbstractArray) = quote
559-
GraphPPL.proxylabel($(QuoteNode(var)), $(Expr(:tuple, index...)), $var)
546+
GraphPPL.proxylabel($(QuoteNode(var)), $var, $(Expr(:tuple, index...)), GraphPPL.True())
560547
end
561548

562549
__combine_axes() = Base.OneTo(1)
@@ -611,7 +598,7 @@ function convert_tilde_expression(e::Expr)
611598
combinable_args = kwargs === nothing ? args : vcat(args, [kwarg.args[2] for kwarg in kwargs])
612599
@capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.")
613600
combinablesym = gensym()
614-
getorcreate_lhs = generate_get_or_create(var, :(GraphPPL.__combine_axes($combinablesym...)))
601+
getorcreate_lhs = generate_get_or_create(var, :(GraphPPL.__combine_axes($combinablesym...)), :(($fform)()))
615602
returnvalsym = gensym()
616603
return quote
617604
$combinablesym = ($(combinable_args...),)
@@ -710,7 +697,7 @@ function get_make_node_function(ms_body, ms_args, ms_name)
710697
__parent_context__::GraphPPL.Context,
711698
__options__::GraphPPL.NodeCreationOptions,
712699
::typeof($ms_name),
713-
__lhs_interface__::Union{GraphPPL.NodeLabel, GraphPPL.ProxyLabel},
700+
__lhs_interface__::Union{GraphPPL.NodeLabel, GraphPPL.ProxyLabel, GraphPPL.VariableRef},
714701
__rhs_interfaces__::NamedTuple,
715702
__n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))}
716703
)
@@ -722,7 +709,7 @@ function get_make_node_function(ms_body, ms_args, ms_name)
722709
__model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__
723710
)
724711
GraphPPL.returnval!(__context__, __returnval__)
725-
return __context__, GraphPPL.unroll(__lhs_interface__)
712+
return __context__, __lhs_interface__
726713
end
727714

728715
function GraphPPL.add_terminated_submodel!(

src/plugins/variational_constraints/variational_constraints_engine.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,10 @@ edge case but don't know how we can resolve this, let alone efficiently. Please
677677
end
678678
end
679679

680+
__resolve(model::Model, label::VariableRef) = __resolve(model, getifcreated(model, label.context, label))
681+
__resolve(model::Model, label::AbstractArray{T}) where {T <: VariableRef} =
682+
__resolve(model, map(l -> getifcreated(model, l.context, l), label))
683+
680684
function __resolve(model::Model, label::NodeLabel)
681685
data = model[label]
682686
return __resolve(model, data, getproperties(data), index(getproperties(data)))

src/resizable_array.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ function recursive_size(::Val{N}, vector::Vector{T}) where {N, T <: Vector}
6868
end
6969

7070
function setindex!(array::ResizableArray{T, V, N}, value, index...) where {T, V, N}
71-
@assert N === length(index) "Invalid index $(index) for $(array)"
71+
if !(N === length(index))
72+
error(lazy"Invalid index $(index) for $(array)")
73+
end
7274
recursive_setindex!(Val(N), array.data, value, index...)
7375
return array
7476
end
@@ -128,7 +130,9 @@ function getindex(array::ResizableArray{T, V, N}, index::Vararg{UnitRange}) wher
128130
end
129131

130132
function getindex(array::ResizableArray{T, V, N}, index::Vararg{Int}) where {T, V, N}
131-
@assert N >= length(index) "Invalid index $(index) for $(array) of shape $(size(array)))"
133+
if !(N >= length(index))
134+
error(lazy"Invalid index $(index) for $(array) of shape $(size(array)))")
135+
end
132136
return recursive_getindex(Val(length(index)), array.data, index...)
133137
end
134138

0 commit comments

Comments
 (0)