Skip to content

Commit efd655d

Browse files
committed
add_edge! for multivariate inputs
1 parent ff106fa commit efd655d

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

src/graph_engine.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1666,11 +1666,20 @@ function add_edge!(
16661666
return add_edge!(model, factor_node_id, factor_node_propeties, variable_node_id, interface_name, 1)
16671667
end
16681668

1669+
add_edge!(
1670+
model::Model,
1671+
factor_node_id::NodeLabel,
1672+
factor_node_propeties::FactorNodeProperties,
1673+
variable_node_id::Union{ProxyLabel, VariableRef},
1674+
interface_name::Symbol,
1675+
index
1676+
) = add_edge!(model, factor_node_id, factor_node_propeties, unroll(variable_node_id), interface_name, index)
1677+
16691678
function add_edge!(
16701679
model::Model,
16711680
factor_node_id::NodeLabel,
16721681
factor_node_propeties::FactorNodeProperties,
1673-
variable_node_id::Union{ProxyLabel, NodeLabel, VariableRef},
1682+
variable_node_id::Union{NodeLabel},
16741683
interface_name::Symbol,
16751684
index
16761685
)

test/graph_construction_tests.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1770,4 +1770,48 @@ end
17701770
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." create_model(
17711771
test_model(y = 1)
17721772
)
1773-
end
1773+
end
1774+
1775+
@testitem "Multivariate input to function" begin
1776+
using GraphPPL
1777+
import GraphPPL: create_model, getorcreate!, datalabel
1778+
1779+
include("testutils.jl")
1780+
function dot end
1781+
function relu end
1782+
1783+
@model function neuron(in, out)
1784+
local w
1785+
for i in 1:(length(in))
1786+
w[i] ~ Normal(0.0, 1.0)
1787+
end
1788+
bias ~ Normal(0.0, 1.0)
1789+
unactivated := dot(in, w) + bias
1790+
out := relu(unactivated)
1791+
end
1792+
1793+
@model function neural_network_layer(in, out, n)
1794+
for i in 1:n
1795+
out[i] ~ neuron(in = in)
1796+
end
1797+
end
1798+
1799+
@model function neural_net(in, out)
1800+
local softin
1801+
for i in 1:length(in)
1802+
softin[i] ~ Normal(in[i], 1.0)
1803+
end
1804+
h1 ~ neural_network_layer(in = softin, n = 10)
1805+
h2 ~ neural_network_layer(in = h1, n = 16)
1806+
out ~ neural_network_layer(in = h2, n = 2)
1807+
end
1808+
1809+
model = create_model(neural_net()) do model, ctx
1810+
in = datalabel(model, ctx, GraphPPL.NodeCreationOptions(kind = :data), :in, rand(3))
1811+
out = datalabel(model, ctx, GraphPPL.NodeCreationOptions(kind = :data), :out, randn(2))
1812+
return (in = in, out = out)
1813+
end
1814+
@test length(collect(filter(as_node(Normal), model))) == 253
1815+
@test length(collect(filter(as_node(dot), model))) == 28
1816+
@test length(collect(filter(as_variable(:in), model))) == 3
1817+
end

0 commit comments

Comments
 (0)