Skip to content

Commit d883ed8

Browse files
committed
Fix variate compatibility for lazy node label
1 parent fb78414 commit d883ed8

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/graph_engine.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,8 +1210,11 @@ __lazy_node_label_check_variate_compatability(label::LazyNodeLabel, collection::
12101210
function __lazy_node_label_check_variate_compatability(label::LazyNodeLabel, collection, indices)
12111211
# The empty indices may be passed as a result of the `combine_axes` function in the broadcasting
12121212
# In this case the `indices` are `Tuple{}`
1213-
if !isempty(indices)::Bool && !(checkbounds(Bool, collection, indices...)::Bool)
1214-
error(BoundsError(label.name, indices))
1213+
if !isempty(indices)::Bool
1214+
# The `Tuple{Nothing}` indices may be passed as a result of the `~` operation without indices on LHS
1215+
if !(isone(length(indices)) && isnothing(first(indices))) && !(checkbounds(Bool, collection, indices...)::Bool)
1216+
error(BoundsError(label.name, indices))
1217+
end
12151218
end
12161219
return true
12171220
end

test/graph_construction_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,3 +1499,22 @@ end
14991499
@test_broken length(collect(filter(as_variable(:tmp), model))) == 10
15001500
end
15011501
end
1502+
1503+
@testitem "LazyIndex should support empty indices if array is passed" begin
1504+
import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex
1505+
1506+
include("testutils.jl")
1507+
1508+
@model function foo(y)
1509+
x ~ MvNormal([1, 1], [1 0.0; 0.0 1.0])
1510+
y ~ MvNormal(x, [1.0 0.0; 0.0 1.0])
1511+
end
1512+
1513+
model = create_model(foo()) do model, ctx
1514+
return (; y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :y, LazyIndex([ 1.0, 1.0 ])))
1515+
end
1516+
1517+
@test length(collect(filter(as_node(MvNormal), model))) == 2
1518+
@test length(collect(filter(as_variable(:x), model))) == 1
1519+
@test length(collect(filter(as_variable(:y), model))) == 1
1520+
end

0 commit comments

Comments
 (0)