Skip to content

Commit 4076dff

Browse files
committed
Implement nested indexing
1 parent c97718a commit 4076dff

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Distributions = "0.25"
3333
Documenter = "1.0"
3434
GraphPlot = "0.5, 0.6"
3535
MacroTools = "0.5"
36-
MetaGraphsNext = "0.6, 0.7"
36+
MetaGraphsNext = "~0.7.0"
3737
NamedTupleTools = "0.14"
3838
Static = "0.8, 1"
3939
StaticArrays = "1.6"

src/model_macro.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,19 +465,30 @@ function proxy_args_lhs_eq_rhs(lhs, rhs)
465465
return :($lhs = $(proxy_args_rhs(rhs)))
466466
end
467467

468-
function proxy_args_rhs(rhs)
468+
function recursive_rhs_indexing(rhs)
469+
name = rhs
470+
while @capture(name, rlabel_[index__] | new(rlabel_[index__]))
471+
name = rlabel
472+
end
469473
if isa(rhs, Symbol)
470-
return :(GraphPPL.proxylabel($(QuoteNode(rhs)), $rhs, nothing, GraphPPL.False()))
474+
return rhs
471475
elseif @capture(rhs, rlabel_[index__])
472-
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), $rlabel, $(Expr(:tuple, index...)), GraphPPL.False()))
476+
return :(GraphPPL.proxylabel($(QuoteNode(name)), $(recursive_rhs_indexing(rlabel)), $(Expr(:tuple, index...)), GraphPPL.False()))
473477
elseif @capture(rhs, new(rlabel_[index__]))
474-
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), $rlabel, $(Expr(:tuple, index...)), GraphPPL.True()))
478+
return :(GraphPPL.proxylabel($(QuoteNode(name)), $(recursive_rhs_indexing(rlabel)), $(Expr(:tuple, index...)), GraphPPL.True()))
475479
elseif @capture(rhs, rlabel_...)
476480
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), GraphPPL.Splat($rlabel), nothing, GraphPPL.False())...)
477481
end
478482
return :(GraphPPL.proxylabel(:anonymous, $rhs, nothing, GraphPPL.False()))
479483
end
480484

485+
function proxy_args_rhs(rhs)
486+
if isa(rhs, Symbol)
487+
return :(GraphPPL.proxylabel($(QuoteNode(rhs)), $rhs, nothing, GraphPPL.False()))
488+
end
489+
return recursive_rhs_indexing(rhs)
490+
end
491+
481492
"""
482493
combine_args(args::Vector, kwargs::Nothing)
483494

test/graph_construction_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,3 +1984,22 @@ end
19841984
@test GraphPPL.getname(model[normal_node, x[2]]) ==
19851985
@test GraphPPL.getname(model[normal_node, y]) == :out
19861986
end
1987+
1988+
@testitem "Multiple indices in rhs statement" begin
1989+
using Distributions
1990+
using GraphPPL
1991+
import GraphPPL: @model, create_model, datalabel, NodeCreationOptions, neighbors
1992+
1993+
@model function multiple_indices(prior_params, y)
1994+
x ~ Normal(prior_params[1][1], prior_params[1][2])
1995+
y ~ Normal(x, 1.0)
1996+
end
1997+
model = create_model(multiple_indices(prior_params = [[1, 2]])) do model, ctx
1998+
y = datalabel(model, ctx, NodeCreationOptions(kind = :data), :y, rand())
1999+
return (y = y,)
2000+
end
2001+
2002+
@test length(collect(filter(as_node(Normal), model))) == 2
2003+
@test length(collect(filter(as_variable(:y), model))) == 1
2004+
@test length(collect(filter(as_variable(:x), model))) == 1
2005+
end

test/model_macro_tests.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,43 @@ end
11811181
@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output
11821182
end
11831183

1184+
@testitem "proxy_args_rhs" begin
1185+
import GraphPPL: proxy_args_rhs, apply_pipeline, recursive_rhs_indexing
1186+
1187+
include("testutils.jl")
1188+
1189+
# Test 1: Input expression with a function call in rhs arguments
1190+
input = :x
1191+
output = quote
1192+
GraphPPL.proxylabel(:x, x, nothing, GraphPPL.False())
1193+
end
1194+
@test_expression_generating proxy_args_rhs(input) output
1195+
1196+
input = quote
1197+
x[1]
1198+
end
1199+
output = quote
1200+
GraphPPL.proxylabel(:x, x, (1,), GraphPPL.False())
1201+
end
1202+
@test_expression_generating proxy_args_rhs(input) output
1203+
1204+
input = quote
1205+
x[1, 2]
1206+
end
1207+
output = quote
1208+
GraphPPL.proxylabel(:x, x, (1, 2), GraphPPL.False())
1209+
end
1210+
@test_expression_generating proxy_args_rhs(input) output
1211+
1212+
input = quote
1213+
x[1][1]
1214+
end
1215+
output = quote
1216+
GraphPPL.proxylabel(:x, GraphPPL.proxylabel(:x, x, (1,), GraphPPL.False()), (1,), GraphPPL.False())
1217+
end
1218+
@test_expression_generating proxy_args_rhs(input) output
1219+
end
1220+
11841221
@testitem "add_get_or_create_expression" begin
11851222
import GraphPPL: add_get_or_create_expression, apply_pipeline
11861223

@@ -2048,7 +2085,7 @@ end
20482085
@test !isnothing(GraphPPL.create_model(somemodel(a = 1, b = 2)))
20492086
end
20502087

2051-
@testitem "model should warn users against incorrect usages of `=` operator with random variables" begin
2088+
@testitem "model should warn users against incorrect usages of `=` operator with random variables" begin
20522089
using GraphPPL, Distributions
20532090
import GraphPPL: @model
20542091

@@ -2058,5 +2095,7 @@ end
20582095
y ~ Normal(0, t)
20592096
end
20602097

2061-
@test_throws "One of the arguments to `exp` is of type `GraphPPL.VariableRef`. Did you mean to create a new random variable with `:=` operator instead?" GraphPPL.create_model(somemodel())
2098+
@test_throws "One of the arguments to `exp` is of type `GraphPPL.VariableRef`. Did you mean to create a new random variable with `:=` operator instead?" GraphPPL.create_model(
2099+
somemodel()
2100+
)
20622101
end

0 commit comments

Comments
 (0)