From 31944dbc8d6e558a41ffeebb3cf14b9befa996b9 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Fri, 2 May 2025 13:44:50 +0200 Subject: [PATCH 1/6] Refactor the base package --- .cursor/rules/julia.mdc | 105 ++ Project.toml | 2 + src/GraphPPL.jl | 92 +- src/backends/default.jl | 2 +- src/core/abstract_types.jl | 37 + src/core/errors.jl | 5 + src/core/functional_indices.jl | 100 ++ src/core/node_types.jl | 49 + src/{ => core}/resizable_array.jl | 0 src/{ => generators}/model_generator.jl | 0 src/graph/interfaces.jl | 57 + src/graph/node_data.jl | 88 + src/graph/node_labels.jl | 65 + src/graph_engine.jl | 2191 ----------------------- src/{ => macros}/model_macro.jl | 63 - src/model/context.jl | 175 ++ src/model/indexed_variable.jl | 22 + src/model/model.jl | 174 ++ src/model/model_filtering.jl | 65 + src/model/node_creation.jl | 202 +++ src/model/proxy_label.jl | 139 ++ src/model/var_dict.jl | 42 + src/model/variable_ref.jl | 300 ++++ src/nodes/node_materialization.jl | 572 ++++++ src/nodes/node_properties.jl | 92 + src/plugins/plugin_processing.jl | 37 + src/{ => plugins}/plugins_collection.jl | 0 src/utils/macro_utils.jl | 66 + test/testutils.jl | 2 +- 29 files changed, 2480 insertions(+), 2264 deletions(-) create mode 100644 .cursor/rules/julia.mdc create mode 100644 src/core/abstract_types.jl create mode 100644 src/core/errors.jl create mode 100644 src/core/functional_indices.jl create mode 100644 src/core/node_types.jl rename src/{ => core}/resizable_array.jl (100%) rename src/{ => generators}/model_generator.jl (100%) create mode 100644 src/graph/interfaces.jl create mode 100644 src/graph/node_data.jl create mode 100644 src/graph/node_labels.jl delete mode 100644 src/graph_engine.jl rename src/{ => macros}/model_macro.jl (93%) create mode 100644 src/model/context.jl create mode 100644 src/model/indexed_variable.jl create mode 100644 src/model/model.jl create mode 100644 src/model/model_filtering.jl create mode 100644 src/model/node_creation.jl create mode 100644 src/model/proxy_label.jl create mode 100644 src/model/var_dict.jl create mode 100644 src/model/variable_ref.jl create mode 100644 src/nodes/node_materialization.jl create mode 100644 src/nodes/node_properties.jl create mode 100644 src/plugins/plugin_processing.jl rename src/{ => plugins}/plugins_collection.jl (100%) create mode 100644 src/utils/macro_utils.jl diff --git a/.cursor/rules/julia.mdc b/.cursor/rules/julia.mdc new file mode 100644 index 00000000..8914defe --- /dev/null +++ b/.cursor/rules/julia.mdc @@ -0,0 +1,105 @@ +--- +description: +globs: +alwaysApply: true +--- +You are an expert in Julia language programming, data science, and numerical computing. + +Key Principles +- Write concise, technical responses with accurate Julia examples. +- Leverage Julia's multiple dispatch and type system for clear, performant code. +- Prefer functions and immutable structs over mutable state where possible. +- Use descriptive variable names with auxiliary verbs (e.g., is_active, has_permission). +- Use lowercase with underscores for directories and files (e.g., src/data_processing.jl). +- Favor named exports for functions and types. +- Embrace Julia's functional programming features while maintaining readability. + +Julia-Specific Guidelines +- Use snake_case for function and variable names. +- Use PascalCase for type names (structs and abstract types). +- Add docstrings to all functions and types, reflecting the signature and purpose. +- Keep docstrings up to date with all the changes and argument signatures. +- Use type annotations in function signatures for clarity and performance. +- Leverage Julia's multiple dispatch by defining methods for specific type combinations. +- Use the `@kwdef` macro for structs to enable keyword constructors. +- Implement custom `show` methods for user-defined types. +- Use modules to organize code and control namespace. + +Function Definitions +- Use descriptive names that convey the function's purpose. +- Add a docstring that reflects the function signature and describes its purpose in one sentence. +- Update docstrings if implementation changed. +- Describe the return value in the docstring. +- Example: + ```julia + """ + process_data(data::Vector{Float64}, threshold::Float64) + + Process the input `data` by applying a `threshold` filter and return the filtered result. + """ + function process_data(data::Vector{Float64}, threshold::Float64) + # Function implementation + end + ``` + +Struct Definitions +- Always use the `@kwdef` macro to enable keyword constructors. +- Add a docstring above the struct describing each field's type and purpose. +- Implement a custom `show` method for better struct printing. + +Error Handling and Validation +- Use Julia's exception system for error handling. +- Create custom exception types for specific error cases. +- Use guard clauses to handle preconditions and invalid states early. +- Implement proper error logging and user-friendly error messages. +- Example: + ```julia + struct InvalidInputError <: Exception + msg::String + end + + function process_positive_number(x::Number) + x <= 0 && throw(InvalidInputError("Input must be positive")) + # Process the number + end + ``` + +Performance Optimization +- Use type annotations where necessary to avoid type instabilities. +- Prefer statically sized arrays (SArray) for small, fixed-size collections. +- Use views (@views macro) to avoid unnecessary array copies. +- Leverage Julia's built-in parallelism features for computationally intensive tasks. +- Use benchmarking tools (BenchmarkTools.jl) to identify and optimize bottlenecks. +- We store benchmarks in the benchmark folder in the root of the repository. +- Create benchmarks for performance sensitive pieces of code. + +Testing +- For each file in the source code create a test file with the `_tests.jl` suffix, e.g. `src/folder/subfoldeer/file.jl` -> `test/folder/subfolder/file_tests.jl` +- Create small individual tests in `@testitem` blocks +- Write test cases of increasing difficulty with comments explaining what is being tested. +- Use individual `@test` calls for each assertion, not for blocks. +- Example: + ```julia +@testitem "A function can be called" begin + import Package: function + + @test function(2) === 3 +end + ``` + +Dependencies +- Use the built-in package manager (Pkg) for managing dependencies. +- Specify version constraints in the Project.toml file. +- Consider using compatibility bounds (e.g., "Package" = "1.2, 2") to balance stability and updates. + +Code Organization +- Use modules to organize related functionality. +- Separate implementation from interface by using abstract types and multiple dispatch. +- Use include() to split large modules into multiple files. +- Follow a consistent project structure (e.g., src/, test/, docs/). + +Documentation +- Write comprehensive docstrings for all public functions and types. +- Use Julia's built-in documentation system (Documenter.jl) for generating documentation. +- Include examples in docstrings to demonstrate usage. +- Keep documentation up-to-date with code changes. \ No newline at end of file diff --git a/Project.toml b/Project.toml index 99eba2e7..6cd06e53 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Wouter Nuijten ", "Dmitry Bagaev index = GraphPPL.FunctionalIndex{:begin}(firstindex) +(begin) + +julia> index([ 2.0, 3.0 ]) +1 + +julia> (index + 1)([ 2.0, 3.0 ]) +2 + +julia> index = GraphPPL.FunctionalIndex{:end}(lastindex) +(end) + +julia> index([ 2.0, 3.0 ]) +2 +``` +""" +(index::FunctionalIndex{R, F})(collection) where {R, F} = __functional_index_apply(R, index.f, collection)::Integer + +Base.getindex(x::AbstractArray, fi::FunctionalIndex) = x[fi(x)] +# Base.getindex(x::NodeLabel, index::FunctionalIndex) = index(x) + +__functional_index_apply(::Symbol, f, collection) = f(collection) +__functional_index_apply(subindex::FunctionalIndex, f::Tuple{typeof(+), <:Integer}, collection) = subindex(collection) .+ f[2] +__functional_index_apply(subindex::FunctionalIndex, f::Tuple{typeof(-), <:Integer}, collection) = subindex(collection) .- f[2] + +Base.:(+)(left::FunctionalIndex, index::Integer) = FunctionalIndex{left}((+, index)) +Base.:(-)(left::FunctionalIndex, index::Integer) = FunctionalIndex{left}((-, index)) + +__functional_index_print(io::IO, f::typeof(firstindex)) = nothing +__functional_index_print(io::IO, f::typeof(lastindex)) = nothing +__functional_index_print(io::IO, f::Tuple{typeof(+), <:Integer}) = print(io, " + ", f[2]) +__functional_index_print(io::IO, f::Tuple{typeof(-), <:Integer}) = print(io, " - ", f[2]) + +function Base.show(io::IO, index::FunctionalIndex{R, F}) where {R, F} + print(io, "(") + print(io, R) + __functional_index_print(io, index.f) + print(io, ")") +end + +""" + FunctionalRange(left, range) + +A range can handle `FunctionalIndex` as one of (or both) the bounds. + +```jldoctest +julia> first = GraphPPL.FunctionalIndex{:begin}(firstindex) +(begin) + +julia> last = GraphPPL.FunctionalIndex{:end}(lastindex) +(end) + +julia> range = GraphPPL.FunctionalRange(first + 1, last - 1) +((begin) + 1):((end) - 1) + +julia> [ 1.0, 2.0, 3.0, 4.0 ][range] +2-element Vector{Float64}: + 2.0 + 3.0 +``` +""" +struct FunctionalRange{L, R} + left::L + right::R +end + +(::Colon)(left, right::FunctionalIndex) = FunctionalRange(left, right) +(::Colon)(left::FunctionalIndex, right) = FunctionalRange(left, right) +(::Colon)(left::FunctionalIndex, right::FunctionalIndex) = FunctionalRange(left, right) + +Base.getindex(collection::AbstractArray, range::FunctionalRange{L, R}) where {L, R <: FunctionalIndex} = + collection[(range.left):range.right(collection)] +Base.getindex(collection::AbstractArray, range::FunctionalRange{L, R}) where {L <: FunctionalIndex, R} = + collection[range.left(collection):(range.right)] +Base.getindex(collection::AbstractArray, range::FunctionalRange{L, R}) where {L <: FunctionalIndex, R <: FunctionalIndex} = + collection[range.left(collection):range.right(collection)] + +function Base.show(io::IO, range::FunctionalRange) + print(io, range.left, ":", range.right) +end \ No newline at end of file diff --git a/src/core/node_types.jl b/src/core/node_types.jl new file mode 100644 index 00000000..cfc7dfe6 --- /dev/null +++ b/src/core/node_types.jl @@ -0,0 +1,49 @@ +""" + NodeType + +Abstract type representing either `Composite` or `Atomic` trait for a given object. By default is `Atomic` unless specified otherwise. +""" +abstract type NodeType end + +""" + Composite + +`Composite` object used as a trait of structs and functions that are composed of multiple nodes and therefore implement `make_node!`. +""" +struct Composite <: NodeType end + +""" + Atomic +`Atomic` object used as a trait of structs and functions that are composed of a single node and are therefore materialized as a single node in the factor graph. +""" +struct Atomic <: NodeType end + +NodeType(backend, fform) = error("Backend $backend must implement a method for `NodeType` for `$(fform)`.") + +""" + NodeBehaviour + +Abstract type representing either `Deterministic` or `Stochastic` for a given object. By default is `Deterministic` unless specified otherwise. +""" +abstract type NodeBehaviour end + +""" + Stochastic + +`Stochastic` object used to parametrize factor node object with stochastic type of relationship between variables. +""" +struct Stochastic <: NodeBehaviour end + +""" + Deterministic + +`Deterministic` object used to parametrize factor node object with determinstic type of relationship between variables. +""" +struct Deterministic <: NodeBehaviour end + +""" + NodeBehaviour(backend, fform) + +Returns a `NodeBehaviour` object for a given `backend` and `fform`. +""" +NodeBehaviour(backend, fform) = error("Backend $backend must implement a method for `NodeBehaviour` for `$(fform)`.") \ No newline at end of file diff --git a/src/resizable_array.jl b/src/core/resizable_array.jl similarity index 100% rename from src/resizable_array.jl rename to src/core/resizable_array.jl diff --git a/src/model_generator.jl b/src/generators/model_generator.jl similarity index 100% rename from src/model_generator.jl rename to src/generators/model_generator.jl diff --git a/src/graph/interfaces.jl b/src/graph/interfaces.jl new file mode 100644 index 00000000..161af642 --- /dev/null +++ b/src/graph/interfaces.jl @@ -0,0 +1,57 @@ +""" + aliases(backend, fform) + +Returns a collection of aliases for `fform` depending on the `backend`. +""" +aliases(backend, fform) = error("Backend $backend must implement a method for `aliases` for `$(fform)`.") +aliases(model::AbstractModel, fform::F) where {F} = aliases(getbackend(model), fform) +""" + factor_alias(backend, fform, interfaces) + +Returns the alias for a given `fform` and `interfaces` with a given `backend`. +""" +function factor_alias end + +factor_alias(backend, fform, interfaces) = + error("The backend $backend must implement a method for `factor_alias` for `$(fform)` and `$(interfaces)`.") +factor_alias(model::AbstractModel, fform::F, interfaces) where {F} = factor_alias(getbackend(model), fform, interfaces) + +""" + interfaces(backend, fform, ::StaticInt{N}) where N + +Returns the interfaces for a given `fform` and `backend` with a given amount of interfaces `N`. +""" +function interfaces end + +interfaces(backend, fform, ninputs) = + error("The backend $(backend) must implement a method for `interfaces` for `$(fform)` and `$(ninputs)` number of inputs.") +interfaces(model::AbstractModel, fform::F, ninputs) where {F} = interfaces(getbackend(model), fform, ninputs) + +""" + interface_aliases(backend, fform) + +Returns the aliases for a given `fform` and `backend`. +""" +function interface_aliases end + +interface_aliases(backend, fform) = error("The backend $backend must implement a method for `interface_aliases` for `$(fform)`.") +interface_aliases(model::AbstractModel, fform::F) where {F} = interface_aliases(getbackend(model), fform) + +""" + default_parametrization(backend, fform, rhs) + +Returns the default parametrization for a given `fform` and `backend` with a given `rhs`. +""" +function default_parametrization end + +default_parametrization(backend, nodetype, fform, rhs) = + error("The backend $backend must implement a method for `default_parametrization` for `$(fform)` (`$(nodetype)`) and `$(rhs)`.") +default_parametrization(model::AbstractModel, nodetype, fform::F, rhs) where {F} = + default_parametrization(getbackend(model), nodetype, fform, rhs) + +""" + instantiate(::Type{Backend}) + +Instantiates a default backend object of the specified type. Should be implemented for all backends. +""" +instantiate(backendtype) = error("The backend of type $backendtype must implement a method for `instantiate`.") \ No newline at end of file diff --git a/src/graph/node_data.jl b/src/graph/node_data.jl new file mode 100644 index 00000000..1d75d5e1 --- /dev/null +++ b/src/graph/node_data.jl @@ -0,0 +1,88 @@ +""" + NodeData(context, properties, plugins) + +Data associated with a node in a probabilistic graphical model. +The `context` field stores the context of the node. +The `properties` field stores the properties of the node. +The `extra` field stores additional properties of the node depending on which plugins were enabled. +""" +mutable struct NodeData <: AbstractNodeData + const context :: Context + const properties :: Union{VariableNodeProperties, FactorNodeProperties{NodeData}} + const extra :: UnorderedDictionary{Symbol, Any} +end + +NodeData(context, properties) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}()) + +function Base.show(io::IO, nodedata::NodeData) + context = getcontext(nodedata) + properties = getproperties(nodedata) + print(io, "NodeData in context ", shortname(context), " with properties ", properties) + extra = getextra(nodedata) + if !isempty(extra) + print(io, " with extra: ") + print(io, extra) + end +end + +getcontext(node::NodeData) = node.context +getproperties(node::NodeData) = node.properties +getextra(node::NodeData) = node.extra + +is_constant(node::NodeData) = is_constant(getproperties(node)) + +""" + hasextra(node::NodeData, key::Symbol) + +Checks if `NodeData` has an extra property with the given key. +""" +hasextra(node::NodeData, key::Symbol) = haskey(node.extra, key) +""" + getextra(node::NodeData, key::Symbol, [ default ]) + +Returns the extra property with the given key. Optionally, if the property does not exist, returns the default value. +""" +getextra(node::NodeData, key::Symbol) = getindex(node.extra, key) +getextra(node::NodeData, key::Symbol, default) = hasextra(node, key) ? getextra(node, key) : default + +""" + setextra!(node::NodeData, key::Symbol, value) + +Sets the extra property with the given key to the given value. +""" +setextra!(node::NodeData, key::Symbol, value) = insert!(node.extra, key, value) + +""" +A compile time key to access the `extra` properties of the `NodeData` structure. +""" +struct NodeDataExtraKey{K, T} end + +getkey(::NodeDataExtraKey{K, T}) where {K, T} = K + +function hasextra(node::NodeData, key::NodeDataExtraKey{K}) where {K} + return haskey(node.extra, K) +end +function getextra(node::NodeData, key::NodeDataExtraKey{K, T})::T where {K, T} + return getindex(node.extra, K)::T +end +function getextra(node::NodeData, key::NodeDataExtraKey{K, T}, default::T)::T where {K, T} + return hasextra(node, key) ? (getextra(node, key)::T) : default +end +function setextra!(node::NodeData, key::NodeDataExtraKey{K}, value::T) where {K, T} + return insert!(node.extra, K, value) +end + +""" + is_factor(nodedata::NodeData) + +Returns `true` if the node data is associated with a factor node, `false` otherwise. +See also: [`is_variable`](@ref), +""" +is_factor(node::NodeData) = is_factor(getproperties(node)) +""" + is_variable(nodedata::NodeData) + +Returns `true` if the node data is associated with a variable node, `false` otherwise. +See also: [`is_factor`](@ref), +""" +is_variable(node::NodeData) = is_variable(getproperties(node)) \ No newline at end of file diff --git a/src/graph/node_labels.jl b/src/graph/node_labels.jl new file mode 100644 index 00000000..68478370 --- /dev/null +++ b/src/graph/node_labels.jl @@ -0,0 +1,65 @@ +""" + NodeLabel(name, global_counter::Int64) + +Unique identifier for a node (factor or variable) in a probabilistic graphical model. +""" +mutable struct NodeLabel + const name::Any + const global_counter::Int64 +end + +Base.length(label::NodeLabel) = 1 +Base.size(label::NodeLabel) = () +Base.getindex(label::NodeLabel, any) = label +Base.:(<)(left::NodeLabel, right::NodeLabel) = left.global_counter < right.global_counter +Base.broadcastable(label::NodeLabel) = Ref(label) + +getname(label::NodeLabel) = label.name +getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(first(labels)) +iterate(label::NodeLabel) = (label, nothing) +iterate(label::NodeLabel, any) = nothing + +to_symbol(label::NodeLabel) = to_symbol(label.name, label.global_counter) +to_symbol(name::Any, index::Int) = Symbol(string(name, "_", index)) + +Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter) +Base.:(==)(label1::NodeLabel, label2::NodeLabel) = label1.name == label2.name && label1.global_counter == label2.global_counter +Base.hash(label::NodeLabel, h::UInt) = hash(label.global_counter, h) + +""" + EdgeLabel(symbol, index) + +A unique identifier for an edge in a probabilistic graphical model. +""" +mutable struct EdgeLabel + const name::Symbol + const index::Union{Int, Nothing} +end + +getname(label::EdgeLabel) = label.name +getname(labels::Tuple) = map(group -> getname(group), labels) + +to_symbol(label::EdgeLabel) = to_symbol(label, label.index) +to_symbol(label::EdgeLabel, ::Nothing) = label.name +to_symbol(label::EdgeLabel, ::Int64) = Symbol(string(label.name) * "[" * string(label.index) * "]") + +Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label)) +Base.:(==)(label1::EdgeLabel, label2::EdgeLabel) = label1.name == label2.name && label1.index == label2.index +Base.hash(label::EdgeLabel, h::UInt) = hash(label.name, hash(label.index, h)) + +""" + FactorID(fform, index) + +A unique identifier for a factor node in a probabilistic graphical model. +""" +mutable struct FactorID{F} + const fform::F + const index::Int64 +end + +fform(id::FactorID) = id.fform +index(id::FactorID) = id.index + +Base.show(io::IO, id::FactorID) = print(io, "(", fform(id), ", ", index(id), ")") +Base.:(==)(id1::FactorID{F}, id2::FactorID{T}) where {F, T} = id1.fform == id2.fform && id1.index == id2.index +Base.hash(id::FactorID, h::UInt) = hash(id.fform, hash(id.index, h)) \ No newline at end of file diff --git a/src/graph_engine.jl b/src/graph_engine.jl deleted file mode 100644 index 27d0d55a..00000000 --- a/src/graph_engine.jl +++ /dev/null @@ -1,2191 +0,0 @@ -using MetaGraphsNext, MetaGraphsNext.Graphs, MetaGraphsNext.JLD2 -using BitSetTuples -using Static -using NamedTupleTools -using Dictionaries - -import Base: put!, haskey, getindex, getproperty, setproperty!, setindex!, vec, iterate, showerror, Exception -import MetaGraphsNext.Graphs: neighbors, degree - -export as_node, as_variable, as_context, savegraph, loadgraph - -struct NotImplementedError <: Exception - message::String -end - -showerror(io::IO, e::NotImplementedError) = print(io, "NotImplementedError: " * e.message) - -""" - FunctionalIndex - -A special type of an index that represents a function that can be used only in pair with a collection. -An example of a `FunctionalIndex` can be `firstindex` or `lastindex`, but more complex use cases are possible too, -e.g. `firstindex + 1`. Important part of the implementation is that the resulting structure is `isbitstype(...) = true`, -that allows to store it in parametric type as valtype. One use case for this structure is to dispatch on and to replace `begin` or `end` -(or more complex use cases, e.g. `begin + 1`). -""" -struct FunctionalIndex{R, F} - f::F - FunctionalIndex{R}(f::F) where {R, F} = new{R, F}(f) -end - -""" - (index::FunctionalIndex)(collection) - -Returns the result of applying the function `f` to the collection. - -```jldoctest -julia> index = GraphPPL.FunctionalIndex{:begin}(firstindex) -(begin) - -julia> index([ 2.0, 3.0 ]) -1 - -julia> (index + 1)([ 2.0, 3.0 ]) -2 - -julia> index = GraphPPL.FunctionalIndex{:end}(lastindex) -(end) - -julia> index([ 2.0, 3.0 ]) -2 -``` -""" -(index::FunctionalIndex{R, F})(collection) where {R, F} = __functional_index_apply(R, index.f, collection)::Integer - -Base.getindex(x::AbstractArray, fi::FunctionalIndex) = x[fi(x)] -# Base.getindex(x::NodeLabel, index::FunctionalIndex) = index(x) - -__functional_index_apply(::Symbol, f, collection) = f(collection) -__functional_index_apply(subindex::FunctionalIndex, f::Tuple{typeof(+), <:Integer}, collection) = subindex(collection) .+ f[2] -__functional_index_apply(subindex::FunctionalIndex, f::Tuple{typeof(-), <:Integer}, collection) = subindex(collection) .- f[2] - -Base.:(+)(left::FunctionalIndex, index::Integer) = FunctionalIndex{left}((+, index)) -Base.:(-)(left::FunctionalIndex, index::Integer) = FunctionalIndex{left}((-, index)) - -__functional_index_print(io::IO, f::typeof(firstindex)) = nothing -__functional_index_print(io::IO, f::typeof(lastindex)) = nothing -__functional_index_print(io::IO, f::Tuple{typeof(+), <:Integer}) = print(io, " + ", f[2]) -__functional_index_print(io::IO, f::Tuple{typeof(-), <:Integer}) = print(io, " - ", f[2]) - -function Base.show(io::IO, index::FunctionalIndex{R, F}) where {R, F} - print(io, "(") - print(io, R) - __functional_index_print(io, index.f) - print(io, ")") -end - -""" - FunctionalRange(left, range) - -A range can handle `FunctionalIndex` as one of (or both) the bounds. - -```jldoctest -julia> first = GraphPPL.FunctionalIndex{:begin}(firstindex) -(begin) - -julia> last = GraphPPL.FunctionalIndex{:end}(lastindex) -(end) - -julia> range = GraphPPL.FunctionalRange(first + 1, last - 1) -((begin) + 1):((end) - 1) - -julia> [ 1.0, 2.0, 3.0, 4.0 ][range] -2-element Vector{Float64}: - 2.0 - 3.0 -``` -""" -struct FunctionalRange{L, R} - left::L - right::R -end - -(::Colon)(left, right::FunctionalIndex) = FunctionalRange(left, right) -(::Colon)(left::FunctionalIndex, right) = FunctionalRange(left, right) -(::Colon)(left::FunctionalIndex, right::FunctionalIndex) = FunctionalRange(left, right) - -Base.getindex(collection::AbstractArray, range::FunctionalRange{L, R}) where {L, R <: FunctionalIndex} = - collection[(range.left):range.right(collection)] -Base.getindex(collection::AbstractArray, range::FunctionalRange{L, R}) where {L <: FunctionalIndex, R} = - collection[range.left(collection):(range.right)] -Base.getindex(collection::AbstractArray, range::FunctionalRange{L, R}) where {L <: FunctionalIndex, R <: FunctionalIndex} = - collection[range.left(collection):range.right(collection)] - -function Base.show(io::IO, range::FunctionalRange) - print(io, range.left, ":", range.right) -end - -""" - IndexedVariable(name, index) - -`IndexedVariable` represents a reference to a variable named `name` with index `index`. -""" -struct IndexedVariable{T} - name::Symbol - index::T -end - -getname(index::IndexedVariable) = index.name -index(index::IndexedVariable) = index.index - -Base.length(index::IndexedVariable{T} where {T}) = 1 -Base.iterate(index::IndexedVariable{T} where {T}) = (index, nothing) -Base.iterate(index::IndexedVariable{T} where {T}, any) = nothing -Base.:(==)(left::IndexedVariable, right::IndexedVariable) = (left.name == right.name && left.index == right.index) -Base.show(io::IO, variable::IndexedVariable{Nothing}) = print(io, variable.name) -Base.show(io::IO, variable::IndexedVariable) = print(io, variable.name, "[", variable.index, "]") - -""" - NodeType - -Abstract type representing either `Composite` or `Atomic` trait for a given object. By default is `Atomic` unless specified otherwise. -""" -abstract type NodeType end - -""" - Composite - -`Composite` object used as a trait of structs and functions that are composed of multiple nodes and therefore implement `make_node!`. -""" -struct Composite <: NodeType end - -""" - Atomic -`Atomic` object used as a trait of structs and functions that are composed of a single node and are therefore materialized as a single node in the factor graph. -""" -struct Atomic <: NodeType end - -NodeType(backend, fform) = error("Backend $backend must implement a method for `NodeType` for `$(fform)`.") - -""" - NodeBehaviour - -Abstract type representing either `Deterministic` or `Stochastic` for a given object. By default is `Deterministic` unless specified otherwise. -""" -abstract type NodeBehaviour end - -""" - Stochastic - -`Stochastic` object used to parametrize factor node object with stochastic type of relationship between variables. -""" -struct Stochastic <: NodeBehaviour end - -""" - Deterministic - -`Deterministic` object used to parametrize factor node object with determinstic type of relationship between variables. -""" -struct Deterministic <: NodeBehaviour end - -""" - NodeBehaviour(backend, fform) - -Returns a `NodeBehaviour` object for a given `backend` and `fform`. -""" -NodeBehaviour(backend, fform) = error("Backend $backend must implement a method for `NodeBehaviour` for `$(fform)`.") - -""" - FactorID(fform, index) - -A unique identifier for a factor node in a probabilistic graphical model. -""" -mutable struct FactorID{F} - const fform::F - const index::Int64 -end - -fform(id::FactorID) = id.fform -index(id::FactorID) = id.index - -Base.show(io::IO, id::FactorID) = print(io, "(", fform(id), ", ", index(id), ")") -Base.:(==)(id1::FactorID{F}, id2::FactorID{T}) where {F, T} = id1.fform == id2.fform && id1.index == id2.index -Base.hash(id::FactorID, h::UInt) = hash(id.fform, hash(id.index, h)) - -""" - Model(graph::MetaGraph) - -A structure representing a probabilistic graphical model. It contains a `MetaGraph` object -representing the factor graph and a `Base.RefValue{Int64}` object to keep track of the number -of nodes in the graph. - -Fields: -- `graph`: A `MetaGraph` object representing the factor graph. -- `plugins`: A `PluginsCollection` object representing the plugins enabled in the model. -- `backend`: A `Backend` object representing the backend used in the model. -- `source`: A `Source` object representing the original source code of the model (typically a `String` object). -- `counter`: A `Base.RefValue{Int64}` object keeping track of the number of nodes in the graph. -""" -struct Model{G, P, B, S} - graph::G - plugins::P - backend::B - source::S - counter::Base.RefValue{Int64} -end - -labels(model::Model) = MetaGraphsNext.labels(model.graph) -Base.isempty(model::Model) = iszero(nv(model.graph)) && iszero(ne(model.graph)) - -getplugins(model::Model) = model.plugins -getbackend(model::Model) = model.backend -getsource(model::Model) = model.source -getcounter(model::Model) = model.counter[] -setcounter!(model::Model, value) = model.counter[] = value - -Graphs.savegraph(file::AbstractString, model::GraphPPL.Model) = save(file, "__model__", model) -Graphs.loadgraph(file::AbstractString, ::Type{GraphPPL.Model}) = load(file, "__model__") - -NodeType(model::Model, fform::F) where {F} = NodeType(getbackend(model), fform) -NodeBehaviour(model::Model, fform::F) where {F} = NodeBehaviour(getbackend(model), fform) - -""" - NodeLabel(name, global_counter::Int64) - -Unique identifier for a node (factor or variable) in a probabilistic graphical model. -""" -mutable struct NodeLabel - const name::Any - const global_counter::Int64 -end - -Base.length(label::NodeLabel) = 1 -Base.size(label::NodeLabel) = () -Base.getindex(label::NodeLabel, any) = label -Base.:(<)(left::NodeLabel, right::NodeLabel) = left.global_counter < right.global_counter -Base.broadcastable(label::NodeLabel) = Ref(label) - -getname(label::NodeLabel) = label.name -getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(first(labels)) -iterate(label::NodeLabel) = (label, nothing) -iterate(label::NodeLabel, any) = nothing - -to_symbol(label::NodeLabel) = to_symbol(label.name, label.global_counter) -to_symbol(name::Any, index::Int) = Symbol(string(name, "_", index)) - -Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter) -Base.:(==)(label1::NodeLabel, label2::NodeLabel) = label1.name == label2.name && label1.global_counter == label2.global_counter -Base.hash(label::NodeLabel, h::UInt) = hash(label.global_counter, h) - -""" - EdgeLabel(symbol, index) - -A unique identifier for an edge in a probabilistic graphical model. -""" -mutable struct EdgeLabel - const name::Symbol - const index::Union{Int, Nothing} -end - -getname(label::EdgeLabel) = label.name -getname(labels::Tuple) = map(group -> getname(group), labels) - -to_symbol(label::EdgeLabel) = to_symbol(label, label.index) -to_symbol(label::EdgeLabel, ::Nothing) = label.name -to_symbol(label::EdgeLabel, ::Int64) = Symbol(string(label.name) * "[" * string(label.index) * "]") - -Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label)) -Base.:(==)(label1::EdgeLabel, label2::EdgeLabel) = label1.name == label2.name && label1.index == label2.index -Base.hash(label::EdgeLabel, h::UInt) = hash(label.name, hash(label.index, h)) - -""" - Splat{T} - -A type used to represent splatting in the model macro. Any call on the right hand side of ~ that uses splatting will be wrapped in this type. -""" -struct Splat{T} - collection::T -end - -""" - ProxyLabel(name, index, proxied) - -A label that proxies another label in a probabilistic graphical model. -The proxied objects must implement the `is_proxied(::Type) = True()`. -The proxy labels may spawn new variables in a model, if `maycreate` is set to `True()`. -""" -mutable struct ProxyLabel{P, I, M} - const name::Symbol - const proxied::P - const index::I - const maycreate::M -end - -is_proxied(any) = is_proxied(typeof(any)) -is_proxied(::Type) = False() -is_proxied(::Type{T}) where {T <: NodeLabel} = True() -is_proxied(::Type{T}) where {T <: ProxyLabel} = True() -is_proxied(::Type{T}) where {T <: AbstractArray} = is_proxied(eltype(T)) - -proxylabel(name::Symbol, proxied::Splat{T}, index, maycreate) where {T} = - [proxylabel(name, proxiedelement, index, maycreate) for proxiedelement in proxied.collection] - -# By default, `proxylabel` set `maycreate` to `False` -proxylabel(name::Symbol, proxied, index) = proxylabel(name, proxied, index, False()) -proxylabel(name::Symbol, proxied, index, maycreate) = proxylabel(is_proxied(proxied), name, proxied, index, maycreate) - -# In case if `is_proxied` returns `False` we simply return the original object, because the object cannot be proxied -proxylabel(::False, name::Symbol, proxied::Any, index::Nothing, maycreate) = proxied -proxylabel(::False, name::Symbol, proxied::Any, index::Tuple, maycreate) = proxied[index...] - -# In case if `is_proxied` returns `True`, we wrap the object into the `ProxyLabel` for later `unroll`-ing -function proxylabel(::True, name::Symbol, proxied::Any, index::Any, maycreate::Any) - return ProxyLabel(name, proxied, index, maycreate) -end - -# In case if `proxied` is another `ProxyLabel` we take `|` operation with its `maycreate` to lift it further -# This is a useful operation for `datalabels`, since they define `maycreate = True()` on their creation time -# That means that all subsequent usages of data labels will always create a new label, even when used on right hand side from `~` -function proxylabel(::True, name::Symbol, proxied::ProxyLabel, index::Any, maycreate::Any) - return ProxyLabel(name, proxied, index, proxied.maycreate | maycreate) -end - -getname(label::ProxyLabel) = label.name -index(label::ProxyLabel) = label.index - -# This function allows to overwrite the `maycreate` flag on a proxy label, might be useful for situations where code should -# definitely not create a new variable, e.g in the variational constraints plugin -set_maycreate(proxylabel::ProxyLabel, maycreate::Union{True, False}) = - ProxyLabel(proxylabel.name, proxylabel.proxied, proxylabel.index, maycreate) -set_maycreate(something, maycreate::Union{True, False}) = something - -function unroll(something) - return something -end - -function unroll(proxylabel::ProxyLabel) - return unroll(proxylabel, proxylabel.proxied, proxylabel.index, proxylabel.maycreate, proxylabel.index) -end - -function unroll(proxylabel::ProxyLabel, proxied::ProxyLabel, index, maycreate, liftedindex) - # In case of a chain of proxy-labels we should lift the index, that potentially might - # be used to create a new collection of variables - liftedindex = lift_index(maycreate, index, liftedindex) - unrolled = unroll(proxied, proxied.proxied, proxied.index, proxied.maycreate, liftedindex) - return checked_getindex(unrolled, index) -end - -function unroll(proxylabel::ProxyLabel, something::Any, index, maycreate, liftedindex) - return checked_getindex(something, index) -end - -checked_getindex(something, index::FunctionalIndex) = Base.getindex(something, index) -checked_getindex(something, index::Tuple) = Base.getindex(something, index...) -checked_getindex(something, index::Nothing) = something - -checked_getindex(nodelabel::NodeLabel, index::Nothing) = nodelabel -checked_getindex(nodelabel::NodeLabel, index::Tuple) = - error("Indexing a single node label `$(getname(nodelabel))` with an index `[$(join(index, ", "))]` is not allowed.") -checked_getindex(nodelabel::NodeLabel, index) = - error("Indexing a single node label `$(getname(nodelabel))` with an index `$index` is not allowed.") - -""" -The `lift_index` function "lifts" (or tracks) the index that is going to be used to determine the shape of the container upon creation -for a variable during the unrolling of the `ProxyLabel`. This index is used only if the container is set to be created and is not used if -variable container already exists. -""" -function lift_index end - -lift_index(::True, ::Nothing, ::Nothing) = nothing -lift_index(::True, current, ::Nothing) = current -lift_index(::True, ::Nothing, previous) = previous -lift_index(::True, current, previous) = current -lift_index(::False, current, previous) = previous - -Base.show(io::IO, proxy::ProxyLabel) = show_proxy(io, getname(proxy), index(proxy)) -show_proxy(io::IO, name::Symbol, index::Nothing) = print(io, name) -show_proxy(io::IO, name::Symbol, index::Tuple) = print(io, name, "[", join(index, ","), "]") -show_proxy(io::IO, name::Symbol, index::Any) = print(io, name, "[", index, "]") - -Base.last(label::ProxyLabel) = last(label.proxied, label) -Base.last(proxied::ProxyLabel, ::ProxyLabel) = last(proxied) -Base.last(proxied, ::ProxyLabel) = proxied - -Base.:(==)(proxy1::ProxyLabel, proxy2::ProxyLabel) = - proxy1.name == proxy2.name && proxy1.index == proxy2.index && proxy1.proxied == proxy2.proxied -Base.hash(proxy::ProxyLabel, h::UInt) = hash(proxy.maycreate, hash(proxy.name, hash(proxy.index, hash(proxy.proxied, h)))) - -# Iterator's interface methods -Base.IteratorSize(proxy::ProxyLabel) = Base.IteratorSize(indexed_last(proxy)) -Base.IteratorEltype(proxy::ProxyLabel) = Base.IteratorEltype(indexed_last(proxy)) -Base.eltype(proxy::ProxyLabel) = Base.eltype(indexed_last(proxy)) - -Base.length(proxy::ProxyLabel) = length(indexed_last(proxy)) -Base.size(proxy::ProxyLabel, dims...) = size(indexed_last(proxy), dims...) -Base.firstindex(proxy::ProxyLabel) = firstindex(indexed_last(proxy)) -Base.lastindex(proxy::ProxyLabel) = lastindex(indexed_last(proxy)) -Base.eachindex(proxy::ProxyLabel) = eachindex(indexed_last(proxy)) -Base.axes(proxy::ProxyLabel) = axes(indexed_last(proxy)) -Base.getindex(proxy::ProxyLabel, indices...) = getindex(indexed_last(proxy), indices...) -Base.size(proxy::ProxyLabel) = size(indexed_last(proxy)) -Base.broadcastable(proxy::ProxyLabel) = Base.broadcastable(indexed_last(proxy)) - -postprocess_returnval(proxy::ProxyLabel) = postprocess_returnval(indexed_last(proxy)) - -"""Similar to `Base.last` when applied on `ProxyLabel`, but also applies `checked_getindex` while unrolling""" -function indexed_last end - -indexed_last(proxy::ProxyLabel) = checked_getindex(indexed_last(proxy.proxied), proxy.index) -indexed_last(something) = something - -""" - Context - -Contains all information about a submodel in a probabilistic graphical model. -""" -struct Context - depth::Int64 - fform::Function - prefix::String - parent::Union{Context, Nothing} - submodel_counts::UnorderedDictionary{Any, Int} - children::UnorderedDictionary{FactorID, Context} - factor_nodes::UnorderedDictionary{FactorID, NodeLabel} - individual_variables::UnorderedDictionary{Symbol, NodeLabel} - vector_variables::UnorderedDictionary{Symbol, ResizableArray{NodeLabel, Vector{NodeLabel}, 1}} - tensor_variables::UnorderedDictionary{Symbol, ResizableArray{NodeLabel}} - proxies::UnorderedDictionary{Symbol, ProxyLabel} - returnval::Ref{Any} -end - -function Context(depth::Int, fform::Function, prefix::String, parent) - return Context( - depth, - fform, - prefix, - parent, - UnorderedDictionary{Any, Int}(), - UnorderedDictionary{FactorID, Context}(), - UnorderedDictionary{FactorID, NodeLabel}(), - UnorderedDictionary{Symbol, NodeLabel}(), - UnorderedDictionary{Symbol, ResizableArray{NodeLabel, Vector{NodeLabel}, 1}}(), - UnorderedDictionary{Symbol, ResizableArray{NodeLabel}}(), - UnorderedDictionary{Symbol, ProxyLabel}(), - Ref{Any}() - ) -end - -Context(parent::Context, model_fform::Function) = - Context(parent.depth + 1, model_fform, (parent.prefix == "" ? parent.prefix : parent.prefix * "_") * getname(model_fform), parent) -Context(fform) = Context(0, fform, "", nothing) -Context() = Context(identity) - -fform(context::Context) = context.fform -parent(context::Context) = context.parent -individual_variables(context::Context) = context.individual_variables -vector_variables(context::Context) = context.vector_variables -tensor_variables(context::Context) = context.tensor_variables -factor_nodes(context::Context) = context.factor_nodes -proxies(context::Context) = context.proxies -children(context::Context) = context.children -count(context::Context, fform::F) where {F} = haskey(context.submodel_counts, fform) ? context.submodel_counts[fform] : 0 -shortname(context::Context) = string(context.prefix) - -returnval(context::Context) = context.returnval[] - -function returnval!(context::Context, value) - context.returnval[] = postprocess_returnval(value) -end - -# We do not want to return `VariableRef` from the model -# In this case we replace them with the actual node labels -postprocess_returnval(value) = value -postprocess_returnval(value::Tuple) = map(postprocess_returnval, value) - -path_to_root(::Nothing) = [] -path_to_root(context::Context) = [context, path_to_root(parent(context))...] - -function generate_factor_nodelabel(context::Context, fform::F) where {F} - if count(context, fform) == 0 - set!(context.submodel_counts, fform, 1) - else - context.submodel_counts[fform] += 1 - end - return FactorID(fform, count(context, fform)) -end - -function Base.show(io::IO, mime::MIME"text/plain", context::Context) - iscompact = get(io, :compact, false)::Bool - - if iscompact - print(io, "Context(", shortname(context), " | ") - nvariables = - length(context.individual_variables) + - length(context.vector_variables) + - length(context.tensor_variables) + - length(context.proxies) - nfactornodes = length(context.factor_nodes) - print(io, nvariables, " variables, ", nfactornodes, " factor nodes") - if !isempty(context.children) - print(io, ", ", length(context.children), " children") - end - print(io, ")") - else - indentation = get(io, :indentation, 0)::Int - indentationstr = " "^indentation - indentationstrp1 = " "^(indentation + 1) - println(io, indentationstr, "Context(", shortname(context), ")") - println(io, indentationstrp1, "Individual variables: ", keys(individual_variables(context))) - println(io, indentationstrp1, "Vector variables: ", keys(vector_variables(context))) - println(io, indentationstrp1, "Tensor variables: ", keys(tensor_variables(context))) - println(io, indentationstrp1, "Proxies: ", keys(proxies(context))) - println(io, indentationstrp1, "Factor nodes: ", collect(keys(factor_nodes(context)))) - if !isempty(context.children) - println(io, indentationstrp1, "Children: ", map(shortname, values(context.children))) - end - end -end - -getname(f::Function) = String(Symbol(f)) - -haskey(context::Context, key::Symbol) = - haskey(context.individual_variables, key) || - haskey(context.vector_variables, key) || - haskey(context.tensor_variables, key) || - haskey(context.proxies, key) - -haskey(context::Context, key::FactorID) = haskey(context.factor_nodes, key) || haskey(context.children, key) - -function Base.getindex(c::Context, key::Symbol) - if haskey(c.individual_variables, key) - return c.individual_variables[key] - elseif haskey(c.vector_variables, key) - return c.vector_variables[key] - elseif haskey(c.tensor_variables, key) - return c.tensor_variables[key] - elseif haskey(c.proxies, key) - return c.proxies[key] - end - throw(KeyError(key)) -end - -function Base.getindex(c::Context, key::FactorID) - if haskey(c.factor_nodes, key) - return c.factor_nodes[key] - elseif haskey(c.children, key) - return c.children[key] - end - throw(KeyError(key)) -end - -Base.getindex(c::Context, fform, index::Int) = c[FactorID(fform, index)] - -Base.setindex!(c::Context, val::NodeLabel, key::Symbol) = set!(c.individual_variables, key, val) -Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Nothing) = set!(c.individual_variables, key, val) -Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Int) = c.vector_variables[key][index] = val -Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::NTuple{N, Int64} where {N}) = c.tensor_variables[key][index...] = val -Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, 1} where {T}, key::Symbol) = set!(c.vector_variables, key, val) -Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, N} where {T, N}, key::Symbol) = set!(c.tensor_variables, key, val) -Base.setindex!(c::Context, val::ProxyLabel, key::Symbol) = set!(c.proxies, key, val) -Base.setindex!(c::Context, val::ProxyLabel, key::Symbol, index::Nothing) = set!(c.proxies, key, val) -Base.setindex!(c::Context, val::Context, key::FactorID) = set!(c.children, key, val) -Base.setindex!(c::Context, val::NodeLabel, key::FactorID) = set!(c.factor_nodes, key, val) - -""" - VarDict - -A recursive dictionary structure that contains all variables in a probabilistic graphical model. -Iterates over all variables in the model and their children in a linear fashion, but preserves the recursive nature of the actual model. -""" -struct VarDict{T} - variables::UnorderedDictionary{Symbol, T} - children::UnorderedDictionary{FactorID, VarDict} -end - -function VarDict(context::Context) - dictvariables = merge(individual_variables(context), vector_variables(context), tensor_variables(context)) - dictchildren = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> VarDict(child), children(context))) - return VarDict(dictvariables, dictchildren) -end - -variables(vardict::VarDict) = vardict.variables -children(vardict::VarDict) = vardict.children - -haskey(vardict::VarDict, key::Symbol) = haskey(vardict.variables, key) -haskey(vardict::VarDict, key::Tuple{T, Int} where {T}) = haskey(vardict.children, FactorID(first(key), last(key))) -haskey(vardict::VarDict, key::FactorID) = haskey(vardict.children, key) - -Base.getindex(vardict::VarDict, key::Symbol) = vardict.variables[key] -Base.getindex(vardict::VarDict, f, index::Int) = vardict.children[FactorID(f, index)] -Base.getindex(vardict::VarDict, key::Tuple{T, Int} where {T}) = vardict.children[FactorID(first(key), last(key))] -Base.getindex(vardict::VarDict, key::FactorID) = vardict.children[key] - -function Base.map(f, vardict::VarDict) - mapped_variables = map(f, variables(vardict)) - mapped_children = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> map(f, child), children(vardict))) - return VarDict(mapped_variables, mapped_children) -end - -function Base.filter(f, vardict::VarDict) - filtered_variables = filter(f, variables(vardict)) - filtered_children = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> filter(f, child), children(vardict))) - return VarDict(filtered_variables, filtered_children) -end - -Base.:(==)(left::VarDict, right::VarDict) = left.variables == right.variables && left.children == right.children - -""" - NodeCreationOptions(namedtuple) - -Options for creating a node in a probabilistic graphical model. These are typically coming from the `where {}` block -in the `@model` macro, but can also be created manually. Expects a `NamedTuple` as an input. -""" -struct NodeCreationOptions{N} - options::N -end - -const EmptyNodeCreationOptions = NodeCreationOptions{Nothing}(nothing) - -NodeCreationOptions(; kwargs...) = convert(NodeCreationOptions, kwargs) - -Base.convert(::Type{NodeCreationOptions}, ::@Kwargs{}) = NodeCreationOptions(nothing) -Base.convert(::Type{NodeCreationOptions}, options) = NodeCreationOptions(NamedTuple(options)) - -Base.haskey(options::NodeCreationOptions, key::Symbol) = haskey(options.options, key) -Base.getindex(options::NodeCreationOptions, keys...) = getindex(options.options, keys...) -Base.getindex(options::NodeCreationOptions, keys::NTuple{N, Symbol}) where {N} = NodeCreationOptions(getindex(options.options, keys)) -Base.keys(options::NodeCreationOptions) = keys(options.options) -Base.get(options::NodeCreationOptions, key::Symbol, default) = get(options.options, key, default) - -# Fast fallback for empty options -Base.haskey(::NodeCreationOptions{Nothing}, key::Symbol) = false -Base.getindex(::NodeCreationOptions{Nothing}, keys...) = error("type `NodeCreationOptions{Nothing}` has no field $(keys)") -Base.keys(::NodeCreationOptions{Nothing}) = () -Base.get(::NodeCreationOptions{Nothing}, key::Symbol, default) = default - -withopts(::NodeCreationOptions{Nothing}, options::NamedTuple) = NodeCreationOptions(options) -withopts(options::NodeCreationOptions, extra::NamedTuple) = NodeCreationOptions((; options.options..., extra...)) - -withoutopts(::NodeCreationOptions{Nothing}, ::Val) = NodeCreationOptions(nothing) - -function withoutopts(options::NodeCreationOptions, ::Val{K}) where {K} - newoptions = options.options[filter(key -> key ∉ K, keys(options.options))] - # Should be compiled out, there are tests for it - if isempty(newoptions) - return NodeCreationOptions(nothing) - else - return NodeCreationOptions(newoptions) - end -end - -""" - VariableNodeProperties(name, index, kind, link, value) - -Data associated with a variable node in a probabilistic graphical model. -""" -struct VariableNodeProperties - name::Symbol - index::Any - kind::Symbol - link::Any - value::Any -end - -VariableNodeProperties(; name, index, kind = VariableKindRandom, link = nothing, value = nothing) = - VariableNodeProperties(name, index, kind, link, value) - -is_factor(::VariableNodeProperties) = false -is_variable(::VariableNodeProperties) = true - -function Base.convert(::Type{VariableNodeProperties}, name::Symbol, index, options::NodeCreationOptions) - return VariableNodeProperties( - name = name, - index = index, - kind = get(options, :kind, VariableKindRandom), - link = get(options, :link, nothing), - value = get(options, :value, nothing) - ) -end - -getname(properties::VariableNodeProperties) = properties.name -getlink(properties::VariableNodeProperties) = properties.link -index(properties::VariableNodeProperties) = properties.index -value(properties::VariableNodeProperties) = properties.value - -"Defines a `random` (or `latent`) kind for a variable in a probabilistic graphical model." -const VariableKindRandom = :random -"Defines a `data` kind for a variable in a probabilistic graphical model." -const VariableKindData = :data -"Defines a `constant` kind for a variable in a probabilistic graphical model." -const VariableKindConstant = :constant -"Placeholder for a variable kind in a probabilistic graphical model." -const VariableKindUnknown = :unknown - -is_kind(properties::VariableNodeProperties, kind) = properties.kind === kind -is_kind(properties::VariableNodeProperties, ::Val{kind}) where {kind} = properties.kind === kind -is_random(properties::VariableNodeProperties) = is_kind(properties, Val(VariableKindRandom)) -is_data(properties::VariableNodeProperties) = is_kind(properties, Val(VariableKindData)) -is_constant(properties::VariableNodeProperties) = is_kind(properties, Val(VariableKindConstant)) - -const VariableNameAnonymous = :anonymous_var_graphppl - -is_anonymous(properties::VariableNodeProperties) = properties.name === VariableNameAnonymous - -function Base.show(io::IO, properties::VariableNodeProperties) - print(io, "name = ", properties.name, ", index = ", properties.index) - if !isnothing(properties.link) - print(io, ", linked to ", properties.link) - end -end - -""" - FactorNodeProperties(fform, neighbours) - -Data associated with a factor node in a probabilistic graphical model. -""" -struct FactorNodeProperties{D} - fform::Any - neighbors::Vector{Tuple{NodeLabel, EdgeLabel, D}} -end - -FactorNodeProperties(; fform, neighbors = Tuple{NodeLabel, EdgeLabel, NodeData}[]) = FactorNodeProperties(fform, neighbors) - -is_factor(::FactorNodeProperties) = true -is_variable(::FactorNodeProperties) = false - -function Base.convert(::Type{FactorNodeProperties}, fform, options::NodeCreationOptions) - return FactorNodeProperties(fform = fform, neighbors = get(options, :neighbors, Tuple{NodeLabel, EdgeLabel, NodeData}[])) -end - -getname(properties::FactorNodeProperties) = string(properties.fform) -prettyname(properties::FactorNodeProperties) = prettyname(properties.fform) -prettyname(fform::Any) = string(fform) # Can be overloaded for custom pretty names - -fform(properties::FactorNodeProperties) = properties.fform -neighbors(properties::FactorNodeProperties) = properties.neighbors -addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel, data) = - push!(properties.neighbors, (variable, edge, data)) -neighbor_data(properties::FactorNodeProperties) = Iterators.map(neighbor -> neighbor[3], neighbors(properties)) - -function Base.show(io::IO, properties::FactorNodeProperties) - print(io, "fform = ", properties.fform, ", neighbors = ", properties.neighbors) -end - -""" - NodeData(context, properties, plugins) - -Data associated with a node in a probabilistic graphical model. -The `context` field stores the context of the node. -The `properties` field stores the properties of the node. -The `plugins` field stores additional properties of the node depending on which plugins were enabled. -""" -mutable struct NodeData - const context :: Context - const properties :: Union{VariableNodeProperties, FactorNodeProperties{NodeData}} - const extra :: UnorderedDictionary{Symbol, Any} -end - -NodeData(context, properties) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}()) - -function Base.show(io::IO, nodedata::NodeData) - context = getcontext(nodedata) - properties = getproperties(nodedata) - print(io, "NodeData in context ", shortname(context), " with properties ", properties) - extra = getextra(nodedata) - if !isempty(extra) - print(io, " with extra: ") - print(io, extra) - end -end - -getcontext(node::NodeData) = node.context -getproperties(node::NodeData) = node.properties -getextra(node::NodeData) = node.extra - -is_constant(node::NodeData) = is_constant(getproperties(node)) - -""" - hasextra(node::NodeData, key::Symbol) - -Checks if `NodeData` has an extra property with the given key. -""" -hasextra(node::NodeData, key::Symbol) = haskey(node.extra, key) -""" - getextra(node::NodeData, key::Symbol, [ default ]) - -Returns the extra property with the given key. Optionally, if the property does not exist, returns the default value. -""" -getextra(node::NodeData, key::Symbol) = getindex(node.extra, key) -getextra(node::NodeData, key::Symbol, default) = hasextra(node, key) ? getextra(node, key) : default - -""" - setextra!(node::NodeData, key::Symbol, value) - -Sets the extra property with the given key to the given value. -""" -setextra!(node::NodeData, key::Symbol, value) = insert!(node.extra, key, value) - -""" -A compile time key to access the `extra` properties of the `NodeData` structure. -""" -struct NodeDataExtraKey{K, T} end - -getkey(::NodeDataExtraKey{K, T}) where {K, T} = K - -function hasextra(node::NodeData, key::NodeDataExtraKey{K}) where {K} - return haskey(node.extra, K) -end -function getextra(node::NodeData, key::NodeDataExtraKey{K, T})::T where {K, T} - return getindex(node.extra, K)::T -end -function getextra(node::NodeData, key::NodeDataExtraKey{K, T}, default::T)::T where {K, T} - return hasextra(node, key) ? (getextra(node, key)::T) : default -end -function setextra!(node::NodeData, key::NodeDataExtraKey{K}, value::T) where {K, T} - return insert!(node.extra, K, value) -end - -""" - is_factor(nodedata::NodeData) - -Returns `true` if the node data is associated with a factor node, `false` otherwise. -See also: [`is_variable`](@ref), -""" -is_factor(node::NodeData) = is_factor(getproperties(node)) -""" - is_variable(nodedata::NodeData) - -Returns `true` if the node data is associated with a variable node, `false` otherwise. -See also: [`is_factor`](@ref), -""" -is_variable(node::NodeData) = is_variable(getproperties(node)) - -factor_nodes(model::Model) = Iterators.filter(node -> is_factor(model[node]), labels(model)) -variable_nodes(model::Model) = Iterators.filter(node -> is_variable(model[node]), labels(model)) - -""" -A version `factor_nodes(model)` that uses a callback function to process the factor nodes. -The callback function accepts both the label and the node data. -""" -function factor_nodes(callback::F, model::Model) where {F} - for label in labels(model) - nodedata = model[label] - if is_factor(nodedata) - callback((label::NodeLabel), (nodedata::NodeData)) - end - end -end - -""" -A version `variable_nodes(model)` that uses a callback function to process the variable nodes. -The callback function accepts both the label and the node data. -""" -function variable_nodes(callback::F, model::Model) where {F} - for label in labels(model) - nodedata = model[label] - if is_variable(nodedata) - callback((label::NodeLabel), (nodedata::NodeData)) - end - end -end - -""" - VariableRef(model::Model, context::Context, name::Symbol, index, external_collection = nothing) - -`VariableRef` implements a lazy reference to a variable in the model. -The reference does not create an actual variable in the model immediatelly, but postpones the creation -until strictly necessarily, which is hapenning inside the `unroll` function. The postponed creation allows users to define -pass a single variable into a submodel, e.g. `y ~ submodel(x = x)`, but use it as an array inside the submodel, -e.g. `y[i] ~ Normal(x[i], 1.0)`. - -Optionally accepts an `external_collection`, which defines the upper limit on the shape of the underlying collection. -For example, an external collection `[ 1, 2, 3 ]` can be used both as `y ~ ...` and `y[i] ~ ...`, but not as `y[i, j] ~ ...`. -By default, the `MissingCollection` is used for the `external_collection`, which does not restrict the shape of the underlying collection. - -The `index` is always a `Tuple`. By default, `(nothing, )` is used, to indicate empty indices with no restrictions on the shape of the underlying collection. -If "non-nothing" index is supplied, e.g. `(1, )` the shape of the udnerlying collection will be fixed to match the index -(1-dimensional in case of `(1, )`, 2-dimensional in case of `(1, 1)` and so on). -""" -struct VariableRef{M, C, O, I, E, L} - model::M - context::C - options::O - name::Symbol - index::I - external_collection::E - internal_collection::L -end - -Base.:(==)(left::VariableRef, right::VariableRef) = - left.model == right.model && left.context == right.context && left.name == right.name && left.index == right.index - -function Base.:(==)(left::VariableRef, right) - error( - "Comparing Factor Graph variable `$left` with a value. This is not possible as the value of `$left` is not known at model construction time." - ) -end -Base.:(==)(left, right::VariableRef) = right == left - -Base.:(>)(left::VariableRef, right) = left == right -Base.:(>)(left, right::VariableRef) = left == right -Base.:(<)(left::VariableRef, right) = left == right -Base.:(<)(left, right::VariableRef) = left == right -Base.:(>=)(left::VariableRef, right) = left == right -Base.:(>=)(left, right::VariableRef) = left == right -Base.:(<=)(left::VariableRef, right) = left == right -Base.:(<=)(left, right::VariableRef) = left == right - -is_proxied(::Type{T}) where {T <: VariableRef} = True() - -external_collection_typeof(::Type{VariableRef{M, C, O, I, E, L}}) where {M, C, O, I, E, L} = E -internal_collection_typeof(::Type{VariableRef{M, C, O, I, E, L}}) where {M, C, O, I, E, L} = L - -external_collection(ref::VariableRef) = ref.external_collection -internal_collection(ref::VariableRef) = ref.internal_collection - -Base.show(io::IO, ref::VariableRef) = variable_ref_show(io, ref.name, ref.index) -variable_ref_show(io::IO, name::Symbol, index::Nothing) = print(io, name) -variable_ref_show(io::IO, name::Symbol, index::Tuple{Nothing}) = print(io, name) -variable_ref_show(io::IO, name::Symbol, index::Tuple) = print(io, name, "[", join(index, ","), "]") -variable_ref_show(io::IO, name::Symbol, index::Any) = print(io, name, "[", index, "]") - -""" - makevarref(fform::F, model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) - -A function that creates `VariableRef`, but takes the `fform` into account. When `fform` happens to be `Atomic` creates -the underlying variable immediatelly without postponing. When `fform` is `Composite` does not create the actual variable, -but waits until strictly necessarily. -""" -function makevarref end - -function makevarref(fform::F, model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) where {F} - return makevarref(NodeType(model, fform), model, context, options, name, index) -end - -function makevarref(::Atomic, model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) - # In the case of `Atomic` variable reference, we always create the variable - # (unless the index is empty, which may happen during broadcasting) - internal_collection = isempty(index) ? nothing : getorcreate!(model, context, name, index...) - return VariableRef(model, context, options, name, index, nothing, internal_collection) -end - -function makevarref(::Composite, model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) - # In the case of `Composite` variable reference, we create it immediatelly only when the variable is instantiated - # with indexing operation - internal_collection = if !all(isnothing, index) - getorcreate!(model, context, name, index...) - else - nothing - end - return VariableRef(model, context, options, name, index, nothing, internal_collection) -end - -function VariableRef( - model::Model, - context::Context, - options::NodeCreationOptions, - name::Symbol, - index::Tuple, - external_collection = nothing, - internal_collection = nothing -) - M = typeof(model) - C = typeof(context) - O = typeof(options) - I = typeof(index) - E = typeof(external_collection) - L = typeof(internal_collection) - return VariableRef{M, C, O, I, E, L}(model, context, options, name, index, external_collection, internal_collection) -end - -function unroll(p::ProxyLabel, ref::VariableRef, index, maycreate, liftedindex) - liftedindex = lift_index(maycreate, index, liftedindex) - if maycreate === False() - return checked_getindex(getifcreated(ref.model, ref.context, ref, liftedindex), index) - elseif maycreate === True() - return checked_getindex(getorcreate!(ref.model, ref.context, ref, liftedindex), index) - end - error("Unreachable. The `maycreate` argument in the `unroll` function for the `VariableRef` must be either `True` or `False`.") -end - -function getifcreated(model::Model, context::Context, ref::VariableRef) - return getifcreated(model, context, ref, ref.index) -end - -function getifcreated(model::Model, context::Context, ref::VariableRef, index) - if !isnothing(ref.external_collection) - return getorcreate!(ref.model, ref.context, ref, index) - elseif !isnothing(ref.internal_collection) - return ref.internal_collection - elseif haskey(ref.context, ref.name) - return ref.context[ref.name] - else - error(lazy"The variable `$ref` has been used, but has not been instantiated.") - end -end - -function getorcreate!(model::Model, context::Context, ref::VariableRef, index::Nothing) - check_external_collection_compatibility(ref, index) - return getorcreate!(model, context, ref.options, ref.name, index) -end - -function getorcreate!(model::Model, context::Context, ref::VariableRef, index::Tuple) - check_external_collection_compatibility(ref, index) - return getorcreate!(model, context, ref.options, ref.name, index...) -end - -Base.IteratorSize(ref::VariableRef) = Base.IteratorSize(typeof(ref)) -Base.IteratorEltype(ref::VariableRef) = Base.IteratorEltype(typeof(ref)) -Base.eltype(ref::VariableRef) = Base.eltype(typeof(ref)) - -Base.IteratorSize(::Type{R}) where {R <: VariableRef} = - variable_ref_iterator_size(external_collection_typeof(R), internal_collection_typeof(R)) -variable_ref_iterator_size(::Type{Nothing}, ::Type{Nothing}) = Base.SizeUnknown() -variable_ref_iterator_size(::Type{E}, ::Type{L}) where {E, L} = Base.IteratorSize(E) -variable_ref_iterator_size(::Type{Nothing}, ::Type{L}) where {L} = Base.IteratorSize(L) - -Base.IteratorEltype(::Type{R}) where {R <: VariableRef} = - variable_ref_iterator_eltype(external_collection_typeof(R), internal_collection_typeof(R)) -variable_ref_iterator_eltype(::Type{Nothing}, ::Type{Nothing}) = Base.EltypeUnknown() -variable_ref_iterator_eltype(::Type{E}, ::Type{L}) where {E, L} = Base.IteratorEltype(E) -variable_ref_iterator_eltype(::Type{Nothing}, ::Type{L}) where {L} = Base.IteratorEltype(L) - -Base.eltype(::Type{R}) where {R <: VariableRef} = variable_ref_eltype(external_collection_typeof(R), internal_collection_typeof(R)) -variable_ref_eltype(::Type{Nothing}, ::Type{Nothing}) = Any -variable_ref_eltype(::Type{E}, ::Type{L}) where {E, L} = Base.eltype(E) -variable_ref_eltype(::Type{Nothing}, ::Type{L}) where {L} = Base.eltype(L) - -function variableref_checked_collection_typeof(::VariableRef) - return variableref_checked_iterator_call(typeof, :typeof, ref) -end - -Base.length(ref::VariableRef) = variableref_checked_iterator_call(Base.length, :length, ref) -Base.firstindex(ref::VariableRef) = variableref_checked_iterator_call(Base.firstindex, :firstindex, ref) -Base.lastindex(ref::VariableRef) = variableref_checked_iterator_call(Base.lastindex, :lastindex, ref) -Base.eachindex(ref::VariableRef) = variableref_checked_iterator_call(Base.eachindex, :eachindex, ref) -Base.axes(ref::VariableRef) = variableref_checked_iterator_call(Base.axes, :axes, ref) - -Base.size(ref::VariableRef, dims...) = variableref_checked_iterator_call((c) -> Base.size(c, dims...), :size, ref) -Base.getindex(ref::VariableRef, indices...) = variableref_checked_iterator_call((c) -> Base.getindex(c, indices...), :getindex, ref) - -function variableref_checked_iterator_call(f::F, fsymbol::Symbol, ref::VariableRef) where {F} - if !isnothing(ref.external_collection) - return f(ref.external_collection) - elseif !isnothing(ref.internal_collection) - return f(ref.internal_collection) - elseif haskey(ref.context, ref.name) - return f(ref.context[ref.name]) - end - error(lazy"Cannot call `$(fsymbol)` on variable reference `$(ref.name)`. The variable `$(ref.name)` has not been instantiated.") -end - -""" - datalabel(model, context, options, name, collection = MissingCollection()) - -A function for creating proxy data labels to pass into the model upon creation. -Can be useful in combination with `ModelGenerator` and `create_model`. -""" -function datalabel(model, context, options, name, collection = MissingCollection()) - kind = get(options, :kind, VariableKindUnknown) - if !isequal(kind, VariableKindData) - error("`datalabel` only supports `VariableKindData` in `NodeCreationOptions`") - end - return proxylabel(name, VariableRef(model, context, options, name, (nothing,), collection), nothing, True()) -end - -function postprocess_returnval(ref::VariableRef) - if haskey(ref.context, ref.name) - return ref.context[ref.name] - end - error("Cannot `return $(ref)`. The variable has not been instantiated.") -end - -""" -A placeholder collection for `VariableRef` when the actual external collection is not yet available. -""" -struct MissingCollection end - -__err_missing_collection_missing_method(method::Symbol) = - error("The `$method` method is not defined for a lazy node label without data attached.") - -Base.IteratorSize(::Type{MissingCollection}) = __err_missing_collection_missing_method(:IteratorSize) -Base.IteratorEltype(::Type{MissingCollection}) = __err_missing_collection_missing_method(:IteratorEltype) -Base.eltype(::Type{MissingCollection}) = __err_missing_collection_missing_method(:eltype) -Base.length(::MissingCollection) = __err_missing_collection_missing_method(:length) -Base.size(::MissingCollection, dims...) = __err_missing_collection_missing_method(:size) -Base.firstindex(::MissingCollection) = __err_missing_collection_missing_method(:firstindex) -Base.lastindex(::MissingCollection) = __err_missing_collection_missing_method(:lastindex) -Base.eachindex(::MissingCollection) = __err_missing_collection_missing_method(:eachindex) -Base.axes(::MissingCollection) = __err_missing_collection_missing_method(:axes) - -function check_external_collection_compatibility(ref::VariableRef, index) - if !isnothing(external_collection(ref)) && !__check_external_collection_compatibility(ref, index) - error( - """ - The index `[$(!isnothing(index) ? join(index, ", ") : nothing)]` is not compatible with the underlying collection provided for the label `$(ref.name)`. - The underlying data provided for `$(ref.name)` is `$(external_collection(ref))`. - """ - ) - end - return nothing -end - -function __check_external_collection_compatibility(ref::VariableRef, index::Nothing) - # We assume that index `nothing` is always compatible with the underlying collection - # Eg. a matrix `Σ` can be used both as it is `Σ`, but also as `Σ[1]` or `Σ[1, 1]` - return true -end - -function __check_external_collection_compatibility(ref::VariableRef, index::Tuple) - return __check_external_collection_compatibility(ref, external_collection(ref), index) -end - -# We can't really check if the data compatible or not if we get the `MissingCollection` -__check_external_collection_compatibility(label::VariableRef, ::MissingCollection, index::Tuple) = true -__check_external_collection_compatibility(label::VariableRef, collection::AbstractArray, indices::Tuple) = - checkbounds(Bool, collection, indices...) -__check_external_collection_compatibility(label::VariableRef, collection::Tuple, indices::Tuple) = - length(indices) === 1 && first(indices) ∈ 1:length(collection) -# A number cannot really be queried with non-empty indices -__check_external_collection_compatibility(label::VariableRef, collection::Number, indices::Tuple) = false -# For all other we simply don't know so we assume we are compatible -__check_external_collection_compatibility(label::VariableRef, collection, indices::Tuple) = true - -function Base.iterate(ref::VariableRef, state) - if !isnothing(external_collection(ref)) - return iterate(external_collection(ref), state) - elseif !isnothing(internal_collection(ref)) - return iterate(internal_collection(ref), state) - elseif haskey(ref.context, ref.name) - return iterate(ref.context[ref.name], state) - end - error("Cannot iterate over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.") -end - -function Base.iterate(ref::VariableRef) - if !isnothing(external_collection(ref)) - return iterate(external_collection(ref)) - elseif !isnothing(internal_collection(ref)) - return iterate(internal_collection(ref)) - elseif haskey(ref.context, ref.name) - return iterate(ref.context[ref.name]) - end - error("Cannot iterate over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.") -end - -function Base.broadcastable(ref::VariableRef) - if !isnothing(external_collection(ref)) - # If we have an underlying collection (e.g. data), we should instantiate all variables at the point of broadcasting - # in order to support something like `y .~ ` where `y` is a data label - return collect( - Iterators.map( - I -> checked_getindex(getorcreate!(ref.model, ref.context, ref.options, ref.name, I.I...), I.I), CartesianIndices(axes(ref)) - ) - ) - elseif !isnothing(internal_collection(ref)) - return Base.broadcastable(internal_collection(ref)) - elseif haskey(ref.context, ref.name) - return Base.broadcastable(ref.context[ref.name]) - end - error("Cannot broadcast over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.") -end - -""" -A structure that holds interfaces of a node in the type argument `I`. Used for dispatch. -""" -struct StaticInterfaces{I} end - -StaticInterfaces(I::Tuple) = StaticInterfaces{I}() -Base.getindex(::StaticInterfaces{I}, index) where {I} = I[index] - -function Base.convert(::Type{NamedTuple}, ::StaticInterfaces{I}, t::Tuple) where {I} - return NamedTuple{I}(t) -end - -function Model(graph::MetaGraph, plugins::PluginsCollection, backend, source) - return Model(graph, plugins, backend, source, Base.RefValue(0)) -end - -function Model(fform::F, plugins::PluginsCollection) where {F} - return Model(fform, plugins, default_backend(fform), nothing) -end - -function Model(fform::F, plugins::PluginsCollection, backend, source) where {F} - label_type = NodeLabel - edge_data_type = EdgeLabel - vertex_data_type = NodeData - graph = MetaGraph(Graph(), label_type, vertex_data_type, edge_data_type, Context(fform)) - model = Model(graph, plugins, backend, source) - return model -end - -Base.setindex!(model::Model, val::NodeData, key::NodeLabel) = Base.setindex!(model.graph, val, key) -Base.setindex!(model::Model, val::EdgeLabel, src::NodeLabel, dst::NodeLabel) = Base.setindex!(model.graph, val, src, dst) -Base.getindex(model::Model) = Base.getindex(model.graph) -Base.getindex(model::Model, key::NodeLabel) = Base.getindex(model.graph, key) -Base.getindex(model::Model, src::NodeLabel, dst::NodeLabel) = Base.getindex(model.graph, src, dst) -Base.getindex(model::Model, keys::AbstractArray{NodeLabel}) = map(key -> model[key], keys) -Base.getindex(model::Model, keys::NTuple{N, NodeLabel}) where {N} = collect(map(key -> model[key], keys)) - -Base.getindex(model::Model, keys::Base.Generator) = [model[key] for key in keys] - -Graphs.nv(model::Model) = Graphs.nv(model.graph) -Graphs.ne(model::Model) = Graphs.ne(model.graph) -Graphs.edges(model::Model) = Graphs.edges(model.graph) - -Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node]) -Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes)) - -Graphs.neighbors(model::Model, node::NodeLabel, nodedata::NodeData) = Graphs.neighbors(model, node, nodedata, getproperties(nodedata)) -Graphs.neighbors(model::Model, node::NodeLabel, nodedata::NodeData, properties::FactorNodeProperties) = map(neighbor -> neighbor[1], neighbors(properties)) -Graphs.neighbors(model::Model, node::NodeLabel, nodedata::NodeData, properties::VariableNodeProperties) = MetaGraphsNext.neighbor_labels(model.graph, node) - -Graphs.edges(model::Model, node::NodeLabel) = Graphs.edges(model, node, model[node]) -Graphs.edges(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.edges(model, node), nodes)) - -Graphs.edges(model::Model, node::NodeLabel, nodedata::NodeData) = Graphs.edges(model, node, nodedata, getproperties(nodedata)) -Graphs.edges(model::Model, node::NodeLabel, nodedata::NodeData, properties::FactorNodeProperties) = - map(neighbor -> neighbor[2], neighbors(properties)) - -function Graphs.edges(model::Model, node::NodeLabel, nodedata::NodeData, properties::VariableNodeProperties) - return (model[node, dst] for dst in MetaGraphsNext.neighbor_labels(model.graph, node)) -end - -Graphs.degree(model::Model, label::NodeLabel) = Graphs.degree(model.graph, MetaGraphsNext.code_for(model.graph, label)) - -abstract type AbstractModelFilterPredicate end - -struct FactorNodePredicate{N} <: AbstractModelFilterPredicate end - -function apply(::FactorNodePredicate{N}, model, something) where {N} - return apply(IsFactorNode(), model, something) && fform(getproperties(model[something])) ∈ aliases(model, N) -end - -struct IsFactorNode <: AbstractModelFilterPredicate end - -function apply(::IsFactorNode, model, something) - return is_factor(model[something]) -end - -struct VariableNodePredicate{V} <: AbstractModelFilterPredicate end - -function apply(::VariableNodePredicate{N}, model, something) where {N} - return apply(IsVariableNode(), model, something) && getname(getproperties(model[something])) === N -end - -struct IsVariableNode <: AbstractModelFilterPredicate end - -function apply(::IsVariableNode, model, something) - return is_variable(model[something]) -end - -struct SubmodelPredicate{S, C} <: AbstractModelFilterPredicate end - -function apply(::SubmodelPredicate{S, False}, model, something) where {S} - return fform(getcontext(model[something])) === S -end - -function apply(::SubmodelPredicate{S, True}, model, something) where {S} - return S ∈ fform.(path_to_root(getcontext(model[something]))) -end - -struct AndNodePredicate{L, R} <: AbstractModelFilterPredicate - left::L - right::R -end - -function apply(and::AndNodePredicate, model, something) - return apply(and.left, model, something) && apply(and.right, model, something) -end - -struct OrNodePredicate{L, R} <: AbstractModelFilterPredicate - left::L - right::R -end - -function apply(or::OrNodePredicate, model, something) - return apply(or.left, model, something) || apply(or.right, model, something) -end - -Base.:(|)(left::AbstractModelFilterPredicate, right::AbstractModelFilterPredicate) = OrNodePredicate(left, right) -Base.:(&)(left::AbstractModelFilterPredicate, right::AbstractModelFilterPredicate) = AndNodePredicate(left, right) - -as_node(any) = FactorNodePredicate{any}() -as_node() = IsFactorNode() -as_variable(any) = VariableNodePredicate{any}() -as_variable() = IsVariableNode() -as_context(any; children = false) = SubmodelPredicate{any, typeof(static(children))}() - -function Base.filter(predicate::AbstractModelFilterPredicate, model::Model) - return Iterators.filter(something -> apply(predicate, model, something), labels(model)) -end - -""" - generate_nodelabel(model::Model, name::Symbol) - -Generate a new `NodeLabel` object with a unique identifier based on the specified name and the -number of nodes already in the model. - -Arguments: -- `model`: A `Model` object representing the probabilistic graphical model. -- `name`: A symbol representing the name of the node. -- `variable_type`: A UInt8 representing the type of the variable. 0 = factor, 1 = individual variable, 2 = vector variable, 3 = tensor variable -- `index`: An integer or tuple of integers representing the index of the variable. -""" -function generate_nodelabel(model::Model, name) - nextcounter = setcounter!(model, getcounter(model) + 1) - return NodeLabel(name, nextcounter) -end - -""" - getcontext(model::Model) - -Retrieves the context of a model. The context of a model contains the complete hierarchy of variables and factor nodes. -Additionally, contains all child submodels and their respective contexts. The Context supplies a mapping from symbols to `GraphPPL.NodeLabel` structures -with which the model can be queried. -""" -getcontext(model::Model) = model[] - -function get_principal_submodel(model::Model) - context = getcontext(model) - return context -end - -Base.getindex(context::Context, ivar::IndexedVariable{Nothing}) = context[getname(ivar)] -Base.getindex(context::Context, ivar::IndexedVariable) = context[getname(ivar)][index(ivar)] - -""" - aliases(backend, fform) - -Returns a collection of aliases for `fform` depending on the `backend`. -""" -aliases(backend, fform) = error("Backend $backend must implement a method for `aliases` for `$(fform)`.") -aliases(model::Model, fform::F) where {F} = aliases(getbackend(model), fform) - -function add_vertex!(model::Model, label, data) - # This is an unsafe procedure that implements behaviour from `MetaGraphsNext`. - code = nv(model) + 1 - model.graph.vertex_labels[code] = label - model.graph.vertex_properties[label] = (code, data) - Graphs.add_vertex!(model.graph.graph) -end - -function add_edge!(model::Model, src, dst, data) - # This is an unsafe procedure that implements behaviour from `MetaGraphsNext`. - code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst) - model.graph.edge_data[(src, dst)] = data - return Graphs.add_edge!(model.graph.graph, code_src, code_dst) -end - -function has_edge(model::Model, src, dst) - code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst) - return Graphs.has_edge(model.graph.graph, code_src, code_dst) -end - -""" - copy_markov_blanket_to_child_context(child_context::Context, interfaces::NamedTuple) - -Copy the variables in the Markov blanket of a parent context to a child context, using a mapping specified by a named tuple. - -The Markov blanket of a node or model in a Factor Graph is defined as the set of its outgoing interfaces. -This function copies the variables in the Markov blanket of the parent context specified by the named tuple `interfaces` to the child context `child_context`, - by setting each child variable in `child_context.individual_variables` to its corresponding parent variable in `interfaces`. - -# Arguments -- `child_context::Context`: The child context to which to copy the Markov blanket variables. -- `interfaces::NamedTuple`: A named tuple that maps child variable names to parent variable names. -""" -function copy_markov_blanket_to_child_context(child_context::Context, interfaces::NamedTuple) - foreach(pairs(interfaces)) do (name_in_child, object_in_parent) - add_to_child_context(child_context, name_in_child, object_in_parent) - end -end - -function add_to_child_context(child_context::Context, name_in_child::Symbol, object_in_parent::ProxyLabel) - set!(child_context.proxies, name_in_child, object_in_parent) - return nothing -end - -function add_to_child_context(child_context::Context, name_in_child::Symbol, object_in_parent) - # By default, we assume that `object_in_parent` is a constant, so there is no need to save it in the context - return nothing -end - -throw_if_individual_variable(context::Context, name::Symbol) = - haskey(context.individual_variables, name) ? error("Variable $name is already an individual variable in the model") : nothing -throw_if_vector_variable(context::Context, name::Symbol) = - haskey(context.vector_variables, name) ? error("Variable $name is already a vector variable in the model") : nothing -throw_if_tensor_variable(context::Context, name::Symbol) = - haskey(context.tensor_variables, name) ? error("Variable $name is already a tensor variable in the model") : nothing - -""" - getorcreate!(model::Model, context::Context, options::NodeCreationOptions, name, index) - -Get or create a variable (name) from a factor graph model and context, using an index if provided. - -This function searches for a variable (name) in the factor graph model and context specified by the arguments `model` and `context`. If the variable exists, -it returns it. Otherwise, it creates a new variable and returns it. - -# Arguments -- `model::Model`: The factor graph model to search for or create the variable in. -- `context::Context`: The context to search for or create the variable in. -- `options::NodeCreationOptions`: Options for creating the variable. Must be a `NodeCreationOptions` object. -- `name`: The variable (name) to search for or create. Must be a symbol. -- `index`: Optional index for the variable. Can be an integer, a collection of integers, or `nothing`. If the index is `nothing` creates a single variable. -If the index is an integer creates a vector-like variable. If the index is a collection of integers creates a tensor-like variable. - -# Returns -The variable (name) found or created in the factor graph model and context. -""" -function getorcreate! end - -getorcreate!(::Model, ::Context, name::Symbol) = error("Index is required in the `getorcreate!` function for variable `$(name)`") -getorcreate!(::Model, ::Context, options::NodeCreationOptions, name::Symbol) = - error("Index is required in the `getorcreate!` function for variable `$(name)`") - -function getorcreate!(model::Model, ctx::Context, name::Symbol, index...) - return getorcreate!(model, ctx, EmptyNodeCreationOptions, name, index...) -end - -function getorcreate!(model::Model, ctx::Context, options::NodeCreationOptions, name::Symbol, index::Nothing) - throw_if_vector_variable(ctx, name) - throw_if_tensor_variable(ctx, name) - return get(() -> add_variable_node!(model, ctx, options, name, index), ctx.individual_variables, name) -end - -function getorcreate!(model::Model, ctx::Context, options::NodeCreationOptions, name::Symbol, index::Integer) - throw_if_individual_variable(ctx, name) - throw_if_tensor_variable(ctx, name) - if !haskey(ctx.vector_variables, name) - ctx[name] = ResizableArray(NodeLabel, Val(1)) - end - vectorvar = ctx.vector_variables[name] - if !isassigned(vectorvar, index) - vectorvar[index] = add_variable_node!(model, ctx, options, name, index) - end - return vectorvar -end - -function getorcreate!(model::Model, ctx::Context, options::NodeCreationOptions, name::Symbol, i1::Integer, is::Vararg{Integer}) - throw_if_individual_variable(ctx, name) - throw_if_vector_variable(ctx, name) - if !haskey(ctx.tensor_variables, name) - ctx[name] = ResizableArray(NodeLabel, Val(1 + length(is))) - end - tensorvar = ctx.tensor_variables[name] - if !isassigned(tensorvar, i1, is...) - tensorvar[i1, is...] = add_variable_node!(model, ctx, options, name, (i1, is...)) - end - return tensorvar -end - -function getorcreate!(model::Model, ctx::Context, options::NodeCreationOptions, name::Symbol, range::AbstractRange) - isempty(range) && error("Empty range is not allowed in the `getorcreate!` function for variable `$(name)`") - foreach(range) do i - getorcreate!(model, ctx, options, name, i) - end - return getorcreate!(model, ctx, options, name, first(range)) -end - -function getorcreate!(model::Model, ctx::Context, options::NodeCreationOptions, name::Symbol, r1::AbstractRange, rs::Vararg{AbstractRange}) - (isempty(r1) || any(isempty, rs)) && error("Empty range is not allowed in the `getorcreate!` function for variable `$(name)`") - foreach(Iterators.product(r1, rs...)) do i - getorcreate!(model, ctx, options, name, i...) - end - return getorcreate!(model, ctx, options, name, first(r1), first.(rs)...) -end - -function getorcreate!(model::Model, ctx::Context, options::NodeCreationOptions, name::Symbol, indices...) - if haskey(ctx, name) - var = ctx[name] - return var - end - error(lazy"Cannot create a variable named `$(name)` with non-standard indices $(indices)") -end - -getifcreated(model::Model, context::Context, var::NodeLabel) = var -getifcreated(model::Model, context::Context, var::ResizableArray) = var -getifcreated(model::Model, context::Context, var::Union{Tuple, AbstractArray{T}} where {T <: Union{NodeLabel, ProxyLabel, VariableRef}}) = - map((v) -> getifcreated(model, context, v), var) -getifcreated(model::Model, context::Context, var::ProxyLabel) = var -getifcreated(model::Model, context::Context, var) = - add_constant_node!(model, context, NodeCreationOptions(value = var, kind = :constant), :constvar, nothing) - -""" - add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index) - -Add a variable node to the model with the given `name` and `index`. -This function is unsafe (doesn't check if a variable with the given name already exists in the model). - -Args: - - `model::Model`: The model to which the node is added. - - `context::Context`: The context to which the symbol is added. - - `options::NodeCreationOptions`: The options for the creation process. - - `name::Symbol`: The ID of the variable. - - `index`: The index of the variable. - -Returns: - - The generated symbol for the variable. -""" -function add_variable_node! end - -function add_variable_node!(model::Model, context::Context, name::Symbol, index) - return add_variable_node!(model, context, EmptyNodeCreationOptions, name, index) -end - -function add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index) - label = __add_variable_node!(model, context, options, name, index) - context[name, index] = label -end - -function add_constant_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index) - label = __add_variable_node!(model, context, options, name, index) - context[to_symbol(name, label.global_counter), index] = label # to_symbol(label) is type unstable and we know the type of label.name here from name - return label -end - -function __add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index) - # In theory plugins are able to overwrite this - potential_label = generate_nodelabel(model, name) - potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options)) - label, nodedata = preprocess_plugins( - UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options - ) - add_vertex!(model, label, nodedata) - return label -end - -""" - AnonymousVariable(model, context) - -Defines a lazy structure for anonymous variables. -The actual anonymous variables materialize only in `make_node!` upon calling, because it needs arguments to the `make_node!` in order to create proper links. -""" -struct AnonymousVariable{M, C} - model::M - context::C -end - -Base.broadcastable(v::AnonymousVariable) = Ref(v) - -create_anonymous_variable!(model::Model, context::Context) = AnonymousVariable(model, context) - -function materialize_anonymous_variable!(anonymous::AnonymousVariable, fform, args) - model = anonymous.model - return materialize_anonymous_variable!(NodeBehaviour(model, fform), model, anonymous.context, fform, args) -end - -# Deterministic nodes can create links to variables in the model -# This might be important for better factorization constraints resolution -function materialize_anonymous_variable!(::Deterministic, model::Model, context::Context, fform, args) - linked = getindex.(Ref(model), unroll.(filter(is_nodelabel, args))) - - # Check if all links are either `data` or `constants` - # In this case it is not necessary to create a new random variable, but rather a data variable - # with `value = fform` - link_const, link_const_or_data = reduce(linked; init = (true, true)) do accum, link - check_is_all_constant, check_is_all_constant_or_data = accum - check_is_all_constant = check_is_all_constant && anonymous_arg_is_constanst(link) - check_is_all_constant_or_data = check_is_all_constant_or_data && anonymous_arg_is_constanst_or_data(link) - return (check_is_all_constant, check_is_all_constant_or_data) - end - - if !link_const && !link_const_or_data - # Most likely case goes first, we need to create a new factor node and a new random variable - (true, add_variable_node!(model, context, NodeCreationOptions(link = linked), VariableNameAnonymous, nothing)) - elseif link_const - # If all `links` are constant nodes we can evaluate the `fform` here and create another constant rather than creating a new factornode - val = fform(map(arg -> arg isa NodeLabel ? value(getproperties(model[arg])) : arg, unroll.(args))...) - ( - false, - add_variable_node!( - model, context, NodeCreationOptions(kind = :constant, value = val, link = linked), VariableNameAnonymous, nothing - ) - ) - elseif link_const_or_data - # If all `links` are constant or data we can create a new data variable with `fform` attached to it as a value rather than creating a new factornode - ( - false, - add_variable_node!( - model, - context, - NodeCreationOptions(kind = :data, value = (fform, unroll.(args)), link = linked), - VariableNameAnonymous, - nothing - ) - ) - else - # This should not really happen - error("Unreachable reached in `materialize_anonymous_variable!` for `Deterministic` node behaviour.") - end -end - -anonymous_arg_is_constanst(data) = true -anonymous_arg_is_constanst(data::NodeData) = is_constant(getproperties(data)) -anonymous_arg_is_constanst(data::AbstractArray) = all(anonymous_arg_is_constanst, data) - -anonymous_arg_is_constanst_or_data(data) = is_constant(data) -anonymous_arg_is_constanst_or_data(data::NodeData) = - let props = getproperties(data) - is_constant(props) || is_data(props) - end -anonymous_arg_is_constanst_or_data(data::AbstractArray) = all(anonymous_arg_is_constanst_or_data, data) - -function materialize_anonymous_variable!(::Deterministic, model::Model, context::Context, fform, args::NamedTuple) - return materialize_anonymous_variable!(Deterministic(), model, context, fform, values(args)) -end - -function materialize_anonymous_variable!(::Stochastic, model::Model, context::Context, fform, _) - return (true, add_variable_node!(model, context, NodeCreationOptions(), VariableNameAnonymous, nothing)) -end - -""" - add_atomic_factor_node!(model::Model, context::Context, options::NodeCreationOptions, fform) - -Add an atomic factor node to the model with the given name. -The function generates a new symbol for the node and adds it to the model with -the generated symbol as the key and a `FactorNodeData` struct. - -Args: - - `model::Model`: The model to which the node is added. - - `context::Context`: The context to which the symbol is added. - - `options::NodeCreationOptions`: The options for the creation process. - - `fform::Any`: The functional form of the node. - -Returns: - - The generated label for the node. -""" -function add_atomic_factor_node! end - -function add_atomic_factor_node!(model::Model, context::Context, options::NodeCreationOptions, fform::F) where {F} - factornode_id = generate_factor_nodelabel(context, fform) - - potential_label = generate_nodelabel(model, fform) - potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options)) - - label, nodedata = preprocess_plugins( - UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options - ) - - add_vertex!(model, label, nodedata) - context[factornode_id] = label - - return label, nodedata, convert(FactorNodeProperties, getproperties(nodedata)) -end - -""" - factor_alias(backend, fform, interfaces) - -Returns the alias for a given `fform` and `interfaces` with a given `backend`. -""" -function factor_alias end - -factor_alias(backend, fform, interfaces) = - error("The backend $backend must implement a method for `factor_alias` for `$(fform)` and `$(interfaces)`.") -factor_alias(model::Model, fform::F, interfaces) where {F} = factor_alias(getbackend(model), fform, interfaces) - -""" -Add a composite factor node to the model with the given name. - -The function generates a new symbol for the node and adds it to the model with -the generated symbol as the key and a `NodeData` struct with `is_variable` set to -`false` and `node_name` set to the given name. - -Args: - - `model::Model`: The model to which the node is added. - - `parent_context::Context`: The context to which the symbol is added. - - `context::Context`: The context of the composite factor node. - - `node_name::Symbol`: The name of the node. - -Returns: - - The generated id for the node. -""" -function add_composite_factor_node!(model::Model, parent_context::Context, context::Context, node_name) - node_id = generate_factor_nodelabel(parent_context, node_name) - parent_context[node_id] = context - return node_id -end - -function add_edge!( - model::Model, - factor_node_id::NodeLabel, - factor_node_propeties::FactorNodeProperties, - variable_node_id::Union{ProxyLabel, NodeLabel, VariableRef}, - interface_name::Symbol -) - return add_edge!(model, factor_node_id, factor_node_propeties, variable_node_id, interface_name, nothing) -end - -function add_edge!( - model::Model, - factor_node_id::NodeLabel, - factor_node_propeties::FactorNodeProperties, - variable_node_id::Union{AbstractArray, Tuple, NamedTuple}, - interface_name::Symbol -) - return add_edge!(model, factor_node_id, factor_node_propeties, variable_node_id, interface_name, 1) -end - -add_edge!( - model::Model, - factor_node_id::NodeLabel, - factor_node_propeties::FactorNodeProperties, - variable_node_id::Union{ProxyLabel, VariableRef}, - interface_name::Symbol, - index -) = add_edge!(model, factor_node_id, factor_node_propeties, unroll(variable_node_id), interface_name, index) - -function add_edge!( - model::Model, - factor_node_id::NodeLabel, - factor_node_propeties::FactorNodeProperties, - variable_node_id::Union{NodeLabel}, - interface_name::Symbol, - index -) - label = EdgeLabel(interface_name, index) - neighbor_node_label = unroll(variable_node_id) - addneighbor!(factor_node_propeties, neighbor_node_label, label, model[neighbor_node_label]) - edge_added = add_edge!(model, neighbor_node_label, factor_node_id, label) - if !edge_added - # Double check if the edge has already been added - if has_edge(model, neighbor_node_label, factor_node_id) - error( - lazy"Trying to create duplicate edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id). Make sure that all the arguments to the `~` operator are unique (both left hand side and right hand side)." - ) - else - error(lazy"Cannot create an edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id).") - end - end - return label -end - -function add_edge!( - model::Model, - factor_node_id::NodeLabel, - factor_node_propeties::FactorNodeProperties, - variable_nodes::Union{AbstractArray, Tuple, NamedTuple}, - interface_name::Symbol, - index -) - for variable_node in variable_nodes - add_edge!(model, factor_node_id, factor_node_propeties, variable_node, interface_name, index) - index += increase_index(variable_node) - end -end - -increase_index(any) = 1 -increase_index(x::AbstractArray) = length(x) - -struct MixedArguments{A <: Tuple, K <: NamedTuple} - args::A - kwargs::K -end - -""" - interfaces(backend, fform, ::StaticInt{N}) where N - -Returns the interfaces for a given `fform` and `backend` with a given amount of interfaces `N`. -""" -function interfaces end - -interfaces(backend, fform, ninputs) = - error("The backend $(backend) must implement a method for `interfaces` for `$(fform)` and `$(ninputs)` number of inputs.") -interfaces(model::Model, fform::F, ninputs) where {F} = interfaces(getbackend(model), fform, ninputs) - -struct StaticInterfaceAliases{A} end - -StaticInterfaceAliases(A::Tuple) = StaticInterfaceAliases{A}() - -""" - interface_aliases(backend, fform) - -Returns the aliases for a given `fform` and `backend`. -""" -function interface_aliases end - -interface_aliases(backend, fform) = error("The backend $backend must implement a method for `interface_aliases` for `$(fform)`.") -interface_aliases(model::Model, fform::F) where {F} = interface_aliases(getbackend(model), fform) -interface_aliases(model::Model, fform::F, interfaces::StaticInterfaces) where {F} = - interface_aliases(interface_aliases(model, fform), interfaces) - -function interface_aliases(::StaticInterfaceAliases{aliases}, ::StaticInterfaces{interfaces}) where {aliases, interfaces} - return StaticInterfaces( - reduce(aliases; init = interfaces) do acc, alias - from, to = alias - return replace(acc, from => to) - end - ) -end - -""" - missing_interfaces(node_type, val, known_interfaces) - -Returns the interfaces that are missing for a node. This is used when inferring the interfaces for a node that is composite. - -# Arguments -- `node_type`: The type of the node as a Function object. -- `val`: The value of the amount of interfaces the node is supposed to have. This is a `Static.StaticInt` object. -- `known_interfaces`: The known interfaces for the node. - -# Returns -- `missing_interfaces`: A `Vector` of the missing interfaces. -""" -function missing_interfaces(model::Model, fform::F, val, known_interfaces::NamedTuple) where {F} - return missing_interfaces(interfaces(model, fform, val), StaticInterfaces(keys(known_interfaces))) -end - -function missing_interfaces( - ::StaticInterfaces{all_interfaces}, ::StaticInterfaces{present_interfaces} -) where {all_interfaces, present_interfaces} - return StaticInterfaces(filter(interface -> interface ∉ present_interfaces, all_interfaces)) -end - -function prepare_interfaces(model::Model, fform::F, lhs_interface, rhs_interfaces::NamedTuple) where {F} - missing_interface = missing_interfaces(model, fform, static(length(rhs_interfaces)) + static(1), rhs_interfaces) - return prepare_interfaces(missing_interface, fform, lhs_interface, rhs_interfaces) -end - -function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_interfaces::NamedTuple) where {I, F} - if !(length(I) == 1) - error( - lazy"Expected only one missing interface, got $I of length $(length(I)) (node $fform with interfaces $(keys(rhs_interfaces)))" - ) - end - missing_interface = first(I) - return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...)) -end - -function materialize_interface(model, context, interface) - return getifcreated(model, context, unroll(interface)) -end - -function materialze_interfaces(model, context, interfaces) - return map(interface -> materialize_interface(model, context, interface), interfaces) -end - -""" - default_parametrization(backend, fform, rhs) - -Returns the default parametrization for a given `fform` and `backend` with a given `rhs`. -""" -function default_parametrization end - -default_parametrization(backend, nodetype, fform, rhs) = - error("The backend $backend must implement a method for `default_parametrization` for `$(fform)` (`$(nodetype)`) and `$(rhs)`.") -default_parametrization(model::Model, nodetype, fform::F, rhs) where {F} = default_parametrization(getbackend(model), nodetype, fform, rhs) - -""" - instantiate(::Type{Backend}) - -Instantiates a default backend object of the specified type. Should be implemented for all backends. -""" -instantiate(backendtype) = error("The backend of type $backendtype must implement a method for `instantiate`.") - -# maybe change name - -is_nodelabel(x) = false -is_nodelabel(x::AbstractArray) = any(element -> is_nodelabel(element), x) -is_nodelabel(x::GraphPPL.NodeLabel) = true -is_nodelabel(x::ProxyLabel) = true -is_nodelabel(x::VariableRef) = true - -function contains_nodelabel(collection::Tuple) - return any(element -> is_nodelabel(element), collection) ? True() : False() -end - -function contains_nodelabel(collection::NamedTuple) - return any(element -> is_nodelabel(element), values(collection)) ? True() : False() -end - -function contains_nodelabel(collection::MixedArguments) - return contains_nodelabel(collection.args) | contains_nodelabel(collection.kwargs) -end - -# TODO improve documentation - -function make_node!(model::Model, ctx::Context, fform::F, lhs_interfaces, rhs_interfaces) where {F} - return make_node!(model, ctx, EmptyNodeCreationOptions, fform, lhs_interfaces, rhs_interfaces) -end - -make_node!(model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} = - make_node!(NodeType(model, fform), model, ctx, options, fform, lhs_interface, rhs_interfaces) - -# if it is composite, we assume it should be materialized and it is stochastic -# TODO: shall we not assume that the `Composite` node is necessarily stochastic? -make_node!( - nodetype::Composite, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces -) where {F} = make_node!(True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces) - -# If a node is an object and not a function, we materialize it as a stochastic atomic node -make_node!(model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces::Nothing) where {F} = - make_node!(True(), Atomic(), Stochastic(), model, ctx, options, fform, lhs_interface, NamedTuple{}()) - -# If node is Atomic, check stochasticity -make_node!(::Atomic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} = - make_node!(Atomic(), NodeBehaviour(model, fform), model, ctx, options, fform, lhs_interface, rhs_interfaces) - -#If a node is deterministic, we check if there are any NodeLabel objects in the rhs_interfaces (direct check if node should be materialized) -make_node!( - atomic::Atomic, - deterministic::Deterministic, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface, - rhs_interfaces -) where {F} = - make_node!(contains_nodelabel(rhs_interfaces), atomic, deterministic, model, ctx, options, fform, lhs_interface, rhs_interfaces) - -# If the node should not be materialized (if it's Atomic, Deterministic and contains no NodeLabel objects), we return the `fform` evaluated at the interfaces -# This works only if the `lhs_interface` is `AnonymousVariable` (or the corresponding `ProxyLabel` with `AnonymousVariable` as the proxied variable) -__evaluate_fform(fform::F, args::Tuple) where {F} = fform(args...) -__evaluate_fform(fform::F, args::NamedTuple) where {F} = fform(; args...) -__evaluate_fform(fform::F, args::MixedArguments) where {F} = fform(args.args...; args.kwargs...) - -make_node!( - ::False, - ::Atomic, - ::Deterministic, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::Union{AnonymousVariable, ProxyLabel{<:T, <:AnonymousVariable} where {T}}, - rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments} -) where {F} = (nothing, __evaluate_fform(fform, rhs_interfaces)) - -# In case if the `lhs_interface` is something else we throw an error saying that `fform` cannot be instantiated since -# arguments are not stochastic and the `fform` is not stochastic either, thus the usage of `~` is invalid -make_node!( - ::False, - ::Atomic, - ::Deterministic, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface, - rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments} -) where {F} = error("`$(fform)` cannot be used as a factor node. Both the arguments and the node are not stochastic.") - -# If a node is Stochastic, we always materialize. -make_node!( - ::Atomic, ::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces -) where {F} = make_node!(True(), Atomic(), Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces) - -function make_node!( - materialize::True, - node_type::NodeType, - behaviour::NodeBehaviour, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::AnonymousVariable, - rhs_interfaces -) where {F} - (noderequired, lhs_materialized) = materialize_anonymous_variable!(lhs_interface, fform, rhs_interfaces)::Tuple{Bool, NodeLabel} - node_materialized = if noderequired - node, _ = make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_materialized, rhs_interfaces) - node - else - nothing - end - return node_materialized, lhs_materialized -end - -# If we have to materialize but the rhs_interfaces argument is not a NamedTuple, we convert it -make_node!( - materialize::True, - node_type::NodeType, - behaviour::NodeBehaviour, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::Union{NodeLabel, ProxyLabel, VariableRef}, - rhs_interfaces::Tuple -) where {F} = make_node!( - materialize, - node_type, - behaviour, - model, - ctx, - options, - fform, - lhs_interface, - GraphPPL.default_parametrization(model, node_type, fform, rhs_interfaces) -) - -make_node!( - ::True, - node_type::NodeType, - behaviour::NodeBehaviour, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::Union{NodeLabel, ProxyLabel, VariableRef}, - rhs_interfaces::MixedArguments -) where {F} = error("MixedArguments not supported for rhs_interfaces when node has to be materialized") - -make_node!( - materialize::True, - node_type::Composite, - behaviour::Stochastic, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::Union{NodeLabel, ProxyLabel, VariableRef}, - rhs_interfaces::Tuple{} -) where {F} = make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_interface, NamedTuple{}()) - -make_node!( - materialize::True, - node_type::Composite, - behaviour::Stochastic, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::Union{NodeLabel, ProxyLabel, VariableRef}, - rhs_interfaces::Tuple -) where {F} = error(lazy"Composite node $fform cannot should be called with explicitly naming the interface names") - -make_node!( - materialize::True, - node_type::Composite, - behaviour::Stochastic, - model::Model, - ctx::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::Union{NodeLabel, ProxyLabel, VariableRef}, - rhs_interfaces::NamedTuple -) where {F} = make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1)) - -""" - make_node! - -Make a new factor node in the Model and specified Context, attach it to the specified interfaces, and return the interface that is on the lhs of the `~` operator. - -# Arguments -- `model::Model`: The model to add the node to. -- `ctx::Context`: The context in which to add the node. -- `fform`: The function that the node represents. -- `lhs_interface`: The interface that is on the lhs of the `~` operator. -- `rhs_interfaces`: The interfaces that are the arguments of fform on the rhs of the `~` operator. -- `__parent_options__::NamedTuple = nothing`: The options to attach to the node. -- `__debug__::Bool = false`: Whether to attach debug information to the factor node. -""" -function make_node!( - materialize::True, - node_type::Atomic, - behaviour::NodeBehaviour, - model::Model, - context::Context, - options::NodeCreationOptions, - fform::F, - lhs_interface::Union{NodeLabel, ProxyLabel, VariableRef}, - rhs_interfaces::NamedTuple -) where {F} - aliased_rhs_interfaces = convert( - NamedTuple, interface_aliases(model, fform, StaticInterfaces(keys(rhs_interfaces))), values(rhs_interfaces) - ) - aliased_fform = factor_alias(model, fform, StaticInterfaces(keys(aliased_rhs_interfaces))) - prepared_interfaces = prepare_interfaces(model, aliased_fform, lhs_interface, aliased_rhs_interfaces) - sorted_interfaces = sort_interfaces(model, aliased_fform, prepared_interfaces) - interfaces = materialze_interfaces(model, context, sorted_interfaces) - nodeid, _, _ = materialize_factor_node!(model, context, options, aliased_fform, interfaces) - return nodeid, unroll(lhs_interface) -end - -function sort_interfaces(model::Model, fform::F, defined_interfaces::NamedTuple) where {F} - return sort_interfaces(interfaces(model, fform, static(length(defined_interfaces))), defined_interfaces) -end - -function sort_interfaces(::StaticInterfaces{I}, defined_interfaces::NamedTuple) where {I} - return defined_interfaces[I] -end - -function materialize_factor_node!(model::Model, context::Context, options::NodeCreationOptions, fform::F, interfaces::NamedTuple) where {F} - factor_node_id, factor_node_data, factor_node_properties = add_atomic_factor_node!(model, context, options, fform) - foreach(pairs(interfaces)) do (interface_name, interface) - add_edge!(model, factor_node_id, factor_node_properties, interface, interface_name) - end - return factor_node_id, factor_node_data, factor_node_properties -end - -function add_terminated_submodel!(model::Model, context::Context, fform, interfaces::NamedTuple) - return add_terminated_submodel!(model, context, NodeCreationOptions((; created_by = () -> :($QuoteNode(fform)))), fform, interfaces) -end - -function add_terminated_submodel!(model::Model, context::Context, options::NodeCreationOptions, fform, interfaces::NamedTuple) - returnval = add_terminated_submodel!(model, context, options, fform, interfaces, static(length(interfaces))) - returnval!(context, returnval) - return returnval -end - -""" -Add the `fform` as the toplevel model to the `model` and `context` with the specified `interfaces`. -Calls the postprocess logic for the attached plugins of the model. Should be called only once for a given `Model` object. -""" -function add_toplevel_model! end - -function add_toplevel_model!(model::Model, fform, interfaces) - return add_toplevel_model!(model, getcontext(model), fform, interfaces) -end - -function add_toplevel_model!(model::Model, context::Context, fform, interfaces) - add_terminated_submodel!(model, context, fform, interfaces) - foreach(getplugins(model)) do plugin - postprocess_plugin(plugin, model) - end - return model -end - -""" - prune!(m::Model) - -Remove all nodes from the model that are not connected to any other node. -""" -function prune!(m::Model) - degrees = degree(m.graph) - nodes_to_remove = keys(degrees)[degrees .== 0] - nodes_to_remove = sort(nodes_to_remove, rev = true) - rem_vertex!.(Ref(m.graph), nodes_to_remove) -end - -## Plugin steps - -""" -A trait object for plugins that add extra functionality for factor nodes. -""" -struct FactorNodePlugin <: AbstractPluginTraitType end - -""" -A trait object for plugins that add extra functionality for variable nodes. -""" -struct VariableNodePlugin <: AbstractPluginTraitType end - -""" -A trait object for plugins that add extra functionality both for factor and variable nodes. -""" -struct FactorAndVariableNodesPlugin <: AbstractPluginTraitType end - -""" - preprocess_plugin(plugin, model, context, label, nodedata, options) - -Call a plugin specific logic for a node with label and nodedata upon their creation. -""" -function preprocess_plugin end - -""" - postprocess_plugin(plugin, model) - -Calls a plugin specific logic after the model has been created. By default does nothing. -""" -postprocess_plugin(plugin, model) = nothing - -function preprocess_plugins( - type::AbstractPluginTraitType, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options -)::Tuple{NodeLabel, NodeData} - plugins = filter(type, getplugins(model)) - return foldl(plugins; init = (label, nodedata)) do (label, nodedata), plugin - return preprocess_plugin(plugin, model, context, label, nodedata, options)::Tuple{NodeLabel, NodeData} - end::Tuple{NodeLabel, NodeData} -end diff --git a/src/model_macro.jl b/src/macros/model_macro.jl similarity index 93% rename from src/model_macro.jl rename to src/macros/model_macro.jl index 2a667c08..ac3a1e3c 100644 --- a/src/model_macro.jl +++ b/src/macros/model_macro.jl @@ -3,69 +3,6 @@ import MacroTools: postwalk, prewalk, @capture, walk using NamedTupleTools using Static -__guard_f(f, e::Expr) = f(e) -__guard_f(f, x) = x - -struct guarded_walk{f} - guard::f -end - -function (w::guarded_walk)(f, x) - return w.guard(x) ? x : walk(x, x -> w(f, x), f) -end - -struct walk_until_occurrence{E} - patterns::E -end - -not_enter_indexed_walk = guarded_walk((x) -> (x isa Expr && x.head == :ref) || (x isa Expr && x.head == :call && x.args[1] == :new)) -not_created_by = guarded_walk((x) -> (x isa Expr && !isempty(x.args) && x.args[1] == :created_by)) - -function (w::walk_until_occurrence{E})(f, x) where {E <: Tuple} - return walk(x, z -> any(pattern -> @capture(x, $(pattern)), w.patterns) ? z : w(f, z), f) -end - -function (w::walk_until_occurrence{E})(f, x) where {E <: Expr} - return walk(x, z -> @capture(x, $(w.patterns)) ? z : w(f, z), f) -end - -what_walk(anything) = postwalk - -""" - apply_pipeline(e::Expr, pipeline) - -Apply a pipeline function to an expression. - -The `apply_pipeline` function takes an expression `e` and a `pipeline` function and applies the function in the pipeline to `e` when walking over it. The walk utilized can be specified by implementing `what_walk` for a pipeline funciton. - -# Arguments -- `e::Expr`: An expression to apply the pipeline to. -- `pipeline`: A function to apply to the expressions in `e`. - -# Returns -The result of applying the pipeline function to `e`. -""" -function apply_pipeline(e::Expr, pipeline::F) where {F} - walk = what_walk(pipeline) - return walk(x -> __guard_f(pipeline, x), e) -end - -""" - apply_pipeline_collection(e::Expr, collection) - -Similar to [`apply_pipeline`](@ref), but applies a collection of pipeline functions to an expression. - -# Arguments -- `e::Expr`: An expression to apply the pipeline to. -- `collection`: A collection of functions to apply to the expressions in `e`. - -# Returns -The result of applying the pipeline function to `e`. -""" -function apply_pipeline_collection(e::Expr, collection) - return reduce((e, pipeline) -> apply_pipeline(e, pipeline), collection, init = e) -end - """ check_reserved_variable_names_model(expr::Expr) diff --git a/src/model/context.jl b/src/model/context.jl new file mode 100644 index 00000000..43a8f8f9 --- /dev/null +++ b/src/model/context.jl @@ -0,0 +1,175 @@ +""" + Context + +Contains all information about a submodel in a probabilistic graphical model. +""" +struct Context + depth::Int64 + fform::Function + prefix::String + parent::Union{Context, Nothing} + submodel_counts::UnorderedDictionary{Any, Int} + children::UnorderedDictionary{FactorID, Context} + factor_nodes::UnorderedDictionary{FactorID, NodeLabel} + individual_variables::UnorderedDictionary{Symbol, NodeLabel} + vector_variables::UnorderedDictionary{Symbol, ResizableArray{NodeLabel, Vector{NodeLabel}, 1}} + tensor_variables::UnorderedDictionary{Symbol, ResizableArray{NodeLabel}} + proxies::UnorderedDictionary{Symbol, ProxyLabel} + returnval::Ref{Any} +end + +function Context(depth::Int, fform::Function, prefix::String, parent) + return Context( + depth, + fform, + prefix, + parent, + UnorderedDictionary{Any, Int}(), + UnorderedDictionary{FactorID, Context}(), + UnorderedDictionary{FactorID, NodeLabel}(), + UnorderedDictionary{Symbol, NodeLabel}(), + UnorderedDictionary{Symbol, ResizableArray{NodeLabel, Vector{NodeLabel}, 1}}(), + UnorderedDictionary{Symbol, ResizableArray{NodeLabel}}(), + UnorderedDictionary{Symbol, ProxyLabel}(), + Ref{Any}() + ) +end + +Context(parent::Context, model_fform::Function) = + Context(parent.depth + 1, model_fform, (parent.prefix == "" ? parent.prefix : parent.prefix * "_") * getname(model_fform), parent) +Context(fform) = Context(0, fform, "", nothing) +Context() = Context(identity) + +fform(context::Context) = context.fform +parent(context::Context) = context.parent +individual_variables(context::Context) = context.individual_variables +vector_variables(context::Context) = context.vector_variables +tensor_variables(context::Context) = context.tensor_variables +factor_nodes(context::Context) = context.factor_nodes +proxies(context::Context) = context.proxies +children(context::Context) = context.children +count(context::Context, fform::F) where {F} = haskey(context.submodel_counts, fform) ? context.submodel_counts[fform] : 0 +shortname(context::Context) = string(context.prefix) + +returnval(context::Context) = context.returnval[] + +function returnval!(context::Context, value) + context.returnval[] = postprocess_returnval(value) +end + +# We do not want to return `VariableRef` from the model +# In this case we replace them with the actual node labels +postprocess_returnval(value) = value +postprocess_returnval(value::Tuple) = map(postprocess_returnval, value) + +path_to_root(::Nothing) = [] +path_to_root(context::Context) = [context, path_to_root(parent(context))...] + +function generate_factor_nodelabel(context::Context, fform::F) where {F} + if count(context, fform) == 0 + set!(context.submodel_counts, fform, 1) + else + context.submodel_counts[fform] += 1 + end + return FactorID(fform, count(context, fform)) +end + +function Base.show(io::IO, mime::MIME"text/plain", context::Context) + iscompact = get(io, :compact, false)::Bool + + if iscompact + print(io, "Context(", shortname(context), " | ") + nvariables = + length(context.individual_variables) + + length(context.vector_variables) + + length(context.tensor_variables) + + length(context.proxies) + nfactornodes = length(context.factor_nodes) + print(io, nvariables, " variables, ", nfactornodes, " factor nodes") + if !isempty(context.children) + print(io, ", ", length(context.children), " children") + end + print(io, ")") + else + indentation = get(io, :indentation, 0)::Int + indentationstr = " "^indentation + indentationstrp1 = " "^(indentation + 1) + println(io, indentationstr, "Context(", shortname(context), ")") + println(io, indentationstrp1, "Individual variables: ", keys(individual_variables(context))) + println(io, indentationstrp1, "Vector variables: ", keys(vector_variables(context))) + println(io, indentationstrp1, "Tensor variables: ", keys(tensor_variables(context))) + println(io, indentationstrp1, "Proxies: ", keys(proxies(context))) + println(io, indentationstrp1, "Factor nodes: ", collect(keys(factor_nodes(context)))) + if !isempty(context.children) + println(io, indentationstrp1, "Children: ", map(shortname, values(context.children))) + end + end +end + +getname(f::Function) = String(Symbol(f)) + +haskey(context::Context, key::Symbol) = + haskey(context.individual_variables, key) || + haskey(context.vector_variables, key) || + haskey(context.tensor_variables, key) || + haskey(context.proxies, key) + +haskey(context::Context, key::FactorID) = haskey(context.factor_nodes, key) || haskey(context.children, key) + +function Base.getindex(c::Context, key::Symbol) + if haskey(c.individual_variables, key) + return c.individual_variables[key] + elseif haskey(c.vector_variables, key) + return c.vector_variables[key] + elseif haskey(c.tensor_variables, key) + return c.tensor_variables[key] + elseif haskey(c.proxies, key) + return c.proxies[key] + end + throw(KeyError(key)) +end + +function Base.getindex(c::Context, key::FactorID) + if haskey(c.factor_nodes, key) + return c.factor_nodes[key] + elseif haskey(c.children, key) + return c.children[key] + end + throw(KeyError(key)) +end + +Base.getindex(c::Context, fform, index::Int) = c[FactorID(fform, index)] + +Base.setindex!(c::Context, val::NodeLabel, key::Symbol) = set!(c.individual_variables, key, val) +Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Nothing) = set!(c.individual_variables, key, val) +Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Int) = c.vector_variables[key][index] = val +Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::NTuple{N, Int64} where {N}) = c.tensor_variables[key][index...] = val +Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, 1} where {T}, key::Symbol) = set!(c.vector_variables, key, val) +Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, N} where {T, N}, key::Symbol) = set!(c.tensor_variables, key, val) +Base.setindex!(c::Context, val::ProxyLabel, key::Symbol) = set!(c.proxies, key, val) +Base.setindex!(c::Context, val::ProxyLabel, key::Symbol, index::Nothing) = set!(c.proxies, key, val) +Base.setindex!(c::Context, val::Context, key::FactorID) = set!(c.children, key, val) +Base.setindex!(c::Context, val::NodeLabel, key::FactorID) = set!(c.factor_nodes, key, val) + +function copy_markov_blanket_to_child_context(child_context::Context, interfaces::NamedTuple) + foreach(pairs(interfaces)) do (name_in_child, object_in_parent) + add_to_child_context(child_context, name_in_child, object_in_parent) + end +end + +function add_to_child_context(child_context::Context, name_in_child::Symbol, object_in_parent::ProxyLabel) + set!(child_context.proxies, name_in_child, object_in_parent) + return nothing +end + +function add_to_child_context(child_context::Context, name_in_child::Symbol, object_in_parent) + # By default, we assume that `object_in_parent` is a constant, so there is no need to save it in the context + return nothing +end + +throw_if_individual_variable(context::Context, name::Symbol) = + haskey(context.individual_variables, name) ? error("Variable $name is already an individual variable in the model") : nothing +throw_if_vector_variable(context::Context, name::Symbol) = + haskey(context.vector_variables, name) ? error("Variable $name is already a vector variable in the model") : nothing +throw_if_tensor_variable(context::Context, name::Symbol) = + haskey(context.tensor_variables, name) ? error("Variable $name is already a tensor variable in the model") : nothing \ No newline at end of file diff --git a/src/model/indexed_variable.jl b/src/model/indexed_variable.jl new file mode 100644 index 00000000..38f0d6f5 --- /dev/null +++ b/src/model/indexed_variable.jl @@ -0,0 +1,22 @@ +""" + IndexedVariable(name, index) + +`IndexedVariable` represents a reference to a variable named `name` with index `index`. +""" +struct IndexedVariable{T} + name::Symbol + index::T +end + +getname(index::IndexedVariable) = index.name +index(index::IndexedVariable) = index.index + +Base.length(index::IndexedVariable{T} where {T}) = 1 +Base.iterate(index::IndexedVariable{T} where {T}) = (index, nothing) +Base.iterate(index::IndexedVariable{T} where {T}, any) = nothing +Base.:(==)(left::IndexedVariable, right::IndexedVariable) = (left.name == right.name && left.index == right.index) +Base.show(io::IO, variable::IndexedVariable{Nothing}) = print(io, variable.name) +Base.show(io::IO, variable::IndexedVariable) = print(io, variable.name, "[", variable.index, "]") + +Base.getindex(context::T, ivar::IndexedVariable{Nothing}) where {T} = Base.getindex(context, getname(ivar)) +Base.getindex(context::T, ivar::IndexedVariable) where {T} = Base.getindex(context, getname(ivar))[index(ivar)] \ No newline at end of file diff --git a/src/model/model.jl b/src/model/model.jl new file mode 100644 index 00000000..84326cdd --- /dev/null +++ b/src/model/model.jl @@ -0,0 +1,174 @@ +""" + Model(graph::MetaGraph) + +A structure representing a probabilistic graphical model. It contains a `MetaGraph` object +representing the factor graph and a `Base.RefValue{Int64}` object to keep track of the number +of nodes in the graph. + +Fields: +- `graph`: A `MetaGraph` object representing the factor graph. +- `plugins`: A `PluginsCollection` object representing the plugins enabled in the model. +- `backend`: A `Backend` object representing the backend used in the model. +- `source`: A `Source` object representing the original source code of the model (typically a `String` object). +- `counter`: A `Base.RefValue{Int64}` object keeping track of the number of nodes in the graph. +""" +struct Model{G, P, B, S} <: AbstractModel + graph::G + plugins::P + backend::B + source::S + counter::Base.RefValue{Int64} +end + +labels(model::Model) = MetaGraphsNext.labels(model.graph) +Base.isempty(model::Model) = iszero(nv(model.graph)) && iszero(ne(model.graph)) + +getplugins(model::Model) = model.plugins +getbackend(model::Model) = model.backend +getsource(model::Model) = model.source +getcounter(model::Model) = model.counter[] +setcounter!(model::Model, value) = model.counter[] = value + +Graphs.savegraph(file::AbstractString, model::GraphPPL.Model) = save(file, "__model__", model) +Graphs.loadgraph(file::AbstractString, ::Type{GraphPPL.Model}) = load(file, "__model__") + +NodeType(model::Model, fform::F) where {F} = NodeType(getbackend(model), fform) +NodeBehaviour(model::Model, fform::F) where {F} = NodeBehaviour(getbackend(model), fform) + +function Model(graph::MetaGraph, plugins::PluginsCollection, backend, source) + return Model(graph, plugins, backend, source, Base.RefValue(0)) +end + +function Model(fform::F, plugins::PluginsCollection) where {F} + return Model(fform, plugins, default_backend(fform), nothing) +end + +function Model(fform::F, plugins::PluginsCollection, backend, source) where {F} + label_type = NodeLabel + edge_data_type = EdgeLabel + vertex_data_type = NodeData + graph = MetaGraph(Graph(), label_type, vertex_data_type, edge_data_type, Context(fform)) + model = Model(graph, plugins, backend, source) + return model +end + +Base.setindex!(model::Model, val::NodeData, key::NodeLabel) = Base.setindex!(model.graph, val, key) +Base.setindex!(model::Model, val::EdgeLabel, src::NodeLabel, dst::NodeLabel) = Base.setindex!(model.graph, val, src, dst) +Base.getindex(model::Model) = Base.getindex(model.graph) +Base.getindex(model::Model, key::NodeLabel) = Base.getindex(model.graph, key) +Base.getindex(model::Model, src::NodeLabel, dst::NodeLabel) = Base.getindex(model.graph, src, dst) +Base.getindex(model::Model, keys::AbstractArray{NodeLabel}) = map(key -> model[key], keys) +Base.getindex(model::Model, keys::NTuple{N, NodeLabel}) where {N} = collect(map(key -> model[key], keys)) + +Base.getindex(model::Model, keys::Base.Generator) = [model[key] for key in keys] + +Graphs.nv(model::Model) = Graphs.nv(model.graph) +Graphs.ne(model::Model) = Graphs.ne(model.graph) +Graphs.edges(model::Model) = Graphs.edges(model.graph) + +Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node]) +Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes)) + +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::NodeData) = Graphs.neighbors(model, node, nodedata, getproperties(nodedata)) +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::NodeData, properties::FactorNodeProperties) = map(neighbor -> neighbor[1], neighbors(properties)) +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::NodeData, properties::VariableNodeProperties) = MetaGraphsNext.neighbor_labels(model.graph, node) + +Graphs.edges(model::Model, node::NodeLabel) = Graphs.edges(model, node, model[node]) +Graphs.edges(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.edges(model, node), nodes)) + +Graphs.edges(model::Model, node::NodeLabel, nodedata::NodeData) = Graphs.edges(model, node, nodedata, getproperties(nodedata)) +Graphs.edges(model::Model, node::NodeLabel, nodedata::NodeData, properties::FactorNodeProperties) = + map(neighbor -> neighbor[2], neighbors(properties)) + +function Graphs.edges(model::Model, node::NodeLabel, nodedata::NodeData, properties::VariableNodeProperties) + return (model[node, dst] for dst in MetaGraphsNext.neighbor_labels(model.graph, node)) +end + +Graphs.degree(model::Model, label::NodeLabel) = Graphs.degree(model.graph, MetaGraphsNext.code_for(model.graph, label)) + +function add_vertex!(model::Model, label, data) + # This is an unsafe procedure that implements behaviour from `MetaGraphsNext`. + code = nv(model) + 1 + model.graph.vertex_labels[code] = label + model.graph.vertex_properties[label] = (code, data) + Graphs.add_vertex!(model.graph.graph) +end + +function add_edge!(model::Model, src, dst, data) + # This is an unsafe procedure that implements behaviour from `MetaGraphsNext`. + code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst) + model.graph.edge_data[(src, dst)] = data + return Graphs.add_edge!(model.graph.graph, code_src, code_dst) +end + +function has_edge(model::Model, src, dst) + code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst) + return Graphs.has_edge(model.graph.graph, code_src, code_dst) +end + +function generate_nodelabel(model::Model, name::Symbol) + nextcounter = setcounter!(model, getcounter(model) + 1) + return NodeLabel(name, nextcounter) +end + +""" + getcontext(model::Model) + +Retrieves the context of a model. The context of a model contains the complete hierarchy of variables and factor nodes. +Additionally, contains all child submodels and their respective contexts. The Context supplies a mapping from symbols to `GraphPPL.NodeLabel` structures +with which the model can be queried. +""" +getcontext(model::Model) = model[] + +function get_principal_submodel(model::Model) + context = getcontext(model) + return context +end + +""" + aliases(backend, fform) + +Returns a collection of aliases for `fform` depending on the `backend`. +""" +aliases(model::Model, fform::F) where {F} = aliases(getbackend(model), fform) + +factor_nodes(model::Model) = Iterators.filter(node -> is_factor(model[node]), labels(model)) +variable_nodes(model::Model) = Iterators.filter(node -> is_variable(model[node]), labels(model)) + +""" +A version `factor_nodes(model)` that uses a callback function to process the factor nodes. +The callback function accepts both the label and the node data. +""" +function factor_nodes(callback::F, model::Model) where {F} + for label in labels(model) + nodedata = model[label] + if is_factor(nodedata) + callback((label::NodeLabel), (nodedata::NodeData)) + end + end +end + +""" +A version `variable_nodes(model)` that uses a callback function to process the variable nodes. +The callback function accepts both the label and the node data. +""" +function variable_nodes(callback::F, model::Model) where {F} + for label in labels(model) + nodedata = model[label] + if is_variable(nodedata) + callback((label::NodeLabel), (nodedata::NodeData)) + end + end +end + +""" + prune!(m::Model) + +Remove all nodes from the model that are not connected to any other node. +""" +function prune!(m::Model) + degrees = degree(m.graph) + nodes_to_remove = keys(degrees)[degrees .== 0] + nodes_to_remove = sort(nodes_to_remove, rev = true) + rem_vertex!.(Ref(m.graph), nodes_to_remove) +end \ No newline at end of file diff --git a/src/model/model_filtering.jl b/src/model/model_filtering.jl new file mode 100644 index 00000000..ba7bece5 --- /dev/null +++ b/src/model/model_filtering.jl @@ -0,0 +1,65 @@ + +struct FactorNodePredicate{N} <: AbstractModelFilterPredicate end + +function apply(::FactorNodePredicate{N}, model, something) where {N} + return apply(IsFactorNode(), model, something) && fform(getproperties(model[something])) ∈ aliases(model, N) +end + +struct IsFactorNode <: AbstractModelFilterPredicate end + +function apply(::IsFactorNode, model, something) + return is_factor(model[something]) +end + +struct VariableNodePredicate{V} <: AbstractModelFilterPredicate end + +function apply(::VariableNodePredicate{N}, model, something) where {N} + return apply(IsVariableNode(), model, something) && getname(getproperties(model[something])) === N +end + +struct IsVariableNode <: AbstractModelFilterPredicate end + +function apply(::IsVariableNode, model, something) + return is_variable(model[something]) +end + +struct SubmodelPredicate{S, C} <: AbstractModelFilterPredicate end + +function apply(::SubmodelPredicate{S, False}, model, something) where {S} + return fform(getcontext(model[something])) === S +end + +function apply(::SubmodelPredicate{S, True}, model, something) where {S} + return S ∈ fform.(path_to_root(getcontext(model[something]))) +end + +struct AndNodePredicate{L, R} <: AbstractModelFilterPredicate + left::L + right::R +end + +function apply(and::AndNodePredicate, model, something) + return apply(and.left, model, something) && apply(and.right, model, something) +end + +struct OrNodePredicate{L, R} <: AbstractModelFilterPredicate + left::L + right::R +end + +function apply(or::OrNodePredicate, model, something) + return apply(or.left, model, something) || apply(or.right, model, something) +end + +Base.:(|)(left::AbstractModelFilterPredicate, right::AbstractModelFilterPredicate) = OrNodePredicate(left, right) +Base.:(&)(left::AbstractModelFilterPredicate, right::AbstractModelFilterPredicate) = AndNodePredicate(left, right) + +as_node(any) = FactorNodePredicate{any}() +as_node() = IsFactorNode() +as_variable(any) = VariableNodePredicate{any}() +as_variable() = IsVariableNode() +as_context(any; children = false) = SubmodelPredicate{any, typeof(static(children))}() + +function Base.filter(predicate::AbstractModelFilterPredicate, model::Model) + return Iterators.filter(something -> apply(predicate, model, something), labels(model)) +end \ No newline at end of file diff --git a/src/model/node_creation.jl b/src/model/node_creation.jl new file mode 100644 index 00000000..10fd414d --- /dev/null +++ b/src/model/node_creation.jl @@ -0,0 +1,202 @@ +""" + NodeCreationOptions(namedtuple) + +Options for creating a node in a probabilistic graphical model. These are typically coming from the `where {}` block +in the `@model` macro, but can also be created manually. Expects a `NamedTuple` as an input. +""" +struct NodeCreationOptions{N} + options::N +end + +const EmptyNodeCreationOptions = NodeCreationOptions{Nothing}(nothing) + +NodeCreationOptions(; kwargs...) = convert(NodeCreationOptions, kwargs) + +Base.convert(::Type{NodeCreationOptions}, ::@Kwargs{}) = NodeCreationOptions(nothing) +Base.convert(::Type{NodeCreationOptions}, options) = NodeCreationOptions(NamedTuple(options)) + +Base.haskey(options::NodeCreationOptions, key::Symbol) = haskey(options.options, key) +Base.getindex(options::NodeCreationOptions, keys...) = getindex(options.options, keys...) +Base.getindex(options::NodeCreationOptions, keys::NTuple{N, Symbol}) where {N} = NodeCreationOptions(getindex(options.options, keys)) +Base.keys(options::NodeCreationOptions) = keys(options.options) +Base.get(options::NodeCreationOptions, key::Symbol, default) = get(options.options, key, default) + +# Fast fallback for empty options +Base.haskey(::NodeCreationOptions{Nothing}, key::Symbol) = false +Base.getindex(::NodeCreationOptions{Nothing}, keys...) = error("type `NodeCreationOptions{Nothing}` has no field $(keys)") +Base.keys(::NodeCreationOptions{Nothing}) = () +Base.get(::NodeCreationOptions{Nothing}, key::Symbol, default) = default + +withopts(::NodeCreationOptions{Nothing}, options::NamedTuple) = NodeCreationOptions(options) +withopts(options::NodeCreationOptions, extra::NamedTuple) = NodeCreationOptions((; options.options..., extra...)) + +withoutopts(::NodeCreationOptions{Nothing}, ::Val) = NodeCreationOptions(nothing) + +function withoutopts(options::NodeCreationOptions, ::Val{K}) where {K} + newoptions = options.options[filter(key -> key ∉ K, keys(options.options))] + # Should be compiled out, there are tests for it + if isempty(newoptions) + return NodeCreationOptions(nothing) + else + return NodeCreationOptions(newoptions) + end +end + +""" + getorcreate!(model::AbstractModel, context::Context, options::NodeCreationOptions, name, index) + +Get or create a variable (name) from a factor graph model and context, using an index if provided. + +This function searches for a variable (name) in the factor graph model and context specified by the arguments `model` and `context`. If the variable exists, +it returns it. Otherwise, it creates a new variable and returns it. + +# Arguments +- `model::AbstractModel`: The factor graph model to search for or create the variable in. +- `context::Context`: The context to search for or create the variable in. +- `options::NodeCreationOptions`: Options for creating the variable. Must be a `NodeCreationOptions` object. +- `name`: The variable (name) to search for or create. Must be a symbol. +- `index`: Optional index for the variable. Can be an integer, a collection of integers, or `nothing`. If the index is `nothing` creates a single variable. +If the index is an integer creates a vector-like variable. If the index is a collection of integers creates a tensor-like variable. + +# Returns +The variable (name) found or created in the factor graph model and context. +""" +function getorcreate! end + +getorcreate!(::AbstractModel, ::Context, name::Symbol) = error("Index is required in the `getorcreate!` function for variable `$(name)`") +getorcreate!(::AbstractModel, ::Context, options::NodeCreationOptions, name::Symbol) = + error("Index is required in the `getorcreate!` function for variable `$(name)`") + +function getorcreate!(model::AbstractModel, ctx::Context, name::Symbol, index...) + return getorcreate!(model, ctx, EmptyNodeCreationOptions, name, index...) +end + +function getorcreate!(model::AbstractModel, ctx::Context, options::NodeCreationOptions, name::Symbol, index::Nothing) + throw_if_vector_variable(ctx, name) + throw_if_tensor_variable(ctx, name) + return get(() -> add_variable_node!(model, ctx, options, name, index), ctx.individual_variables, name) +end + +function getorcreate!(model::AbstractModel, ctx::Context, options::NodeCreationOptions, name::Symbol, index::Integer) + throw_if_individual_variable(ctx, name) + throw_if_tensor_variable(ctx, name) + if !haskey(ctx.vector_variables, name) + ctx[name] = ResizableArray(NodeLabel, Val(1)) + end + vectorvar = ctx.vector_variables[name] + if !isassigned(vectorvar, index) + vectorvar[index] = add_variable_node!(model, ctx, options, name, index) + end + return vectorvar +end + +function getorcreate!(model::AbstractModel, ctx::Context, options::NodeCreationOptions, name::Symbol, i1::Integer, is::Vararg{Integer}) + throw_if_individual_variable(ctx, name) + throw_if_vector_variable(ctx, name) + if !haskey(ctx.tensor_variables, name) + ctx[name] = ResizableArray(NodeLabel, Val(1 + length(is))) + end + tensorvar = ctx.tensor_variables[name] + if !isassigned(tensorvar, i1, is...) + tensorvar[i1, is...] = add_variable_node!(model, ctx, options, name, (i1, is...)) + end + return tensorvar +end + +function getorcreate!(model::AbstractModel, ctx::Context, options::NodeCreationOptions, name::Symbol, range::AbstractRange) + isempty(range) && error("Empty range is not allowed in the `getorcreate!` function for variable `$(name)`") + foreach(range) do i + getorcreate!(model, ctx, options, name, i) + end + return getorcreate!(model, ctx, options, name, first(range)) +end + +function getorcreate!( + model::AbstractModel, ctx::Context, options::NodeCreationOptions, name::Symbol, r1::AbstractRange, rs::Vararg{AbstractRange} +) + (isempty(r1) || any(isempty, rs)) && error("Empty range is not allowed in the `getorcreate!` function for variable `$(name)`") + foreach(Iterators.product(r1, rs...)) do i + getorcreate!(model, ctx, options, name, i...) + end + return getorcreate!(model, ctx, options, name, first(r1), first.(rs)...) +end + +function getorcreate!(model::AbstractModel, ctx::Context, options::NodeCreationOptions, name::Symbol, indices...) + if haskey(ctx, name) + var = ctx[name] + return var + end + error(lazy"Cannot create a variable named `$(name)` with non-standard indices $(indices)") +end + +getifcreated(model::AbstractModel, context::Context, var::NodeLabel) = var +getifcreated(model::AbstractModel, context::Context, var::ResizableArray) = var +getifcreated( + model::AbstractModel, + context::Context, + var::Union{Tuple, AbstractArray{T}} where {T <: Union{NodeLabel, ProxyLabel, <:AbstractVariableReference}} +) = map((v) -> getifcreated(model, context, v), var) +getifcreated(model::AbstractModel, context::Context, var::ProxyLabel) = var +getifcreated(model::AbstractModel, context::Context, var) = + add_constant_node!(model, context, NodeCreationOptions(value = var, kind = :constant), :constvar, nothing) + +""" + add_variable_node!(model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index) + +Add a variable node to the model with the given `name` and `index`. +This function is unsafe (doesn't check if a variable with the given name already exists in the model). + +Args: + - `model::AbstractModel`: The model to which the node is added. + - `context::Context`: The context to which the symbol is added. + - `options::NodeCreationOptions`: The options for the creation process. + - `name::Symbol`: The ID of the variable. + - `index`: The index of the variable. + +Returns: + - The generated symbol for the variable. +""" +function add_variable_node! end + +function add_variable_node!(model::AbstractModel, context::Context, name::Symbol, index) + return add_variable_node!(model, context, EmptyNodeCreationOptions, name, index) +end + +function add_variable_node!(model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index) + label = __add_variable_node!(model, context, options, name, index) + context[name, index] = label +end + +function add_constant_node!(model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index) + label = __add_variable_node!(model, context, options, name, index) + context[to_symbol(name, label.global_counter), index] = label # to_symbol(label) is type unstable and we know the type of label.name here from name + return label +end + +function __add_variable_node!(model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index) + # In theory plugins are able to overwrite this + potential_label = generate_nodelabel(model, name) + potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options)) + label, nodedata = preprocess_plugins( + UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options + ) + add_vertex!(model, label, nodedata) + return label +end + +""" + generate_nodelabel(model::AbstractModel, name::Symbol) + +Generate a new `NodeLabel` object with a unique identifier based on the specified name and the +number of nodes already in the model. + +Arguments: +- `model`: A `AbstractModel` object representing the probabilistic graphical model. +- `name`: A symbol representing the name of the node. +- `variable_type`: A UInt8 representing the type of the variable. 0 = factor, 1 = individual variable, 2 = vector variable, 3 = tensor variable +- `index`: An integer or tuple of integers representing the index of the variable. +""" +function generate_nodelabel(model::AbstractModel, name) + nextcounter = setcounter!(model, getcounter(model) + 1) + return NodeLabel(name, nextcounter) +end \ No newline at end of file diff --git a/src/model/proxy_label.jl b/src/model/proxy_label.jl new file mode 100644 index 00000000..247f3a32 --- /dev/null +++ b/src/model/proxy_label.jl @@ -0,0 +1,139 @@ +""" + Splat{T} + +A type used to represent splatting in the model macro. Any call on the right hand side of ~ that uses splatting will be wrapped in this type. +""" +struct Splat{T} + collection::T +end + +""" + ProxyLabel(name, index, proxied) + +A label that proxies another label in a probabilistic graphical model. +The proxied objects must implement the `is_proxied(::Type) = True()`. +The proxy labels may spawn new variables in a model, if `maycreate` is set to `True()`. +""" +mutable struct ProxyLabel{P, I, M} + const name::Symbol + const proxied::P + const index::I + const maycreate::M +end + +is_proxied(any) = is_proxied(typeof(any)) +is_proxied(::Type) = False() +is_proxied(::Type{T}) where {T <: NodeLabel} = True() +is_proxied(::Type{T}) where {T <: ProxyLabel} = True() +is_proxied(::Type{T}) where {T <: AbstractArray} = is_proxied(eltype(T)) + +proxylabel(name::Symbol, proxied::Splat{T}, index, maycreate) where {T} = + [proxylabel(name, proxiedelement, index, maycreate) for proxiedelement in proxied.collection] + +# By default, `proxylabel` set `maycreate` to `False` +proxylabel(name::Symbol, proxied, index) = proxylabel(name, proxied, index, False()) +proxylabel(name::Symbol, proxied, index, maycreate) = proxylabel(is_proxied(proxied), name, proxied, index, maycreate) + +# In case if `is_proxied` returns `False` we simply return the original object, because the object cannot be proxied +proxylabel(::False, name::Symbol, proxied::Any, index::Nothing, maycreate) = proxied +proxylabel(::False, name::Symbol, proxied::Any, index::Tuple, maycreate) = proxied[index...] + +# In case if `is_proxied` returns `True`, we wrap the object into the `ProxyLabel` for later `unroll`-ing +function proxylabel(::True, name::Symbol, proxied::Any, index::Any, maycreate::Any) + return ProxyLabel(name, proxied, index, maycreate) +end + +# In case if `proxied` is another `ProxyLabel` we take `|` operation with its `maycreate` to lift it further +# This is a useful operation for `datalabels`, since they define `maycreate = True()` on their creation time +# That means that all subsequent usages of data labels will always create a new label, even when used on right hand side from `~` +function proxylabel(::True, name::Symbol, proxied::ProxyLabel, index::Any, maycreate::Any) + return ProxyLabel(name, proxied, index, proxied.maycreate | maycreate) +end + +getname(label::ProxyLabel) = label.name +index(label::ProxyLabel) = label.index + +# This function allows to overwrite the `maycreate` flag on a proxy label, might be useful for situations where code should +# definitely not create a new variable, e.g in the variational constraints plugin +set_maycreate(proxylabel::ProxyLabel, maycreate::Union{True, False}) = + ProxyLabel(proxylabel.name, proxylabel.proxied, proxylabel.index, maycreate) +set_maycreate(something, maycreate::Union{True, False}) = something + +function unroll(something) + return something +end + +function unroll(proxylabel::ProxyLabel) + return unroll(proxylabel, proxylabel.proxied, proxylabel.index, proxylabel.maycreate, proxylabel.index) +end + +function unroll(proxylabel::ProxyLabel, proxied::ProxyLabel, index, maycreate, liftedindex) + # In case of a chain of proxy-labels we should lift the index, that potentially might + # be used to create a new collection of variables + liftedindex = lift_index(maycreate, index, liftedindex) + unrolled = unroll(proxied, proxied.proxied, proxied.index, proxied.maycreate, liftedindex) + return checked_getindex(unrolled, index) +end + +function unroll(proxylabel::ProxyLabel, something::Any, index, maycreate, liftedindex) + return checked_getindex(something, index) +end + +checked_getindex(something, index::FunctionalIndex) = Base.getindex(something, index) +checked_getindex(something, index::Tuple) = Base.getindex(something, index...) +checked_getindex(something, index::Nothing) = something + +checked_getindex(nodelabel::NodeLabel, index::Nothing) = nodelabel +checked_getindex(nodelabel::NodeLabel, index::Tuple) = + error("Indexing a single node label `$(getname(nodelabel))` with an index `[$(join(index, ", "))]` is not allowed.") +checked_getindex(nodelabel::NodeLabel, index) = + error("Indexing a single node label `$(getname(nodelabel))` with an index `$index` is not allowed.") + +""" +The `lift_index` function "lifts" (or tracks) the index that is going to be used to determine the shape of the container upon creation +for a variable during the unrolling of the `ProxyLabel`. This index is used only if the container is set to be created and is not used if +variable container already exists. +""" +function lift_index end + +lift_index(::True, ::Nothing, ::Nothing) = nothing +lift_index(::True, current, ::Nothing) = current +lift_index(::True, ::Nothing, previous) = previous +lift_index(::True, current, previous) = current +lift_index(::False, current, previous) = previous + +Base.show(io::IO, proxy::ProxyLabel) = show_proxy(io, getname(proxy), index(proxy)) +show_proxy(io::IO, name::Symbol, index::Nothing) = print(io, name) +show_proxy(io::IO, name::Symbol, index::Tuple) = print(io, name, "[", join(index, ","), "]") +show_proxy(io::IO, name::Symbol, index::Any) = print(io, name, "[", index, "]") + +Base.last(label::ProxyLabel) = last(label.proxied, label) +Base.last(proxied::ProxyLabel, ::ProxyLabel) = last(proxied) +Base.last(proxied, ::ProxyLabel) = proxied + +Base.:(==)(proxy1::ProxyLabel, proxy2::ProxyLabel) = + proxy1.name == proxy2.name && proxy1.index == proxy2.index && proxy1.proxied == proxy2.proxied +Base.hash(proxy::ProxyLabel, h::UInt) = hash(proxy.maycreate, hash(proxy.name, hash(proxy.index, hash(proxy.proxied, h)))) + +# Iterator's interface methods +Base.IteratorSize(proxy::ProxyLabel) = Base.IteratorSize(indexed_last(proxy)) +Base.IteratorEltype(proxy::ProxyLabel) = Base.IteratorEltype(indexed_last(proxy)) +Base.eltype(proxy::ProxyLabel) = Base.eltype(indexed_last(proxy)) + +Base.length(proxy::ProxyLabel) = length(indexed_last(proxy)) +Base.size(proxy::ProxyLabel, dims...) = size(indexed_last(proxy), dims...) +Base.firstindex(proxy::ProxyLabel) = firstindex(indexed_last(proxy)) +Base.lastindex(proxy::ProxyLabel) = lastindex(indexed_last(proxy)) +Base.eachindex(proxy::ProxyLabel) = eachindex(indexed_last(proxy)) +Base.axes(proxy::ProxyLabel) = axes(indexed_last(proxy)) +Base.getindex(proxy::ProxyLabel, indices...) = getindex(indexed_last(proxy), indices...) +Base.size(proxy::ProxyLabel) = size(indexed_last(proxy)) +Base.broadcastable(proxy::ProxyLabel) = Base.broadcastable(indexed_last(proxy)) + +postprocess_returnval(proxy::ProxyLabel) = postprocess_returnval(indexed_last(proxy)) + +"""Similar to `Base.last` when applied on `ProxyLabel`, but also applies `checked_getindex` while unrolling""" +function indexed_last end + +indexed_last(proxy::ProxyLabel) = checked_getindex(indexed_last(proxy.proxied), proxy.index) +indexed_last(something) = something \ No newline at end of file diff --git a/src/model/var_dict.jl b/src/model/var_dict.jl new file mode 100644 index 00000000..2059045a --- /dev/null +++ b/src/model/var_dict.jl @@ -0,0 +1,42 @@ +""" + VarDict + +A recursive dictionary structure that contains all variables in a probabilistic graphical model. +Iterates over all variables in the model and their children in a linear fashion, but preserves the recursive nature of the actual model. +""" +struct VarDict{T} + variables::UnorderedDictionary{Symbol, T} + children::UnorderedDictionary{FactorID, VarDict} +end + +function VarDict(context::Context) + dictvariables = merge(individual_variables(context), vector_variables(context), tensor_variables(context)) + dictchildren = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> VarDict(child), children(context))) + return VarDict(dictvariables, dictchildren) +end + +variables(vardict::VarDict) = vardict.variables +children(vardict::VarDict) = vardict.children + +haskey(vardict::VarDict, key::Symbol) = haskey(vardict.variables, key) +haskey(vardict::VarDict, key::Tuple{T, Int} where {T}) = haskey(vardict.children, FactorID(first(key), last(key))) +haskey(vardict::VarDict, key::FactorID) = haskey(vardict.children, key) + +Base.getindex(vardict::VarDict, key::Symbol) = vardict.variables[key] +Base.getindex(vardict::VarDict, f, index::Int) = vardict.children[FactorID(f, index)] +Base.getindex(vardict::VarDict, key::Tuple{T, Int} where {T}) = vardict.children[FactorID(first(key), last(key))] +Base.getindex(vardict::VarDict, key::FactorID) = vardict.children[key] + +function Base.map(f, vardict::VarDict) + mapped_variables = map(f, variables(vardict)) + mapped_children = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> map(f, child), children(vardict))) + return VarDict(mapped_variables, mapped_children) +end + +function Base.filter(f, vardict::VarDict) + filtered_variables = filter(f, variables(vardict)) + filtered_children = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> filter(f, child), children(vardict))) + return VarDict(filtered_variables, filtered_children) +end + +Base.:(==)(left::VarDict, right::VarDict) = left.variables == right.variables && left.children == right.children \ No newline at end of file diff --git a/src/model/variable_ref.jl b/src/model/variable_ref.jl new file mode 100644 index 00000000..6111c8ac --- /dev/null +++ b/src/model/variable_ref.jl @@ -0,0 +1,300 @@ +""" + VariableRef(model::AbstractModel, context::Context, name::Symbol, index, external_collection = nothing) + +`VariableRef` implements a lazy reference to a variable in the model. +The reference does not create an actual variable in the model immediatelly, but postpones the creation +until strictly necessarily, which is hapenning inside the `unroll` function. The postponed creation allows users to define +pass a single variable into a submodel, e.g. `y ~ submodel(x = x)`, but use it as an array inside the submodel, +e.g. `y[i] ~ Normal(x[i], 1.0)`. + +Optionally accepts an `external_collection`, which defines the upper limit on the shape of the underlying collection. +For example, an external collection `[ 1, 2, 3 ]` can be used both as `y ~ ...` and `y[i] ~ ...`, but not as `y[i, j] ~ ...`. +By default, the `MissingCollection` is used for the `external_collection`, which does not restrict the shape of the underlying collection. + +The `index` is always a `Tuple`. By default, `(nothing, )` is used, to indicate empty indices with no restrictions on the shape of the underlying collection. +If "non-nothing" index is supplied, e.g. `(1, )` the shape of the udnerlying collection will be fixed to match the index +(1-dimensional in case of `(1, )`, 2-dimensional in case of `(1, 1)` and so on). +""" +struct VariableRef{M, C, O, I, E, L} <: AbstractVariableReference + model::M + context::C + options::O + name::Symbol + index::I + external_collection::E + internal_collection::L +end + +Base.:(==)(left::VariableRef, right::VariableRef) = + left.model == right.model && left.context == right.context && left.name == right.name && left.index == right.index + +function Base.:(==)(left::VariableRef, right) + error( + "Comparing Factor Graph variable `$left` with a value. This is not possible as the value of `$left` is not known at model construction time." + ) +end +Base.:(==)(left, right::VariableRef) = right == left + +Base.:(>)(left::VariableRef, right) = left == right +Base.:(>)(left, right::VariableRef) = left == right +Base.:(<)(left::VariableRef, right) = left == right +Base.:(<)(left, right::VariableRef) = left == right +Base.:(>=)(left::VariableRef, right) = left == right +Base.:(>=)(left, right::VariableRef) = left == right +Base.:(<=)(left::VariableRef, right) = left == right +Base.:(<=)(left, right::VariableRef) = left == right + +is_proxied(::Type{T}) where {T <: VariableRef} = True() + +external_collection_typeof(::Type{VariableRef{M, C, O, I, E, L}}) where {M, C, O, I, E, L} = E +internal_collection_typeof(::Type{VariableRef{M, C, O, I, E, L}}) where {M, C, O, I, E, L} = L + +external_collection(ref::VariableRef) = ref.external_collection +internal_collection(ref::VariableRef) = ref.internal_collection + +Base.show(io::IO, ref::VariableRef) = variable_ref_show(io, ref.name, ref.index) +variable_ref_show(io::IO, name::Symbol, index::Nothing) = print(io, name) +variable_ref_show(io::IO, name::Symbol, index::Tuple{Nothing}) = print(io, name) +variable_ref_show(io::IO, name::Symbol, index::Tuple) = print(io, name, "[", join(index, ","), "]") +variable_ref_show(io::IO, name::Symbol, index::Any) = print(io, name, "[", index, "]") + +""" + makevarref(fform::F, model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) + +A function that creates `VariableRef`, but takes the `fform` into account. When `fform` happens to be `Atomic` creates +the underlying variable immediatelly without postponing. When `fform` is `Composite` does not create the actual variable, +but waits until strictly necessarily. +""" +function makevarref end + +function makevarref(fform::F, model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) where {F} + return makevarref(NodeType(model, fform), model, context, options, name, index) +end + +function makevarref(::Atomic, model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) + # In the case of `Atomic` variable reference, we always create the variable + # (unless the index is empty, which may happen during broadcasting) + internal_collection = isempty(index) ? nothing : getorcreate!(model, context, name, index...) + return VariableRef(model, context, options, name, index, nothing, internal_collection) +end + +function makevarref(::Composite, model::AbstractModel, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) + # In the case of `Composite` variable reference, we create it immediatelly only when the variable is instantiated + # with indexing operation + internal_collection = if !all(isnothing, index) + getorcreate!(model, context, name, index...) + else + nothing + end + return VariableRef(model, context, options, name, index, nothing, internal_collection) +end + +function VariableRef( + model::AbstractModel, + context::Context, + options::NodeCreationOptions, + name::Symbol, + index::Tuple, + external_collection = nothing, + internal_collection = nothing +) + M = typeof(model) + C = typeof(context) + O = typeof(options) + I = typeof(index) + E = typeof(external_collection) + L = typeof(internal_collection) + return VariableRef{M, C, O, I, E, L}(model, context, options, name, index, external_collection, internal_collection) +end + +function unroll(p::ProxyLabel, ref::VariableRef, index, maycreate, liftedindex) + liftedindex = lift_index(maycreate, index, liftedindex) + if maycreate === False() + return checked_getindex(getifcreated(ref.model, ref.context, ref, liftedindex), index) + elseif maycreate === True() + return checked_getindex(getorcreate!(ref.model, ref.context, ref, liftedindex), index) + end + error("Unreachable. The `maycreate` argument in the `unroll` function for the `VariableRef` must be either `True` or `False`.") +end + +function getifcreated(model::AbstractModel, context::Context, ref::VariableRef) + return getifcreated(model, context, ref, ref.index) +end + +function getifcreated(model::AbstractModel, context::Context, ref::VariableRef, index) + if !isnothing(ref.external_collection) + return getorcreate!(ref.model, ref.context, ref, index) + elseif !isnothing(ref.internal_collection) + return ref.internal_collection + elseif haskey(ref.context, ref.name) + return ref.context[ref.name] + else + error(lazy"The variable `$ref` has been used, but has not been instantiated.") + end +end + +function getorcreate!(model::AbstractModel, context::Context, ref::VariableRef, index::Nothing) + check_external_collection_compatibility(ref, index) + return getorcreate!(model, context, ref.options, ref.name, index) +end + +function getorcreate!(model::AbstractModel, context::Context, ref::VariableRef, index::Tuple) + check_external_collection_compatibility(ref, index) + return getorcreate!(model, context, ref.options, ref.name, index...) +end + +Base.IteratorSize(ref::VariableRef) = Base.IteratorSize(typeof(ref)) +Base.IteratorEltype(ref::VariableRef) = Base.IteratorEltype(typeof(ref)) +Base.eltype(ref::VariableRef) = Base.eltype(typeof(ref)) + +Base.IteratorSize(::Type{R}) where {R <: VariableRef} = + variable_ref_iterator_size(external_collection_typeof(R), internal_collection_typeof(R)) +variable_ref_iterator_size(::Type{Nothing}, ::Type{Nothing}) = Base.SizeUnknown() +variable_ref_iterator_size(::Type{E}, ::Type{L}) where {E, L} = Base.IteratorSize(E) +variable_ref_iterator_size(::Type{Nothing}, ::Type{L}) where {L} = Base.IteratorSize(L) + +Base.IteratorEltype(::Type{R}) where {R <: VariableRef} = + variable_ref_iterator_eltype(external_collection_typeof(R), internal_collection_typeof(R)) +variable_ref_iterator_eltype(::Type{Nothing}, ::Type{Nothing}) = Base.EltypeUnknown() +variable_ref_iterator_eltype(::Type{E}, ::Type{L}) where {E, L} = Base.IteratorEltype(E) +variable_ref_iterator_eltype(::Type{Nothing}, ::Type{L}) where {L} = Base.IteratorEltype(L) + +Base.eltype(::Type{R}) where {R <: VariableRef} = variable_ref_eltype(external_collection_typeof(R), internal_collection_typeof(R)) +variable_ref_eltype(::Type{Nothing}, ::Type{Nothing}) = Any +variable_ref_eltype(::Type{E}, ::Type{L}) where {E, L} = Base.eltype(E) +variable_ref_eltype(::Type{Nothing}, ::Type{L}) where {L} = Base.eltype(L) + +function variableref_checked_collection_typeof(::VariableRef) + return variableref_checked_iterator_call(typeof, :typeof, ref) +end + +Base.length(ref::VariableRef) = variableref_checked_iterator_call(Base.length, :length, ref) +Base.firstindex(ref::VariableRef) = variableref_checked_iterator_call(Base.firstindex, :firstindex, ref) +Base.lastindex(ref::VariableRef) = variableref_checked_iterator_call(Base.lastindex, :lastindex, ref) +Base.eachindex(ref::VariableRef) = variableref_checked_iterator_call(Base.eachindex, :eachindex, ref) +Base.axes(ref::VariableRef) = variableref_checked_iterator_call(Base.axes, :axes, ref) + +Base.size(ref::VariableRef, dims...) = variableref_checked_iterator_call((c) -> Base.size(c, dims...), :size, ref) +Base.getindex(ref::VariableRef, indices...) = variableref_checked_iterator_call((c) -> Base.getindex(c, indices...), :getindex, ref) + +function variableref_checked_iterator_call(f::F, fsymbol::Symbol, ref::VariableRef) where {F} + if !isnothing(ref.external_collection) + return f(ref.external_collection) + elseif !isnothing(ref.internal_collection) + return f(ref.internal_collection) + elseif haskey(ref.context, ref.name) + return f(ref.context[ref.name]) + end + error(lazy"Cannot call `$(fsymbol)` on variable reference `$(ref.name)`. The variable `$(ref.name)` has not been instantiated.") +end + +function postprocess_returnval(ref::VariableRef) + if haskey(ref.context, ref.name) + return ref.context[ref.name] + end + error("Cannot `return $(ref)`. The variable has not been instantiated.") +end + +""" +A placeholder collection for `VariableRef` when the actual external collection is not yet available. +""" +struct MissingCollection end + +__err_missing_collection_missing_method(method::Symbol) = + error("The `$method` method is not defined for a lazy node label without data attached.") + +Base.IteratorSize(::Type{MissingCollection}) = __err_missing_collection_missing_method(:IteratorSize) +Base.IteratorEltype(::Type{MissingCollection}) = __err_missing_collection_missing_method(:IteratorEltype) +Base.eltype(::Type{MissingCollection}) = __err_missing_collection_missing_method(:eltype) +Base.length(::MissingCollection) = __err_missing_collection_missing_method(:length) +Base.size(::MissingCollection, dims...) = __err_missing_collection_missing_method(:size) +Base.firstindex(::MissingCollection) = __err_missing_collection_missing_method(:firstindex) +Base.lastindex(::MissingCollection) = __err_missing_collection_missing_method(:lastindex) +Base.eachindex(::MissingCollection) = __err_missing_collection_missing_method(:eachindex) +Base.axes(::MissingCollection) = __err_missing_collection_missing_method(:axes) + +function check_external_collection_compatibility(ref::VariableRef, index) + if !isnothing(external_collection(ref)) && !__check_external_collection_compatibility(ref, index) + error( + """ + The index `[$(!isnothing(index) ? join(index, ", ") : nothing)]` is not compatible with the underlying collection provided for the label `$(ref.name)`. + The underlying data provided for `$(ref.name)` is `$(external_collection(ref))`. + """ + ) + end + return nothing +end + +function __check_external_collection_compatibility(ref::VariableRef, index::Nothing) + # We assume that index `nothing` is always compatible with the underlying collection + # Eg. a matrix `Σ` can be used both as it is `Σ`, but also as `Σ[1]` or `Σ[1, 1]` + return true +end + +function __check_external_collection_compatibility(ref::VariableRef, index::Tuple) + return __check_external_collection_compatibility(ref, external_collection(ref), index) +end + +# We can't really check if the data compatible or not if we get the `MissingCollection` +__check_external_collection_compatibility(label::VariableRef, ::MissingCollection, index::Tuple) = true +__check_external_collection_compatibility(label::VariableRef, collection::AbstractArray, indices::Tuple) = + checkbounds(Bool, collection, indices...) +__check_external_collection_compatibility(label::VariableRef, collection::Tuple, indices::Tuple) = + length(indices) === 1 && first(indices) ∈ 1:length(collection) +# A number cannot really be queried with non-empty indices +__check_external_collection_compatibility(label::VariableRef, collection::Number, indices::Tuple) = false +# For all other we simply don't know so we assume we are compatible +__check_external_collection_compatibility(label::VariableRef, collection, indices::Tuple) = true + +function Base.iterate(ref::VariableRef, state) + if !isnothing(external_collection(ref)) + return iterate(external_collection(ref), state) + elseif !isnothing(internal_collection(ref)) + return iterate(internal_collection(ref), state) + elseif haskey(ref.context, ref.name) + return iterate(ref.context[ref.name], state) + end + error("Cannot iterate over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.") +end + +function Base.iterate(ref::VariableRef) + if !isnothing(external_collection(ref)) + return iterate(external_collection(ref)) + elseif !isnothing(internal_collection(ref)) + return iterate(internal_collection(ref)) + elseif haskey(ref.context, ref.name) + return iterate(ref.context[ref.name]) + end + error("Cannot iterate over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.") +end + +function Base.broadcastable(ref::VariableRef) + if !isnothing(external_collection(ref)) + # If we have an underlying collection (e.g. data), we should instantiate all variables at the point of broadcasting + # in order to support something like `y .~ ` where `y` is a data label + return collect( + Iterators.map( + I -> checked_getindex(getorcreate!(ref.model, ref.context, ref.options, ref.name, I.I...), I.I), CartesianIndices(axes(ref)) + ) + ) + elseif !isnothing(internal_collection(ref)) + return Base.broadcastable(internal_collection(ref)) + elseif haskey(ref.context, ref.name) + return Base.broadcastable(ref.context[ref.name]) + end + error("Cannot broadcast over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.") +end + +""" + datalabel(model, context, options, name, collection = MissingCollection()) + +A function for creating proxy data labels to pass into the model upon creation. +Can be useful in combination with `AbstractModelGenerator` and `create_model`. +""" +function datalabel(model, context, options, name, collection = MissingCollection()) + kind = get(options, :kind, VariableKindUnknown) + if !isequal(kind, VariableKindData) + error("`datalabel` only supports `VariableKindData` in `NodeCreationOptions`") + end + return proxylabel(name, VariableRef(model, context, options, name, (nothing,), collection), nothing, True()) +end \ No newline at end of file diff --git a/src/nodes/node_materialization.jl b/src/nodes/node_materialization.jl new file mode 100644 index 00000000..f5e25bdf --- /dev/null +++ b/src/nodes/node_materialization.jl @@ -0,0 +1,572 @@ +""" + AnonymousVariable(model, context) + +Defines a lazy structure for anonymous variables. +The actual anonymous variables materialize only in `make_node!` upon calling, because it needs arguments to the `make_node!` in order to create proper links. +""" +struct AnonymousVariable{M, C} + model::M + context::C +end + +Base.broadcastable(v::AnonymousVariable) = Ref(v) + +create_anonymous_variable!(model::AbstractModel, context::Context) = AnonymousVariable(model, context) + +function materialize_anonymous_variable!(anonymous::AnonymousVariable, fform, args) + model = anonymous.model + return materialize_anonymous_variable!(NodeBehaviour(model, fform), model, anonymous.context, fform, args) +end + +# Deterministic nodes can create links to variables in the model +# This might be important for better factorization constraints resolution +function materialize_anonymous_variable!(::Deterministic, model::AbstractModel, context::Context, fform, args) + linked = getindex.(Ref(model), unroll.(filter(is_nodelabel, args))) + + # Check if all links are either `data` or `constants` + # In this case it is not necessary to create a new random variable, but rather a data variable + # with `value = fform` + link_const, link_const_or_data = reduce(linked; init = (true, true)) do accum, link + check_is_all_constant, check_is_all_constant_or_data = accum + check_is_all_constant = check_is_all_constant && anonymous_arg_is_constanst(link) + check_is_all_constant_or_data = check_is_all_constant_or_data && anonymous_arg_is_constanst_or_data(link) + return (check_is_all_constant, check_is_all_constant_or_data) + end + + if !link_const && !link_const_or_data + # Most likely case goes first, we need to create a new factor node and a new random variable + (true, add_variable_node!(model, context, NodeCreationOptions(link = linked), VariableNameAnonymous, nothing)) + elseif link_const + # If all `links` are constant nodes we can evaluate the `fform` here and create another constant rather than creating a new factornode + val = fform(map(arg -> arg isa NodeLabel ? value(getproperties(model[arg])) : arg, unroll.(args))...) + ( + false, + add_variable_node!( + model, context, NodeCreationOptions(kind = :constant, value = val, link = linked), VariableNameAnonymous, nothing + ) + ) + elseif link_const_or_data + # If all `links` are constant or data we can create a new data variable with `fform` attached to it as a value rather than creating a new factornode + ( + false, + add_variable_node!( + model, + context, + NodeCreationOptions(kind = :data, value = (fform, unroll.(args)), link = linked), + VariableNameAnonymous, + nothing + ) + ) + else + # This should not really happen + error("Unreachable reached in `materialize_anonymous_variable!` for `Deterministic` node behaviour.") + end +end + +anonymous_arg_is_constanst(data) = true +anonymous_arg_is_constanst(data::AbstractNodeData) = is_constant(getproperties(data)) +anonymous_arg_is_constanst(data::AbstractArray) = all(anonymous_arg_is_constanst, data) + +anonymous_arg_is_constanst_or_data(data) = is_constant(data) +anonymous_arg_is_constanst_or_data(data::AbstractNodeData) = + let props = getproperties(data) + is_constant(props) || is_data(props) + end +anonymous_arg_is_constanst_or_data(data::AbstractArray) = all(anonymous_arg_is_constanst_or_data, data) + +function materialize_anonymous_variable!(::Deterministic, model::AbstractModel, context::Context, fform, args::NamedTuple) + return materialize_anonymous_variable!(Deterministic(), model, context, fform, values(args)) +end + +function materialize_anonymous_variable!(::Stochastic, model::AbstractModel, context::Context, fform, _) + return (true, add_variable_node!(model, context, NodeCreationOptions(), VariableNameAnonymous, nothing)) +end + +""" + add_atomic_factor_node!(model::AbstractModel, context::Context, options::NodeCreationOptions, fform) + +Add an atomic factor node to the model with the given name. +The function generates a new symbol for the node and adds it to the model with +the generated symbol as the key and a `FactorAbstractNodeData` struct. + +Args: + - `model::AbstractModel`: The model to which the node is added. + - `context::Context`: The context to which the symbol is added. + - `options::NodeCreationOptions`: The options for the creation process. + - `fform::Any`: The functional form of the node. + +Returns: + - The generated label for the node. +""" +function add_atomic_factor_node! end + +function add_atomic_factor_node!(model::AbstractModel, context::Context, options::NodeCreationOptions, fform::F) where {F} + factornode_id = generate_factor_nodelabel(context, fform) + + potential_label = generate_nodelabel(model, fform) + potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options)) + + label, nodedata = preprocess_plugins( + UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options + ) + + add_vertex!(model, label, nodedata) + context[factornode_id] = label + + return label, nodedata, convert(FactorNodeProperties, getproperties(nodedata)) +end + +""" +Add a composite factor node to the model with the given name. + +The function generates a new symbol for the node and adds it to the model with +the generated symbol as the key and a `AbstractNodeData` struct with `is_variable` set to +`false` and `node_name` set to the given name. + +Args: + - `model::AbstractModel`: The model to which the node is added. + - `parent_context::Context`: The context to which the symbol is added. + - `context::Context`: The context of the composite factor node. + - `node_name::Symbol`: The name of the node. + +Returns: + - The generated id for the node. +""" +function add_composite_factor_node!(model::AbstractModel, parent_context::Context, context::Context, node_name) + node_id = generate_factor_nodelabel(parent_context, node_name) + parent_context[node_id] = context + return node_id +end + +function add_edge!( + model::AbstractModel, + factor_node_id::NodeLabel, + factor_node_propeties::FactorNodeProperties, + variable_node_id::Union{ProxyLabel, NodeLabel, AbstractVariableReference}, + interface_name::Symbol +) + return add_edge!(model, factor_node_id, factor_node_propeties, variable_node_id, interface_name, nothing) +end + +function add_edge!( + model::AbstractModel, + factor_node_id::NodeLabel, + factor_node_propeties::FactorNodeProperties, + variable_node_id::Union{AbstractArray, Tuple, NamedTuple}, + interface_name::Symbol +) + return add_edge!(model, factor_node_id, factor_node_propeties, variable_node_id, interface_name, 1) +end + +add_edge!( + model::AbstractModel, + factor_node_id::NodeLabel, + factor_node_propeties::FactorNodeProperties, + variable_node_id::Union{ProxyLabel, AbstractVariableReference}, + interface_name::Symbol, + index +) = add_edge!(model, factor_node_id, factor_node_propeties, unroll(variable_node_id), interface_name, index) + +function add_edge!( + model::AbstractModel, + factor_node_id::NodeLabel, + factor_node_propeties::FactorNodeProperties, + variable_node_id::Union{NodeLabel}, + interface_name::Symbol, + index +) + label = EdgeLabel(interface_name, index) + neighbor_node_label = unroll(variable_node_id) + addneighbor!(factor_node_propeties, neighbor_node_label, label, model[neighbor_node_label]) + edge_added = add_edge!(model, neighbor_node_label, factor_node_id, label) + if !edge_added + # Double check if the edge has already been added + if has_edge(model, neighbor_node_label, factor_node_id) + error( + lazy"Trying to create duplicate edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id). Make sure that all the arguments to the `~` operator are unique (both left hand side and right hand side)." + ) + else + error(lazy"Cannot create an edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id).") + end + end + return label +end + +function add_edge!( + model::AbstractModel, + factor_node_id::NodeLabel, + factor_node_propeties::FactorNodeProperties, + variable_nodes::Union{AbstractArray, Tuple, NamedTuple}, + interface_name::Symbol, + index +) + for variable_node in variable_nodes + add_edge!(model, factor_node_id, factor_node_propeties, variable_node, interface_name, index) + index += increase_index(variable_node) + end +end + +increase_index(any) = 1 +increase_index(x::AbstractArray) = length(x) + +struct MixedArguments{A <: Tuple, K <: NamedTuple} + args::A + kwargs::K +end + +""" + StaticInterfaces{I} + +A type that represents a statically defined set of interfaces for a node in a probabilistic graphical model. +The interfaces are encoded in the type parameter `I` as a tuple of symbols, enabling compile-time reasoning +about interface names and structure. + +This implementation provides better performance through type stability and compile-time validation, +but requires that interface names are known at compile time. +""" +struct StaticInterfaces{I} <: AbstractInterfaces end + +StaticInterfaces(I::Tuple) = StaticInterfaces{I}() +Base.getindex(::StaticInterfaces{I}, index) where {I} = I[index] + +function Base.convert(::Type{NamedTuple}, ::StaticInterfaces{I}, t::Tuple) where {I} + return NamedTuple{I}(t) +end + +""" + StaticInterfaceAliases{A} + +A type that represents a statically defined set of interface aliases for a node in a probabilistic graphical model. +The aliases are encoded in the type parameter `A` as a tuple of pairs of symbols, where each pair maps an alias +to its corresponding interface name. + +This implementation provides better performance through type stability and compile-time validation, +but requires that interface aliases are known at compile time. +""" +struct StaticInterfaceAliases{A} <: AbstractInterfaceAliases end + +StaticInterfaceAliases(A::Tuple) = StaticInterfaceAliases{A}() + +interface_aliases(model::AbstractModel, fform::F, interfaces::StaticInterfaces) where {F} = + interface_aliases(interface_aliases(model, fform), interfaces) + +function interface_aliases(::StaticInterfaceAliases{aliases}, ::StaticInterfaces{interfaces}) where {aliases, interfaces} + return StaticInterfaces( + reduce(aliases; init = interfaces) do acc, alias + from, to = alias + return replace(acc, from => to) + end + ) +end + +""" + missing_interfaces(node_type, val, known_interfaces) + +Returns the interfaces that are missing for a node. This is used when inferring the interfaces for a node that is composite. + +# Arguments +- `node_type`: The type of the node as a Function object. +- `val`: The value of the amount of interfaces the node is supposed to have. This is a `Static.StaticInt` object. +- `known_interfaces`: The known interfaces for the node. + +# Returns +- `missing_interfaces`: A `Vector` of the missing interfaces. +""" +function missing_interfaces(model::AbstractModel, fform::F, val, known_interfaces::NamedTuple) where {F} + return missing_interfaces(interfaces(model, fform, val), StaticInterfaces(keys(known_interfaces))) +end + +function missing_interfaces( + ::StaticInterfaces{all_interfaces}, ::StaticInterfaces{present_interfaces} +) where {all_interfaces, present_interfaces} + return StaticInterfaces(filter(interface -> interface ∉ present_interfaces, all_interfaces)) +end + +function prepare_interfaces(model::AbstractModel, fform::F, lhs_interface, rhs_interfaces::NamedTuple) where {F} + missing_interface = missing_interfaces(model, fform, static(length(rhs_interfaces)) + static(1), rhs_interfaces) + return prepare_interfaces(missing_interface, fform, lhs_interface, rhs_interfaces) +end + +function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_interfaces::NamedTuple) where {I, F} + if !(length(I) == 1) + error( + lazy"Expected only one missing interface, got $I of length $(length(I)) (node $fform with interfaces $(keys(rhs_interfaces)))" + ) + end + missing_interface = first(I) + return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...)) +end + +function materialize_interface(model, context, interface) + return getifcreated(model, context, unroll(interface)) +end + +function materialze_interfaces(model, context, interfaces) + return map(interface -> materialize_interface(model, context, interface), interfaces) +end + +function sort_interfaces(model::AbstractModel, fform::F, defined_interfaces::NamedTuple) where {F} + return sort_interfaces(interfaces(model, fform, static(length(defined_interfaces))), defined_interfaces) +end + +function sort_interfaces(::StaticInterfaces{I}, defined_interfaces::NamedTuple) where {I} + return defined_interfaces[I] +end + +function materialize_factor_node!( + model::AbstractModel, context::Context, options::NodeCreationOptions, fform::F, interfaces::NamedTuple +) where {F} + factor_node_id, factor_node_data, factor_node_properties = add_atomic_factor_node!(model, context, options, fform) + foreach(pairs(interfaces)) do (interface_name, interface) + add_edge!(model, factor_node_id, factor_node_properties, interface, interface_name) + end + return factor_node_id, factor_node_data, factor_node_properties +end + +# maybe change name +is_nodelabel(x) = false +is_nodelabel(x::AbstractArray) = any(element -> is_nodelabel(element), x) +is_nodelabel(x::GraphPPL.NodeLabel) = true +is_nodelabel(x::ProxyLabel) = true +is_nodelabel(x::AbstractVariableReference) = true + +function contains_nodelabel(collection::Tuple) + return any(element -> is_nodelabel(element), collection) ? True() : False() +end + +function contains_nodelabel(collection::NamedTuple) + return any(element -> is_nodelabel(element), values(collection)) ? True() : False() +end + +function contains_nodelabel(collection::MixedArguments) + return contains_nodelabel(collection.args) | contains_nodelabel(collection.kwargs) +end + +# TODO improve documentation + +function make_node!(model::AbstractModel, ctx::Context, fform::F, lhs_interfaces, rhs_interfaces) where {F} + return make_node!(model, ctx, EmptyNodeCreationOptions, fform, lhs_interfaces, rhs_interfaces) +end + +make_node!(model::AbstractModel, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} = + make_node!(NodeType(model, fform), model, ctx, options, fform, lhs_interface, rhs_interfaces) + +# if it is composite, we assume it should be materialized and it is stochastic +# TODO: shall we not assume that the `Composite` node is necessarily stochastic? +make_node!( + nodetype::Composite, model::AbstractModel, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces +) where {F} = make_node!(True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces) + +# If a node is an object and not a function, we materialize it as a stochastic atomic node +make_node!(model::AbstractModel, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces::Nothing) where {F} = + make_node!(True(), Atomic(), Stochastic(), model, ctx, options, fform, lhs_interface, NamedTuple{}()) + +# If node is Atomic, check stochasticity +make_node!(::Atomic, model::AbstractModel, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} = + make_node!(Atomic(), NodeBehaviour(model, fform), model, ctx, options, fform, lhs_interface, rhs_interfaces) + +#If a node is deterministic, we check if there are any NodeLabel objects in the rhs_interfaces (direct check if node should be materialized) +make_node!( + atomic::Atomic, + deterministic::Deterministic, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface, + rhs_interfaces +) where {F} = + make_node!(contains_nodelabel(rhs_interfaces), atomic, deterministic, model, ctx, options, fform, lhs_interface, rhs_interfaces) + +# If the node should not be materialized (if it's Atomic, Deterministic and contains no NodeLabel objects), we return the `fform` evaluated at the interfaces +# This works only if the `lhs_interface` is `AnonymousVariable` (or the corresponding `ProxyLabel` with `AnonymousVariable` as the proxied variable) +__evaluate_fform(fform::F, args::Tuple) where {F} = fform(args...) +__evaluate_fform(fform::F, args::NamedTuple) where {F} = fform(; args...) +__evaluate_fform(fform::F, args::MixedArguments) where {F} = fform(args.args...; args.kwargs...) + +make_node!( + ::False, + ::Atomic, + ::Deterministic, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::Union{AnonymousVariable, ProxyLabel{<:T, <:AnonymousVariable} where {T}}, + rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments} +) where {F} = (nothing, __evaluate_fform(fform, rhs_interfaces)) + +# In case if the `lhs_interface` is something else we throw an error saying that `fform` cannot be instantiated since +# arguments are not stochastic and the `fform` is not stochastic either, thus the usage of `~` is invalid +make_node!( + ::False, + ::Atomic, + ::Deterministic, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface, + rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments} +) where {F} = error("`$(fform)` cannot be used as a factor node. Both the arguments and the node are not stochastic.") + +# If a node is Stochastic, we always materialize. +make_node!( + ::Atomic, ::Stochastic, model::AbstractModel, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces +) where {F} = make_node!(True(), Atomic(), Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces) + +function make_node!( + materialize::True, + node_type::NodeType, + behaviour::NodeBehaviour, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::AnonymousVariable, + rhs_interfaces +) where {F} + (noderequired, lhs_materialized) = materialize_anonymous_variable!(lhs_interface, fform, rhs_interfaces)::Tuple{Bool, NodeLabel} + node_materialized = if noderequired + node, _ = make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_materialized, rhs_interfaces) + node + else + nothing + end + return node_materialized, lhs_materialized +end + +# If we have to materialize but the rhs_interfaces argument is not a NamedTuple, we convert it +make_node!( + materialize::True, + node_type::NodeType, + behaviour::NodeBehaviour, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::Union{NodeLabel, ProxyLabel, AbstractVariableReference}, + rhs_interfaces::Tuple +) where {F} = make_node!( + materialize, + node_type, + behaviour, + model, + ctx, + options, + fform, + lhs_interface, + GraphPPL.default_parametrization(model, node_type, fform, rhs_interfaces) +) + +make_node!( + ::True, + node_type::NodeType, + behaviour::NodeBehaviour, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::Union{NodeLabel, ProxyLabel, AbstractVariableReference}, + rhs_interfaces::MixedArguments +) where {F} = error("MixedArguments not supported for rhs_interfaces when node has to be materialized") + +make_node!( + materialize::True, + node_type::Composite, + behaviour::Stochastic, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::Union{NodeLabel, ProxyLabel, AbstractVariableReference}, + rhs_interfaces::Tuple{} +) where {F} = make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_interface, NamedTuple{}()) + +make_node!( + materialize::True, + node_type::Composite, + behaviour::Stochastic, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::Union{NodeLabel, ProxyLabel, AbstractVariableReference}, + rhs_interfaces::Tuple +) where {F} = error(lazy"Composite node $fform cannot should be called with explicitly naming the interface names") + +make_node!( + materialize::True, + node_type::Composite, + behaviour::Stochastic, + model::AbstractModel, + ctx::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::Union{NodeLabel, ProxyLabel, AbstractVariableReference}, + rhs_interfaces::NamedTuple +) where {F} = make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1)) + +""" + make_node! + +Make a new factor node in the AbstractModel and specified Context, attach it to the specified interfaces, and return the interface that is on the lhs of the `~` operator. + +# Arguments +- `model::AbstractModel`: The model to add the node to. +- `ctx::Context`: The context in which to add the node. +- `fform`: The function that the node represents. +- `lhs_interface`: The interface that is on the lhs of the `~` operator. +- `rhs_interfaces`: The interfaces that are the arguments of fform on the rhs of the `~` operator. +- `__parent_options__::NamedTuple = nothing`: The options to attach to the node. +- `__debug__::Bool = false`: Whether to attach debug information to the factor node. +""" +function make_node!( + materialize::True, + node_type::Atomic, + behaviour::NodeBehaviour, + model::AbstractModel, + context::Context, + options::NodeCreationOptions, + fform::F, + lhs_interface::Union{NodeLabel, ProxyLabel, AbstractVariableReference}, + rhs_interfaces::NamedTuple +) where {F} + aliased_rhs_interfaces = convert( + NamedTuple, interface_aliases(model, fform, StaticInterfaces(keys(rhs_interfaces))), values(rhs_interfaces) + ) + aliased_fform = factor_alias(model, fform, StaticInterfaces(keys(aliased_rhs_interfaces))) + prepared_interfaces = prepare_interfaces(model, aliased_fform, lhs_interface, aliased_rhs_interfaces) + sorted_interfaces = sort_interfaces(model, aliased_fform, prepared_interfaces) + interfaces = materialze_interfaces(model, context, sorted_interfaces) + nodeid, _, _ = materialize_factor_node!(model, context, options, aliased_fform, interfaces) + return nodeid, unroll(lhs_interface) +end + +function add_terminated_submodel!(model::AbstractModel, context::Context, fform, interfaces::NamedTuple) + return add_terminated_submodel!(model, context, NodeCreationOptions((; created_by = () -> :($QuoteNode(fform)))), fform, interfaces) +end + +function add_terminated_submodel!(model::AbstractModel, context::Context, options::NodeCreationOptions, fform, interfaces::NamedTuple) + returnval = add_terminated_submodel!(model, context, options, fform, interfaces, static(length(interfaces))) + returnval!(context, returnval) + return returnval +end + +""" +Add the `fform` as the toplevel model to the `model` and `context` with the specified `interfaces`. +Calls the postprocess logic for the attached plugins of the model. Should be called only once for a given `AbstractModel` object. +""" +function add_toplevel_model! end + +function add_toplevel_model!(model::AbstractModel, fform, interfaces) + return add_toplevel_model!(model, getcontext(model), fform, interfaces) +end + +function add_toplevel_model!(model::AbstractModel, context::Context, fform, interfaces) + add_terminated_submodel!(model, context, fform, interfaces) + foreach(getplugins(model)) do plugin + postprocess_plugin(plugin, model) + end + return model +end \ No newline at end of file diff --git a/src/nodes/node_properties.jl b/src/nodes/node_properties.jl new file mode 100644 index 00000000..8f532c24 --- /dev/null +++ b/src/nodes/node_properties.jl @@ -0,0 +1,92 @@ +""" + VariableNodeProperties(name, index, kind, link, value) + +Data associated with a variable node in a probabilistic graphical model. +""" +struct VariableNodeProperties + name::Symbol + index::Any + kind::Symbol + link::Any + value::Any +end + +VariableNodeProperties(; name, index, kind = VariableKindRandom, link = nothing, value = nothing) = + VariableNodeProperties(name, index, kind, link, value) + +is_factor(::VariableNodeProperties) = false +is_variable(::VariableNodeProperties) = true + +function Base.convert(::Type{VariableNodeProperties}, name::Symbol, index, options::NodeCreationOptions) + return VariableNodeProperties( + name = name, + index = index, + kind = get(options, :kind, VariableKindRandom), + link = get(options, :link, nothing), + value = get(options, :value, nothing) + ) +end + +getname(properties::VariableNodeProperties) = properties.name +getlink(properties::VariableNodeProperties) = properties.link +index(properties::VariableNodeProperties) = properties.index +value(properties::VariableNodeProperties) = properties.value + +"Defines a `random` (or `latent`) kind for a variable in a probabilistic graphical model." +const VariableKindRandom = :random +"Defines a `data` kind for a variable in a probabilistic graphical model." +const VariableKindData = :data +"Defines a `constant` kind for a variable in a probabilistic graphical model." +const VariableKindConstant = :constant +"Placeholder for a variable kind in a probabilistic graphical model." +const VariableKindUnknown = :unknown + +is_kind(properties::VariableNodeProperties, kind) = properties.kind === kind +is_kind(properties::VariableNodeProperties, ::Val{kind}) where {kind} = properties.kind === kind +is_random(properties::VariableNodeProperties) = is_kind(properties, Val(VariableKindRandom)) +is_data(properties::VariableNodeProperties) = is_kind(properties, Val(VariableKindData)) +is_constant(properties::VariableNodeProperties) = is_kind(properties, Val(VariableKindConstant)) + +const VariableNameAnonymous = :anonymous_var_graphppl + +is_anonymous(properties::VariableNodeProperties) = properties.name === VariableNameAnonymous + +function Base.show(io::IO, properties::VariableNodeProperties) + print(io, "name = ", properties.name, ", index = ", properties.index) + if !isnothing(properties.link) + print(io, ", linked to ", properties.link) + end +end + +""" + FactorNodeProperties(fform, neighbours) + +Data associated with a factor node in a probabilistic graphical model. +""" +struct FactorNodeProperties{D} + fform::Any + neighbors::Vector{Tuple{NodeLabel, EdgeLabel, D}} +end + +FactorNodeProperties(; fform, neighbors = Tuple{NodeLabel, EdgeLabel, NodeData}[]) = FactorNodeProperties(fform, neighbors) + +is_factor(::FactorNodeProperties) = true +is_variable(::FactorNodeProperties) = false + +function Base.convert(::Type{FactorNodeProperties}, fform, options::NodeCreationOptions) + return FactorNodeProperties(fform = fform, neighbors = get(options, :neighbors, Tuple{NodeLabel, EdgeLabel, NodeData}[])) +end + +getname(properties::FactorNodeProperties) = string(properties.fform) +prettyname(properties::FactorNodeProperties) = prettyname(properties.fform) +prettyname(fform::Any) = string(fform) # Can be overloaded for custom pretty names + +fform(properties::FactorNodeProperties) = properties.fform +neighbors(properties::FactorNodeProperties) = properties.neighbors +addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel, data) = + push!(properties.neighbors, (variable, edge, data)) +neighbor_data(properties::FactorNodeProperties) = Iterators.map(neighbor -> neighbor[3], neighbors(properties)) + +function Base.show(io::IO, properties::FactorNodeProperties) + print(io, "fform = ", properties.fform, ", neighbors = ", properties.neighbors) +end \ No newline at end of file diff --git a/src/plugins/plugin_processing.jl b/src/plugins/plugin_processing.jl new file mode 100644 index 00000000..2dd61dc6 --- /dev/null +++ b/src/plugins/plugin_processing.jl @@ -0,0 +1,37 @@ +""" +A trait object for plugins that add extra functionality for factor nodes. +""" +struct FactorNodePlugin <: AbstractPluginTraitType end + +""" +A trait object for plugins that add extra functionality for variable nodes. +""" +struct VariableNodePlugin <: AbstractPluginTraitType end + +""" +A trait object for plugins that add extra functionality both for factor and variable nodes. +""" +struct FactorAndVariableNodesPlugin <: AbstractPluginTraitType end + +""" + preprocess_plugin(plugin, model, context, label, nodedata, options) + +Call a plugin specific logic for a node with label and nodedata upon their creation. +""" +function preprocess_plugin end + +""" + postprocess_plugin(plugin, model) + +Calls a plugin specific logic after the model has been created. By default does nothing. +""" +postprocess_plugin(plugin, model) = nothing + +function preprocess_plugins( + type::AbstractPluginTraitType, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options +)::Tuple{NodeLabel, NodeData} + plugins = filter(type, getplugins(model)) + return foldl(plugins; init = (label, nodedata)) do (label, nodedata), plugin + return preprocess_plugin(plugin, model, context, label, nodedata, options)::Tuple{NodeLabel, NodeData} + end::Tuple{NodeLabel, NodeData} +end \ No newline at end of file diff --git a/src/plugins_collection.jl b/src/plugins/plugins_collection.jl similarity index 100% rename from src/plugins_collection.jl rename to src/plugins/plugins_collection.jl diff --git a/src/utils/macro_utils.jl b/src/utils/macro_utils.jl new file mode 100644 index 00000000..5b4b1edc --- /dev/null +++ b/src/utils/macro_utils.jl @@ -0,0 +1,66 @@ +import MacroTools: postwalk, prewalk, @capture, walk +using NamedTupleTools +using Static + +__guard_f(f, e::Expr) = f(e) +__guard_f(f, x) = x + +struct guarded_walk{f} + guard::f +end + +function (w::guarded_walk)(f, x) + return w.guard(x) ? x : walk(x, x -> w(f, x), f) +end + +struct walk_until_occurrence{E} + patterns::E +end + +not_enter_indexed_walk = guarded_walk((x) -> (x isa Expr && x.head == :ref) || (x isa Expr && x.head == :call && x.args[1] == :new)) +not_created_by = guarded_walk((x) -> (x isa Expr && !isempty(x.args) && x.args[1] == :created_by)) + +function (w::walk_until_occurrence{E})(f, x) where {E <: Tuple} + return walk(x, z -> any(pattern -> @capture(x, $(pattern)), w.patterns) ? z : w(f, z), f) +end + +function (w::walk_until_occurrence{E})(f, x) where {E <: Expr} + return walk(x, z -> @capture(x, $(w.patterns)) ? z : w(f, z), f) +end + +what_walk(anything) = postwalk + +""" + apply_pipeline(e::Expr, pipeline) + +Apply a pipeline function to an expression. + +The `apply_pipeline` function takes an expression `e` and a `pipeline` function and applies the function in the pipeline to `e` when walking over it. The walk utilized can be specified by implementing `what_walk` for a pipeline funciton. + +# Arguments +- `e::Expr`: An expression to apply the pipeline to. +- `pipeline`: A function to apply to the expressions in `e`. + +# Returns +The result of applying the pipeline function to `e`. +""" +function apply_pipeline(e::Expr, pipeline::F) where {F} + walk = what_walk(pipeline) + return walk(x -> __guard_f(pipeline, x), e) +end + +""" + apply_pipeline_collection(e::Expr, collection) + +Similar to [`apply_pipeline`](@ref), but applies a collection of pipeline functions to an expression. + +# Arguments +- `e::Expr`: An expression to apply the pipeline to. +- `collection`: A collection of functions to apply to the expressions in `e`. + +# Returns +The result of applying the pipeline function to `e`. +""" +function apply_pipeline_collection(e::Expr, collection) + return reduce((e, pipeline) -> apply_pipeline(e, pipeline), collection, init = e) +end \ No newline at end of file diff --git a/test/testutils.jl b/test/testutils.jl index 10759e55..a349a180 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -28,7 +28,7 @@ end # We use a custom backend for testing purposes, instead of using the `DefaultBackend` # The `TestGraphPPLBackend` is a simple backend that specifies how to handle objects from `Distributions.jl` # It does use the default pipeline collection for the `@model` macro -struct TestGraphPPLBackend end +struct TestGraphPPLBackend <: GraphPPL.AbstractBackend end GraphPPL.model_macro_interior_pipelines(::TestGraphPPLBackend) = GraphPPL.model_macro_interior_pipelines(GraphPPL.DefaultBackend()) From f272a32f3ca7fe15299ccf468ded0196fc5d946c Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Tue, 6 May 2025 21:00:07 +0200 Subject: [PATCH 2/6] Update cursorrules --- .cursor/rules/graphppl.mdc | 34 ++++++++++++++++++++++++++++++++++ .cursor/rules/julia.mdc | 1 + 2 files changed, 35 insertions(+) create mode 100644 .cursor/rules/graphppl.mdc diff --git a/.cursor/rules/graphppl.mdc b/.cursor/rules/graphppl.mdc new file mode 100644 index 00000000..57a7af82 --- /dev/null +++ b/.cursor/rules/graphppl.mdc @@ -0,0 +1,34 @@ +--- +description: +globs: +alwaysApply: true +--- +# GraphPPL.jl Overview + +GraphPPL.jl is a probabilistic programming language focused on probabilistic graphical models. It materializes models as factor graphs and provides tools for model specification. + +## Main Components + +The main entry point is [src/GraphPPL.jl](mdc:src/GraphPPL.jl), which includes all the core modules. + +### Core Structure +- **Core**: Basic functionality and interfaces +- **Graph**: Factor graph representation and manipulation +- **Macros**: DSL for model specification via `@model` macro +- **Model**: Model representation classes +- **Nodes**: Graph node representation and behavior +- **Generators**: Code generation tools +- **Plugins**: Extension system for backend integration +- **Utils**: Helper functions and utilities + +### Testing +Tests are organized in the [test](mdc:test) directory, mirroring the source code structure. The test suite uses ReTestItems with `@testitem` blocks for self-contained tests. The [test/runtests.jl](mdc:test/runtests.jl) file is the entry point for running all tests. + +### Benchmarking +Performance benchmarks are in the [benchmark](mdc:benchmark) directory, using BenchmarkTools.jl. Run all benchmarks using [benchmark/benchmarks.jl](mdc:benchmark/benchmarks.jl). + +## Naming Conventions +- Use snake_case for function and variable names +- Use PascalCase for type names (structs and abstract types) +- Add comprehensive docstrings to functions and types +- Use `@kwdef` macro for structs to enable keyword constructors \ No newline at end of file diff --git a/.cursor/rules/julia.mdc b/.cursor/rules/julia.mdc index 8914defe..73c52bd4 100644 --- a/.cursor/rules/julia.mdc +++ b/.cursor/rules/julia.mdc @@ -76,6 +76,7 @@ Performance Optimization Testing - For each file in the source code create a test file with the `_tests.jl` suffix, e.g. `src/folder/subfoldeer/file.jl` -> `test/folder/subfolder/file_tests.jl` - Create small individual tests in `@testitem` blocks +- test files should only contain `@testitem` blocks and no extra code. All code within `@testitem` blocks are self-contained scopes. - Write test cases of increasing difficulty with comments explaining what is being tested. - Use individual `@test` calls for each assertion, not for blocks. - Example: From 410dbe4ea9438fe2bb663a1fa0d2a55bba2343b3 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Tue, 6 May 2025 22:23:14 +0200 Subject: [PATCH 3/6] Refactor graph engine tests --- test/graph/graph_modification_tests.jl | 264 ++ test/graph/graph_properties_tests.jl | 182 ++ test/graph/graph_traversal_filtering_tests.jl | 201 ++ test/graph/indexing_refs_tests.jl | 388 +++ test/graph_engine_tests.jl | 2670 ----------------- test/model/context_tests.jl | 252 ++ test/model/model_construction_tests.jl | 38 + test/model/model_operations_tests.jl | 646 ++++ test/nodes/node_data_tests.jl | 224 ++ test/nodes/node_label_tests.jl | 181 ++ test/nodes/node_semantics_tests.jl | 209 ++ test/plugins/plugin_lifecycle_tests.jl | 70 + 12 files changed, 2655 insertions(+), 2670 deletions(-) create mode 100644 test/graph/graph_modification_tests.jl create mode 100644 test/graph/graph_properties_tests.jl create mode 100644 test/graph/graph_traversal_filtering_tests.jl create mode 100644 test/graph/indexing_refs_tests.jl delete mode 100644 test/graph_engine_tests.jl create mode 100644 test/model/context_tests.jl create mode 100644 test/model/model_construction_tests.jl create mode 100644 test/model/model_operations_tests.jl create mode 100644 test/nodes/node_data_tests.jl create mode 100644 test/nodes/node_label_tests.jl create mode 100644 test/nodes/node_semantics_tests.jl create mode 100644 test/plugins/plugin_lifecycle_tests.jl diff --git a/test/graph/graph_modification_tests.jl b/test/graph/graph_modification_tests.jl new file mode 100644 index 00000000..6d4ff8f5 --- /dev/null +++ b/test/graph/graph_modification_tests.jl @@ -0,0 +1,264 @@ +@testitem "setindex!(::Model, ::NodeData, ::NodeLabel)" begin + using Graphs + import GraphPPL: getcontext, NodeLabel, NodeData, VariableNodeProperties, FactorNodeProperties + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + model[NodeLabel(:μ, 1)] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing)) + @test nv(model) == 1 && ne(model) == 0 + + model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) + @test nv(model) == 2 && ne(model) == 0 + + model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum)) + @test nv(model) == 3 && ne(model) == 0 + + @test_throws MethodError model[0] = 1 + @test_throws MethodError model["string"] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) + @test_throws MethodError model["string"] = NodeData(ctx, FactorNodeProperties(fform = sum)) +end + +@testitem "setindex!(::Model, ::EdgeLabel, ::NodeLabel, ::NodeLabel)" begin + using Graphs + import GraphPPL: getcontext, NodeLabel, NodeData, VariableNodeProperties, EdgeLabel + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + + μ = NodeLabel(:μ, 1) + xref = NodeLabel(:x, 2) + + model[μ] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing)) + model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) + model[μ, xref] = EdgeLabel(:interface, 1) + + @test ne(model) == 1 + @test_throws MethodError model[0, 1] = 1 + + # Test that we can't add an edge between two nodes that don't exist + model[μ, NodeLabel(:x, 100)] = EdgeLabel(:if, 1) + @test ne(model) == 1 +end + +@testitem "add_variable_node!" begin + import GraphPPL: + create_model, + add_variable_node!, + getcontext, + options, + NodeLabel, + ResizableArray, + nv, + ne, + NodeCreationOptions, + getproperties, + is_constant, + value + + include("testutils.jl") + + # Test 1: simple add variable to model + model = create_test_model() + ctx = getcontext(model) + node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, nothing) + @test nv(model) == 1 && haskey(ctx.individual_variables, :x) && ctx.individual_variables[:x] == node_id + + # Test 2: Add second variable to model + add_variable_node!(model, ctx, NodeCreationOptions(), :y, nothing) + @test nv(model) == 2 && haskey(ctx, :y) + + # Test 3: Check that adding an integer variable throws a MethodError + @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), 1) + @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), 1, 1) + + # Test 4: Add a vector variable to the model + model = create_test_model() + ctx = getcontext(model) + ctx[:x] = ResizableArray(NodeLabel, Val(1)) + node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, 2) + @test nv(model) == 1 && haskey(ctx, :x) && ctx[:x][2] == node_id && length(ctx[:x]) == 2 + + # Test 5: Add a second vector variable to the model + node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, 1) + @test nv(model) == 2 && haskey(ctx, :x) && ctx[:x][1] == node_id && length(ctx[:x]) == 2 + + # Test 6: Add a tensor variable to the model + model = create_test_model() + ctx = getcontext(model) + ctx[:x] = ResizableArray(NodeLabel, Val(2)) + node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, (2, 3)) + @test nv(model) == 1 && haskey(ctx, :x) && ctx[:x][2, 3] == node_id + + # Test 7: Add a second tensor variable to the model + node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, (2, 4)) + @test nv(model) == 2 && haskey(ctx, :x) && ctx[:x][2, 4] == node_id + + # Test 9: Add a variable with a non-integer index + model = create_test_model() + ctx = getcontext(model) + ctx[:z] = ResizableArray(NodeLabel, Val(2)) + @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, "a") + @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, ("a", "a")) + @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, ("a", 1)) + @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, (1, "a")) + + # Test 10: Add a variable with a negative index + ctx[:x] = ResizableArray(NodeLabel, Val(1)) + @test_throws BoundsError add_variable_node!(model, ctx, NodeCreationOptions(), :x, -1) + + # Test 11: Add a variable with options + model = create_test_model() + ctx = getcontext(model) + var = add_variable_node!(model, ctx, NodeCreationOptions(kind = :constant, value = 1.0), :x, nothing) + @test nv(model) == 1 && + haskey(ctx, :x) && + ctx[:x] == var && + is_constant(getproperties(model[var])) && + value(getproperties(model[var])) == 1.0 + + # Test 12: Add a variable without options + model = create_test_model() + ctx = getcontext(model) + var = add_variable_node!(model, ctx, :x, nothing) + @test nv(model) == 1 && haskey(ctx, :x) && ctx[:x] == var +end + +@testitem "add_atomic_factor_node!" begin + using Distributions + using Graphs + import GraphPPL: create_model, add_atomic_factor_node!, getorcreate!, getcontext, getorcreate!, label_for, getname, NodeCreationOptions + + include("testutils.jl") + + # Test 1: Add an atomic factor node to the model + model = create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.MetaPlugin())) + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, sum) + @test model[node_id] === node_data + @test nv(model) == 2 && getname(label_for(model.graph, 2)) == sum + + # Test 2: Add a second atomic factor node to the model with the same name and assert they are different + node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, sum) + @test model[node_id] === node_data + @test nv(model) == 3 && getname(label_for(model.graph, 3)) == sum + + # Test 3: Add an atomic factor node with options + options = NodeCreationOptions((; meta = true,)) + node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, sum) + @test model[node_id] === node_data + @test nv(model) == 4 && getname(label_for(model.graph, 4)) == sum + @test GraphPPL.hasextra(node_data, :meta) + @test GraphPPL.getextra(node_data, :meta) == true + + # Test 4: Test that creating a node with an instantiated object is supported + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + prior = Normal(0, 1) + node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, prior) + @test model[node_id] === node_data + @test nv(model) == 1 && getname(label_for(model.graph, 1)) == Normal(0, 1) +end + +@testitem "add_composite_factor_node!" begin + using Graphs + import GraphPPL: create_model, add_composite_factor_node!, getcontext, to_symbol, children, add_variable_node!, Context + + include("testutils.jl") + + # Add a composite factor node to the model + model = create_test_model() + parent_ctx = getcontext(model) + child_ctx = getcontext(model) + add_variable_node!(model, child_ctx, :x, nothing) + add_variable_node!(model, child_ctx, :y, nothing) + node_id = add_composite_factor_node!(model, parent_ctx, child_ctx, :f) + @test nv(model) == 2 && + haskey(children(parent_ctx), node_id) && + children(parent_ctx)[node_id] === child_ctx && + length(child_ctx.individual_variables) == 2 + + # Add a composite factor node with a different name + node_id = add_composite_factor_node!(model, parent_ctx, child_ctx, :g) + @test nv(model) == 2 && + haskey(children(parent_ctx), node_id) && + children(parent_ctx)[node_id] === child_ctx && + length(child_ctx.individual_variables) == 2 + + # Add a composite factor node with an empty child context + empty_ctx = Context() + node_id = add_composite_factor_node!(model, parent_ctx, empty_ctx, :h) + @test nv(model) == 2 && + haskey(children(parent_ctx), node_id) && + children(parent_ctx)[node_id] === empty_ctx && + length(empty_ctx.individual_variables) == 0 +end + +@testitem "add_edge!(::Model, ::NodeLabel, ::NodeLabel, ::Symbol)" begin + import GraphPPL: + create_model, getcontext, nv, ne, NodeData, NodeLabel, EdgeLabel, add_edge!, getorcreate!, generate_nodelabel, NodeCreationOptions + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum) + y = getorcreate!(model, ctx, :y, nothing) + + add_edge!(model, xref, xproperties, y, :interface) + + @test ne(model) == 1 + + @test_throws MethodError add_edge!(model, xref, xproperties, y, 123) +end + +@testitem "add_edge!(::Model, ::NodeLabel, ::Vector{NodeLabel}, ::Symbol)" begin + import GraphPPL: create_model, getcontext, nv, ne, NodeData, NodeLabel, EdgeLabel, add_edge!, getorcreate!, NodeCreationOptions + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + y = getorcreate!(model, ctx, :y, nothing) + + variable_nodes = [getorcreate!(model, ctx, i, nothing) for i in [:a, :b, :c]] + xref, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum) + add_edge!(model, xref, xproperties, variable_nodes, :interface) + + @test ne(model) == 3 && model[variable_nodes[1], xref] == EdgeLabel(:interface, 1) +end + +@testitem "prune!(m::Model)" begin + using Graphs + import GraphPPL: create_model, getcontext, getorcreate!, prune!, create_model, getorcreate!, add_edge!, NodeCreationOptions + + include("testutils.jl") + + # Test 1: Prune a node with no edges + model = create_test_model() + ctx = getcontext(model) + xref = getorcreate!(model, ctx, :x, nothing) + prune!(model) + @test nv(model) == 0 + + # Test 2: Prune two nodes + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + y, ydata, yproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum) + zref = getorcreate!(model, ctx, :z, nothing) + w = getorcreate!(model, ctx, :w, nothing) + + add_edge!(model, y, yproperties, zref, :test) + prune!(model) + @test nv(model) == 2 +end \ No newline at end of file diff --git a/test/graph/graph_properties_tests.jl b/test/graph/graph_properties_tests.jl new file mode 100644 index 00000000..4a5f678f --- /dev/null +++ b/test/graph/graph_properties_tests.jl @@ -0,0 +1,182 @@ +@testitem "degree" begin + import GraphPPL: create_model, getcontext, getorcreate!, NodeCreationOptions, make_node!, degree + + include("testutils.jl") + + for n in 5:10 + model = create_test_model() + ctx = getcontext(model) + + unused = getorcreate!(model, ctx, :unusued, nothing) + xref = getorcreate!(model, ctx, :x, nothing) + y = getorcreate!(model, ctx, :y, nothing) + + foreach(1:n) do k + getorcreate!(model, ctx, :z, k) + end + + zref = getorcreate!(model, ctx, :z, 1) + + @test degree(model, unused) === 0 + @test degree(model, xref) === 0 + @test degree(model, y) === 0 + @test all(zᵢ -> degree(model, zᵢ) === 0, zref) + + for i in 1:n + make_node!(model, ctx, NodeCreationOptions(), sum, y, (in = [xref, zref[i]],)) + end + + @test degree(model, unused) === 0 + @test degree(model, xref) === n + @test degree(model, y) === n + @test all(zᵢ -> degree(model, zᵢ) === 1, zref) + end +end + +@testitem "nv_ne(::Model)" begin + import GraphPPL: create_model, getcontext, nv, ne, NodeData, VariableNodeProperties, NodeLabel, EdgeLabel + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + @test isempty(model) + @test nv(model) == 0 + @test ne(model) == 0 + + model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing)) + model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing)) + @test !isempty(model) + @test nv(model) == 2 + @test ne(model) == 0 + + model[NodeLabel(:a, 1), NodeLabel(:b, 2)] = EdgeLabel(:edge, 1) + @test !isempty(model) + @test nv(model) == 2 + @test ne(model) == 1 +end + +@testitem "edges" begin + import GraphPPL: + edges, + create_model, + getcontext, + getproperties, + NodeData, + VariableNodeProperties, + FactorNodeProperties, + NodeLabel, + EdgeLabel, + getname, + add_edge!, + has_edge, + getproperties + + include("testutils.jl") + + # Test 1: Test getting all edges from a model + model = create_test_model() + ctx = getcontext(model) + a = NodeLabel(:a, 1) + b = NodeLabel(:b, 2) + model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing)) + model[b] = NodeData(ctx, FactorNodeProperties(fform = sum)) + @test !has_edge(model, a, b) + @test !has_edge(model, b, a) + add_edge!(model, b, getproperties(model[b]), a, :edge, 1) + @test has_edge(model, a, b) + @test has_edge(model, b, a) + @test length(edges(model)) == 1 + + c = NodeLabel(:c, 2) + model[c] = NodeData(ctx, FactorNodeProperties(fform = sum)) + @test !has_edge(model, a, c) + @test !has_edge(model, c, a) + add_edge!(model, c, getproperties(model[c]), a, :edge, 2) + @test has_edge(model, a, c) + @test has_edge(model, c, a) + + @test length(edges(model)) == 2 + + # Test 2: Test getting all edges from a model with a specific node + @test getname.(edges(model, a)) == [:edge, :edge] + @test getname.(edges(model, b)) == [:edge] + @test getname.(edges(model, c)) == [:edge] + # @test getname.(edges(model, [a, b])) == [:edge, :edge, :edge] +end + +@testitem "neighbors(::Model, ::NodeData)" begin + import GraphPPL: + create_model, + getcontext, + neighbors, + NodeData, + VariableNodeProperties, + FactorNodeProperties, + NodeLabel, + EdgeLabel, + getname, + ResizableArray, + add_edge!, + getproperties + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_test_model() + ctx = getcontext(model) + + a = NodeLabel(:a, 1) + b = NodeLabel(:b, 2) + model[a] = NodeData(ctx, FactorNodeProperties(fform = sum)) + model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing)) + add_edge!(model, a, getproperties(model[a]), b, :edge, 1) + @test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)] + + model = create_test_model() + ctx = getcontext(model) + a = ResizableArray(NodeLabel, Val(1)) + b = ResizableArray(NodeLabel, Val(1)) + for i in 1:3 + a[i] = NodeLabel(:a, i) + model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum)) + b[i] = NodeLabel(:b, i) + model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i)) + add_edge!(model, a[i], getproperties(model[a[i]]), b[i], :edge, i) + end + for n in b + @test n ∈ neighbors(model, a) + end + # Test 2: Test getting sorted neighbors + model = create_model(simple_model()) + ctx = getcontext(model) + node = first(neighbors(model, ctx[:z])) # Normal node we're investigating is the only neighbor of `z` in the graph. + @test getname.(neighbors(model, node)) == [:z, :x, :y] + + # Test 3: Test getting sorted neighbors when one of the edge indices is nothing + model = create_model(vector_model()) + ctx = getcontext(model) + node = first(neighbors(model, ctx[:z][1])) + @test getname.(collect(neighbors(model, node))) == [:z, :x, :y] +end + +@testitem "save and load graph" begin + import GraphPPL: create_model, with_plugins, savegraph, loadgraph, getextra, as_node + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_model(with_plugins(vector_model(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + mktemp() do file, io + file = file * ".jld2" + savegraph(file, model) + model2 = loadgraph(file, GraphPPL.Model) + for (node, node2) in zip(filter(as_node(), model), filter(as_node(), model2)) + @test node == node2 + @test GraphPPL.getextra(model[node], :factorization_constraint_bitset) == + GraphPPL.getextra(model2[node2], :factorization_constraint_bitset) + end + end +end \ No newline at end of file diff --git a/test/graph/graph_traversal_filtering_tests.jl b/test/graph/graph_traversal_filtering_tests.jl new file mode 100644 index 00000000..0cb4a0c7 --- /dev/null +++ b/test/graph/graph_traversal_filtering_tests.jl @@ -0,0 +1,201 @@ +@testitem "factor_nodes" begin + import GraphPPL: create_model, factor_nodes, is_factor, labels + + include("testutils.jl") + + using .TestUtils.ModelZoo + + for modelfn in ModelsInTheZooWithoutArguments + model = create_model(modelfn()) + fnodes = collect(factor_nodes(model)) + for node in fnodes + @test is_factor(model[node]) + end + for label in labels(model) + if is_factor(model[label]) + @test label ∈ fnodes + end + end + end +end + +@testitem "factor_nodes with lambda function" begin + import GraphPPL: create_model, factor_nodes, is_factor, labels + + include("testutils.jl") + + using .TestUtils.ModelZoo + + for model_fn in ModelsInTheZooWithoutArguments + model = create_model(model_fn()) + fnodes = collect(factor_nodes(model)) + factor_nodes(model) do label, nodedata + @test is_factor(model[label]) + @test is_factor(nodedata) + @test model[label] === nodedata + @test label ∈ labels(model) + @test label ∈ fnodes + + clength = length(fnodes) + filter!(n -> n !== label, fnodes) + @test length(fnodes) === clength - 1 # Only one should be removed + end + @test length(fnodes) === 0 # all should be processed + end +end + +@testitem "variable_nodes" begin + import GraphPPL: create_model, variable_nodes, is_variable, labels + + include("testutils.jl") + + using .TestUtils.ModelZoo + + for model_fn in ModelsInTheZooWithoutArguments + model = create_model(model_fn()) + fnodes = collect(variable_nodes(model)) + for node in fnodes + @test is_variable(model[node]) + end + for label in labels(model) + if is_variable(model[label]) + @test label ∈ fnodes + end + end + end +end + +@testitem "variable_nodes with lambda function" begin + import GraphPPL: create_model, variable_nodes, is_variable, labels + + include("testutils.jl") + + using .TestUtils.ModelZoo + + for model_fn in ModelsInTheZooWithoutArguments + model = create_model(model_fn()) + fnodes = collect(variable_nodes(model)) + variable_nodes(model) do label, nodedata + @test is_variable(model[label]) + @test is_variable(nodedata) + @test model[label] === nodedata + @test label ∈ labels(model) + @test label ∈ fnodes + + clength = length(fnodes) + filter!(n -> n !== label, fnodes) + @test length(fnodes) === clength - 1 # Only one should be removed + end + @test length(fnodes) === 0 # all should be processed + end +end + +@testitem "variable_nodes with anonymous variables" begin + # The idea here is that the `variable_nodes` must return ALL anonymous variables as well + using Distributions + import GraphPPL: create_model, variable_nodes, getname, is_anonymous, getproperties + + include("testutils.jl") + + @model function simple_submodel_with_2_anonymous_for_variable_nodes(z, x, y) + # Creates two anonymous variables here + z ~ Normal(x + 1, y - 1) + end + + @model function simple_submodel_with_3_anonymous_for_variable_nodes(z, x, y) + # Creates three anonymous variables here + z ~ Normal(x + 1, y - 1 + 1) + end + + @model function simple_model_for_variable_nodes(submodel) + xref ~ Normal(0, 1) + y ~ Gamma(1, 1) + zref ~ submodel(x = xref, y = y) + end + + @testset let submodel = simple_submodel_with_2_anonymous_for_variable_nodes + model = create_model(simple_model_for_variable_nodes(submodel = submodel)) + @test length(collect(variable_nodes(model))) === 11 + @test length(collect(filter(v -> is_anonymous(getproperties(model[v])), collect(variable_nodes(model))))) === 2 + end + + @testset let submodel = simple_submodel_with_3_anonymous_for_variable_nodes + model = create_model(simple_model_for_variable_nodes(submodel = submodel)) + @test length(collect(variable_nodes(model))) === 13 # +1 for new anonymous +1 for new constant + @test length(collect(filter(v -> is_anonymous(getproperties(model[v])), collect(variable_nodes(model))))) === 3 + end +end + +@testitem "filter(::Predicate, ::Model)" begin + import GraphPPL: create_model, as_node, as_context, as_variable + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_model(simple_model()) + result = collect(filter(as_node(Normal) | as_variable(:x), model)) + @test length(result) == 3 + + model = create_model(outer()) + result = collect(filter(as_node(Gamma) & as_context(inner_inner), model)) + @test length(result) == 0 + + result = collect(filter(as_node(Gamma) | as_context(inner_inner), model)) + @test length(result) == 6 + + result = collect(filter(as_node(Normal) & as_context(inner_inner; children = true), model)) + @test length(result) == 1 +end + +@testitem "filter(::FactorNodePredicate, ::Model)" begin + import GraphPPL: create_model, as_node, getcontext + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_model(simple_model()) + context = getcontext(model) + result = filter(as_node(Normal), model) + @test collect(result) == [context[NormalMeanVariance, 1], context[NormalMeanVariance, 2]] + result = filter(as_node(), model) + @test collect(result) == [context[NormalMeanVariance, 1], context[GammaShapeScale, 1], context[NormalMeanVariance, 2]] +end + +@testitem "filter(::VariableNodePredicate, ::Model)" begin + import GraphPPL: create_model, as_variable, getcontext, variable_nodes + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_model(simple_model()) + context = getcontext(model) + result = filter(as_variable(:x), model) + @test collect(result) == [context[:x]...] + result = filter(as_variable(), model) + @test collect(result) == collect(variable_nodes(model)) +end + +@testitem "filter(::SubmodelPredicate, Model)" begin + import GraphPPL: create_model, as_context + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_model(outer()) + + result = filter(as_context(inner), model) + @test length(collect(result)) == 0 + + result = filter(as_context(inner; children = true), model) + @test length(collect(result)) == 1 + + result = filter(as_context(inner_inner), model) + @test length(collect(result)) == 1 + + result = filter(as_context(outer; children = true), model) + @test length(collect(result)) == 22 +end \ No newline at end of file diff --git a/test/graph/indexing_refs_tests.jl b/test/graph/indexing_refs_tests.jl new file mode 100644 index 00000000..71206749 --- /dev/null +++ b/test/graph/indexing_refs_tests.jl @@ -0,0 +1,388 @@ +@testitem "IndexedVariable" begin + import GraphPPL: IndexedVariable, CombinedRange, SplittedRange, getname, index + + # Test 1: Test IndexedVariable + @test IndexedVariable(:x, nothing) isa IndexedVariable + + # Test 2: Test IndexedVariable equality + lhs = IndexedVariable(:x, nothing) + rhs = IndexedVariable(:x, nothing) + @test lhs == rhs + @test lhs === rhs + @test lhs != IndexedVariable(:y, nothing) + @test lhs !== IndexedVariable(:y, nothing) + @test getname(IndexedVariable(:x, nothing)) === :x + @test getname(IndexedVariable(:x, 1)) === :x + @test getname(IndexedVariable(:y, nothing)) === :y + @test getname(IndexedVariable(:y, 1)) === :y + @test index(IndexedVariable(:x, nothing)) === nothing + @test index(IndexedVariable(:x, 1)) === 1 + @test index(IndexedVariable(:y, nothing)) === nothing + @test index(IndexedVariable(:y, 1)) === 1 +end + +@testitem "FunctionalIndex" begin + import GraphPPL: FunctionalIndex + + collection = [1, 2, 3, 4, 5] + + # Test 1: Test FunctionalIndex{:begin} + index = FunctionalIndex{:begin}(firstindex) + @test index(collection) === firstindex(collection) + + # Test 2: Test FunctionalIndex{:end} + index = FunctionalIndex{:end}(lastindex) + @test index(collection) === lastindex(collection) + + # Test 3: Test FunctionalIndex{:begin} + 1 + index = FunctionalIndex{:begin}(firstindex) + 1 + @test index(collection) === firstindex(collection) + 1 + + # Test 4: Test FunctionalIndex{:end} - 1 + index = FunctionalIndex{:end}(lastindex) - 1 + @test index(collection) === lastindex(collection) - 1 + + # Test 5: Test FunctionalIndex equality + lhs = FunctionalIndex{:begin}(firstindex) + rhs = FunctionalIndex{:begin}(firstindex) + @test lhs == rhs + @test lhs === rhs + @test lhs != FunctionalIndex{:end}(lastindex) + @test lhs !== FunctionalIndex{:end}(lastindex) + + for N in 1:5 + collection = ones(N) + @test FunctionalIndex{:nothing}(firstindex)(collection) === firstindex(collection) + @test FunctionalIndex{:nothing}(lastindex)(collection) === lastindex(collection) + @test (FunctionalIndex{:nothing}(firstindex) + 1)(collection) === firstindex(collection) + 1 + @test (FunctionalIndex{:nothing}(lastindex) - 1)(collection) === lastindex(collection) - 1 + @test (FunctionalIndex{:nothing}(firstindex) + 1 - 2 + 3)(collection) === firstindex(collection) + 1 - 2 + 3 + @test (FunctionalIndex{:nothing}(lastindex) - 1 + 2 - 3)(collection) === lastindex(collection) - 1 + 2 - 3 + end + + @test repr(FunctionalIndex{:begin}(firstindex)) === "(begin)" + @test repr(FunctionalIndex{:begin}(firstindex) + 1) === "((begin) + 1)" + @test repr(FunctionalIndex{:begin}(firstindex) - 1) === "((begin) - 1)" + @test repr(FunctionalIndex{:begin}(firstindex) - 1 + 1) === "(((begin) - 1) + 1)" + + @test repr(FunctionalIndex{:end}(lastindex)) === "(end)" + @test repr(FunctionalIndex{:end}(lastindex) + 1) === "((end) + 1)" + @test repr(FunctionalIndex{:end}(lastindex) - 1) === "((end) - 1)" + @test repr(FunctionalIndex{:end}(lastindex) - 1 + 1) === "(((end) - 1) + 1)" + + @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) + 1))) + @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) - 1))) + @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) + 1 + 1))) + @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) - 1 + 1))) +end + +@testitem "FunctionalRange" begin + import GraphPPL: FunctionalIndex + + collection = [1, 2, 3, 4, 5] + + range = FunctionalIndex{:begin}(firstindex):FunctionalIndex{:end}(lastindex) + @test collection[range] == collection + + range = (FunctionalIndex{:begin}(firstindex) + 1):(FunctionalIndex{:end}(lastindex) - 1) + @test collection[range] == collection[(begin + 1):(end - 1)] + + for i in 1:length(collection) + _range = i:FunctionalIndex{:end}(lastindex) + @test collection[_range] == collection[i:end] + end + + for i in 1:length(collection) + _range = FunctionalIndex{:begin}(firstindex):i + @test collection[_range] == collection[begin:i] + end +end + +@testitem "Lift index" begin + import GraphPPL: lift_index, True, False, checked_getindex + + @test lift_index(True(), nothing, nothing) === nothing + @test lift_index(True(), (1,), nothing) === (1,) + @test lift_index(True(), nothing, (1,)) === (1,) + @test lift_index(True(), (2,), (1,)) === (2,) + @test lift_index(True(), (2, 2), (1,)) === (2, 2) + + @test lift_index(False(), nothing, nothing) === nothing + @test lift_index(False(), (1,), nothing) === nothing + @test lift_index(False(), nothing, (1,)) === (1,) + @test lift_index(False(), (2,), (1,)) === (1,) + @test lift_index(False(), (2, 2), (1,)) === (1,) + + import GraphPPL: proxylabel, lift_index, unroll, ProxyLabel + + struct LiftingTest end + + GraphPPL.is_proxied(::Type{LiftingTest}) = GraphPPL.True() + + function GraphPPL.unroll(proxy::ProxyLabel, ::LiftingTest, index, maycreate, liftedindex) + if liftedindex === nothing + return checked_getindex("Hello", index) + else + return checked_getindex("World", index) + end + end + + @test unroll(proxylabel(:x, LiftingTest(), nothing, True())) === "Hello" + @test unroll(proxylabel(:x, LiftingTest(), (1,), True())) === 'W' + @test unroll(proxylabel(:r, proxylabel(:x, proxylabel(:z, LiftingTest(), nothing), (3,), True()), nothing)) === 'r' + @test unroll( + proxylabel(:r, proxylabel(:x, proxylabel(:w, proxylabel(:z, LiftingTest(), nothing), (2:3,), True()), (1,), False()), nothing) + ) === 'o' +end + +@testitem "`VariableRef` iterators interface" begin + import GraphPPL: VariableRef, getcontext, NodeCreationOptions, VariableKindData, getorcreate! + + include("testutils.jl") + + @testset "Missing internal and external collections" begin + model = create_test_model() + ctx = getcontext(model) + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + + @test @inferred(Base.IteratorSize(xref)) === Base.SizeUnknown() + @test @inferred(Base.IteratorEltype(xref)) === Base.EltypeUnknown() + @test @inferred(Base.eltype(xref)) === Any + end + + @testset "Existing internal and external collections" begin + model = create_test_model() + ctx = getcontext(model) + xcollection = getorcreate!(model, ctx, NodeCreationOptions(), :x, 1) + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (1,), xcollection) + + @test @inferred(Base.IteratorSize(xref)) === Base.HasShape{1}() + @test @inferred(Base.IteratorEltype(xref)) === Base.HasEltype() + @test @inferred(Base.eltype(xref)) === GraphPPL.NodeLabel + end + + @testset "Missing internal but existing external collections" begin + model = create_test_model() + ctx = getcontext(model) + xref = VariableRef(model, ctx, NodeCreationOptions(kind = VariableKindData), :x, (nothing,), [1.0 1.0; 1.0 1.0]) + + @test @inferred(Base.IteratorSize(xref)) === Base.HasShape{2}() + @test @inferred(Base.IteratorEltype(xref)) === Base.HasEltype() + @test @inferred(Base.eltype(xref)) === Float64 + end +end + +@testitem "`VariableRef` in combination with `ProxyLabel` should create variables in the model" begin + import GraphPPL: + VariableRef, + makevarref, + getcontext, + getifcreated, + unroll, + set_maycreate, + ProxyLabel, + NodeLabel, + proxylabel, + NodeCreationOptions, + VariableKindRandom, + VariableKindData, + getproperties, + is_kind, + MissingCollection, + getorcreate! + + using Distributions + + include("testutils.jl") + + @testset "Individual variable creation" begin + model = create_test_model() + ctx = getcontext(model) + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) + x = unroll(proxylabel(:p, xref, nothing, True())) + @test x isa NodeLabel + @test x === ctx[:x] + @test is_kind(getproperties(model[x]), VariableKindRandom) + @test getifcreated(model, ctx, xref) === ctx[:x] + + zref = VariableRef(model, ctx, NodeCreationOptions(kind = VariableKindData), :z, (nothing,), MissingCollection()) + # @test_throws "The variable `z` has been used, but has not been instantiated." getifcreated(model, ctx, zref) + # The label above SHOULD NOT throw, since it has been instantiated with the `MissingCollection` + + # Top level `False` should not play a role here really, but is also essential + # The bottom level `True` does allow the creation of the variable and the top-level `False` should only fetch + z = unroll(proxylabel(:r, proxylabel(:w, zref, nothing, True()), nothing, False())) + @test z isa NodeLabel + @test z === ctx[:z] + @test is_kind(getproperties(model[z]), VariableKindData) + @test getifcreated(model, ctx, zref) === ctx[:z] + end + + @testset "Vectored variable creation" begin + model = create_test_model() + ctx = getcontext(model) + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) + for i in 1:10 + x = unroll(proxylabel(:x, xref, (i,), True())) + @test x isa NodeLabel + @test x === ctx[:x][i] + @test getifcreated(model, ctx, xref) === ctx[:x] + end + @test length(xref) === 10 + @test firstindex(xref) === 1 + @test lastindex(xref) === 10 + @test collect(eachindex(xref)) == collect(1:10) + @test size(xref) === (10,) + end + + @testset "Tensor variable creation" begin + model = create_test_model() + ctx = getcontext(model) + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) + for i in 1:10, j in 1:10 + xij = unroll(proxylabel(:x, xref, (i, j), True())) + @test xij isa NodeLabel + @test xij === ctx[:x][i, j] + @test getifcreated(model, ctx, xref) === ctx[:x] + end + @test length(xref) === 100 + @test firstindex(xref) === 1 + @test lastindex(xref) === 100 + @test collect(eachindex(xref)) == collect(CartesianIndices((1:10, 1:10))) + @test size(xref) === (10, 10) + end + + @testset "Variable should not be created if the `creation` flag is set to `False`" begin + model = create_test_model() + ctx = getcontext(model) + # `x` is not created here, should fail during `unroll` + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) + @test_throws "The variable `x` has been used, but has not been instantiated" unroll(proxylabel(:x, xref, nothing, False())) + # Force create `x` + getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + # Since `x` has been created the `False` flag should not throw + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + @test ctx[:x] === unroll(proxylabel(:x, xref, nothing, False())) + @test getifcreated(model, ctx, xref) === ctx[:x] + end + + @testset "Variable should be created if the `Atomic` fform is used as a first argument with `makevarref`" begin + model = create_test_model() + ctx = getcontext(model) + # `x` is not created here, but `makevarref` takes into account the `Atomic/Composite` + # we always create a variable when used with `Atomic` + xref = makevarref(Normal, model, ctx, NodeCreationOptions(), :x, (nothing,)) + # `@inferred` here is important for simple use cases like `x ~ Normal(0, 1)`, so + # `x` can be inferred properly + @test ctx[:x] === @inferred(unroll(proxylabel(:x, xref, nothing, False()))) + end + + @testset "It should be possible to toggle `maycreate` flag" begin + model = create_test_model() + ctx = getcontext(model) + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + # The first time should throw since the variable has not been instantiated yet + @test_throws "The variable `x` has been used, but has not been instantiated." unroll(proxylabel(:x, xref, nothing, False())) + # Even though the `maycreate` flag is set to `True`, the `set_maycreate` should overwrite it with `False` + @test_throws "The variable `x` has been used, but has not been instantiated." unroll( + set_maycreate(proxylabel(:x, xref, nothing, True()), False()) + ) + + # Even though the `maycreate` flag is set to `False`, the `set_maycreate` should overwrite it with `True` + @test unroll(set_maycreate(proxylabel(:x, xref, nothing, False()), True())) === ctx[:x] + # At this point the variable should be created + @test unroll(proxylabel(:x, xref, nothing, False())) === ctx[:x] + @test unroll(proxylabel(:x, xref, nothing, True())) === ctx[:x] + + @test set_maycreate(1, True()) === 1 + @test set_maycreate(1, False()) === 1 + end +end + +@testitem "`VariableRef` comparison" begin + import GraphPPL: + VariableRef, + makevarref, + getcontext, + getifcreated, + unroll, + ProxyLabel, + NodeLabel, + proxylabel, + NodeCreationOptions, + VariableKindRandom, + VariableKindData, + getproperties, + is_kind, + MissingCollection, + getorcreate! + + using Distributions + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) + @test xref == xref + @test_throws( + "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time.", + xref != 1 + ) + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 1 != + xref + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref == + 1 + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 1 == + xref + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref > + 0 + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 < + xref + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." "something" == + xref + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 10 > + xref + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref < + 10 + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 <= + xref + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref >= + 0 + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref <= + 0 + @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 >= + xref + + xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (1, 2)) + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref != + 1 + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 1 != + xref + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref == + 1 + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 1 == + xref + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref > + 0 + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 < + xref + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." "something" == + xref + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 10 > + xref + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref < + 10 + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 <= + xref + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref >= + 0 + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref <= + 0 + @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 >= + xref +end \ No newline at end of file diff --git a/test/graph_engine_tests.jl b/test/graph_engine_tests.jl deleted file mode 100644 index a0a06575..00000000 --- a/test/graph_engine_tests.jl +++ /dev/null @@ -1,2670 +0,0 @@ -@testitem "IndexedVariable" begin - import GraphPPL: IndexedVariable, CombinedRange, SplittedRange, getname, index - - # Test 1: Test IndexedVariable - @test IndexedVariable(:x, nothing) isa IndexedVariable - - # Test 2: Test IndexedVariable equality - lhs = IndexedVariable(:x, nothing) - rhs = IndexedVariable(:x, nothing) - @test lhs == rhs - @test lhs === rhs - @test lhs != IndexedVariable(:y, nothing) - @test lhs !== IndexedVariable(:y, nothing) - @test getname(IndexedVariable(:x, nothing)) === :x - @test getname(IndexedVariable(:x, 1)) === :x - @test getname(IndexedVariable(:y, nothing)) === :y - @test getname(IndexedVariable(:y, 1)) === :y - @test index(IndexedVariable(:x, nothing)) === nothing - @test index(IndexedVariable(:x, 1)) === 1 - @test index(IndexedVariable(:y, nothing)) === nothing - @test index(IndexedVariable(:y, 1)) === 1 -end - -@testitem "FunctionalIndex" begin - import GraphPPL: FunctionalIndex - - collection = [1, 2, 3, 4, 5] - - # Test 1: Test FunctionalIndex{:begin} - index = FunctionalIndex{:begin}(firstindex) - @test index(collection) === firstindex(collection) - - # Test 2: Test FunctionalIndex{:end} - index = FunctionalIndex{:end}(lastindex) - @test index(collection) === lastindex(collection) - - # Test 3: Test FunctionalIndex{:begin} + 1 - index = FunctionalIndex{:begin}(firstindex) + 1 - @test index(collection) === firstindex(collection) + 1 - - # Test 4: Test FunctionalIndex{:end} - 1 - index = FunctionalIndex{:end}(lastindex) - 1 - @test index(collection) === lastindex(collection) - 1 - - # Test 5: Test FunctionalIndex equality - lhs = FunctionalIndex{:begin}(firstindex) - rhs = FunctionalIndex{:begin}(firstindex) - @test lhs == rhs - @test lhs === rhs - @test lhs != FunctionalIndex{:end}(lastindex) - @test lhs !== FunctionalIndex{:end}(lastindex) - - for N in 1:5 - collection = ones(N) - @test FunctionalIndex{:nothing}(firstindex)(collection) === firstindex(collection) - @test FunctionalIndex{:nothing}(lastindex)(collection) === lastindex(collection) - @test (FunctionalIndex{:nothing}(firstindex) + 1)(collection) === firstindex(collection) + 1 - @test (FunctionalIndex{:nothing}(lastindex) - 1)(collection) === lastindex(collection) - 1 - @test (FunctionalIndex{:nothing}(firstindex) + 1 - 2 + 3)(collection) === firstindex(collection) + 1 - 2 + 3 - @test (FunctionalIndex{:nothing}(lastindex) - 1 + 2 - 3)(collection) === lastindex(collection) - 1 + 2 - 3 - end - - @test repr(FunctionalIndex{:begin}(firstindex)) === "(begin)" - @test repr(FunctionalIndex{:begin}(firstindex) + 1) === "((begin) + 1)" - @test repr(FunctionalIndex{:begin}(firstindex) - 1) === "((begin) - 1)" - @test repr(FunctionalIndex{:begin}(firstindex) - 1 + 1) === "(((begin) - 1) + 1)" - - @test repr(FunctionalIndex{:end}(lastindex)) === "(end)" - @test repr(FunctionalIndex{:end}(lastindex) + 1) === "((end) + 1)" - @test repr(FunctionalIndex{:end}(lastindex) - 1) === "((end) - 1)" - @test repr(FunctionalIndex{:end}(lastindex) - 1 + 1) === "(((end) - 1) + 1)" - - @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) + 1))) - @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) - 1))) - @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) + 1 + 1))) - @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) - 1 + 1))) -end - -@testitem "FunctionalRange" begin - import GraphPPL: FunctionalIndex - - collection = [1, 2, 3, 4, 5] - - range = FunctionalIndex{:begin}(firstindex):FunctionalIndex{:end}(lastindex) - @test collection[range] == collection - - range = (FunctionalIndex{:begin}(firstindex) + 1):(FunctionalIndex{:end}(lastindex) - 1) - @test collection[range] == collection[(begin + 1):(end - 1)] - - for i in 1:length(collection) - _range = i:FunctionalIndex{:end}(lastindex) - @test collection[_range] == collection[i:end] - end - - for i in 1:length(collection) - _range = FunctionalIndex{:begin}(firstindex):i - @test collection[_range] == collection[begin:i] - end -end - -@testitem "model constructor" begin - import GraphPPL: create_model, Model - - include("testutils.jl") - - @test typeof(create_test_model()) <: Model - - @test_throws MethodError Model() -end - -@testitem "NodeData constructor" begin - import GraphPPL: create_model, getcontext, NodeData, FactorNodeProperties, VariableNodeProperties, getproperties - - include("testutils.jl") - - model = create_test_model() - context = getcontext(model) - - @testset "FactorNodeProperties" begin - properties = FactorNodeProperties(fform = String) - nodedata = NodeData(context, properties) - - @test getcontext(nodedata) === context - @test getproperties(nodedata) === properties - - io = IOBuffer() - - show(io, nodedata) - - output = String(take!(io)) - - @test !isempty(output) - @test contains(output, "String") # fform - end - - @testset "VariableNodeProperties" begin - properties = VariableNodeProperties(name = :x, index = 1) - nodedata = NodeData(context, properties) - - @test getcontext(nodedata) === context - @test getproperties(nodedata) === properties - - io = IOBuffer() - - show(io, nodedata) - - output = String(take!(io)) - - @test !isempty(output) - @test contains(output, "x") # name - @test contains(output, "1") # index - end -end - -@testitem "NodeDataExtraKey" begin - import GraphPPL: NodeDataExtraKey, getkey - - @test NodeDataExtraKey{:a, Int}() isa NodeDataExtraKey - @test NodeDataExtraKey{:a, Int}() === NodeDataExtraKey{:a, Int}() - @test NodeDataExtraKey{:a, Int}() !== NodeDataExtraKey{:a, Float64}() - @test NodeDataExtraKey{:a, Int}() !== NodeDataExtraKey{:b, Int}() - @test getkey(NodeDataExtraKey{:a, Int}()) === :a - @test getkey(NodeDataExtraKey{:a, Float64}()) === :a - @test getkey(NodeDataExtraKey{:b, Float64}()) === :b -end - -@testitem "NodeData extra properties" begin - import GraphPPL: - create_model, - getcontext, - NodeData, - FactorNodeProperties, - VariableNodeProperties, - getproperties, - setextra!, - getextra, - hasextra, - NodeDataExtraKey - - include("testutils.jl") - - model = create_test_model() - context = getcontext(model) - - @testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1)) - nodedata = NodeData(context, properties) - - @test !hasextra(nodedata, :a) - @test getextra(nodedata, :a, 2) === 2 - @test !hasextra(nodedata, :a) # the default should not add the extra property, only return - setextra!(nodedata, :a, 1) - @test hasextra(nodedata, :a) - @test getextra(nodedata, :a) === 1 - @test getextra(nodedata, :a, 2) === 1 - @test !hasextra(nodedata, :b) - @test_throws Exception getextra(nodedata, :b) - @test getextra(nodedata, :b, 2) === 2 - - # In the current implementation it is not possible to update extra properties - @test_throws Exception setextra!(nodedata, :a, 2) - - @test !hasextra(nodedata, :b) - setextra!(nodedata, :b, 2) - @test hasextra(nodedata, :b) - @test getextra(nodedata, :b) === 2 - - constkey_c_float = NodeDataExtraKey{:c, Float64}() - - @test !@inferred(hasextra(nodedata, constkey_c_float)) - @test @inferred(getextra(nodedata, constkey_c_float, 4.0)) === 4.0 - @inferred(setextra!(nodedata, constkey_c_float, 3.0)) - @test @inferred(hasextra(nodedata, constkey_c_float)) - @test @inferred(getextra(nodedata, constkey_c_float)) === 3.0 - @test @inferred(getextra(nodedata, constkey_c_float, 4.0)) === 3.0 - - # The default has a different type from the key (4.0 is Float and 4 is Int), thus the error - @test_throws MethodError getextra(nodedata, constkey_c_float, 4) - - constkey_d_int = NodeDataExtraKey{:d, Int64}() - - @test !@inferred(hasextra(nodedata, constkey_d_int)) - @inferred(setextra!(nodedata, constkey_d_int, 4)) - @test @inferred(hasextra(nodedata, constkey_d_int)) - @test @inferred(getextra(nodedata, constkey_d_int)) === 4 - end -end - -@testitem "factor_nodes" begin - import GraphPPL: create_model, factor_nodes, is_factor, labels - - include("testutils.jl") - - using .TestUtils.ModelZoo - - for modelfn in ModelsInTheZooWithoutArguments - model = create_model(modelfn()) - fnodes = collect(factor_nodes(model)) - for node in fnodes - @test is_factor(model[node]) - end - for label in labels(model) - if is_factor(model[label]) - @test label ∈ fnodes - end - end - end -end - -@testitem "factor_nodes with lambda function" begin - import GraphPPL: create_model, factor_nodes, is_factor, labels - - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments - model = create_model(model_fn()) - fnodes = collect(factor_nodes(model)) - factor_nodes(model) do label, nodedata - @test is_factor(model[label]) - @test is_factor(nodedata) - @test model[label] === nodedata - @test label ∈ labels(model) - @test label ∈ fnodes - - clength = length(fnodes) - filter!(n -> n !== label, fnodes) - @test length(fnodes) === clength - 1 # Only one should be removed - end - @test length(fnodes) === 0 # all should be processed - end -end - -@testitem "variable_nodes" begin - import GraphPPL: create_model, variable_nodes, is_variable, labels - - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments - model = create_model(model_fn()) - fnodes = collect(variable_nodes(model)) - for node in fnodes - @test is_variable(model[node]) - end - for label in labels(model) - if is_variable(model[label]) - @test label ∈ fnodes - end - end - end -end - -@testitem "variable_nodes with lambda function" begin - import GraphPPL: create_model, variable_nodes, is_variable, labels - - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments - model = create_model(model_fn()) - fnodes = collect(variable_nodes(model)) - variable_nodes(model) do label, nodedata - @test is_variable(model[label]) - @test is_variable(nodedata) - @test model[label] === nodedata - @test label ∈ labels(model) - @test label ∈ fnodes - - clength = length(fnodes) - filter!(n -> n !== label, fnodes) - @test length(fnodes) === clength - 1 # Only one should be removed - end - @test length(fnodes) === 0 # all should be processed - end -end - -@testitem "variable_nodes with anonymous variables" begin - # The idea here is that the `variable_nodes` must return ALL anonymous variables as well - using Distributions - import GraphPPL: create_model, variable_nodes, getname, is_anonymous, getproperties - - include("testutils.jl") - - @model function simple_submodel_with_2_anonymous_for_variable_nodes(z, x, y) - # Creates two anonymous variables here - z ~ Normal(x + 1, y - 1) - end - - @model function simple_submodel_with_3_anonymous_for_variable_nodes(z, x, y) - # Creates three anonymous variables here - z ~ Normal(x + 1, y - 1 + 1) - end - - @model function simple_model_for_variable_nodes(submodel) - xref ~ Normal(0, 1) - y ~ Gamma(1, 1) - zref ~ submodel(x = xref, y = y) - end - - @testset let submodel = simple_submodel_with_2_anonymous_for_variable_nodes - model = create_model(simple_model_for_variable_nodes(submodel = submodel)) - @test length(collect(variable_nodes(model))) === 11 - @test length(collect(filter(v -> is_anonymous(getproperties(model[v])), collect(variable_nodes(model))))) === 2 - end - - @testset let submodel = simple_submodel_with_3_anonymous_for_variable_nodes - model = create_model(simple_model_for_variable_nodes(submodel = submodel)) - @test length(collect(variable_nodes(model))) === 13 # +1 for new anonymous +1 for new constant - @test length(collect(filter(v -> is_anonymous(getproperties(model[v])), collect(variable_nodes(model))))) === 3 - end -end - -@testitem "Predefined kinds of variable nodes" begin - import GraphPPL: VariableKindRandom, VariableKindData, VariableKindConstant - import GraphPPL: getcontext, getorcreate!, NodeCreationOptions, getproperties - - include("testutils.jl") - - model = create_test_model() - context = getcontext(model) - xref = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindRandom), :x, nothing) - y = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindData), :y, nothing) - zref = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindConstant), :z, nothing) - - import GraphPPL: is_random, is_data, is_constant, is_kind - - xprops = getproperties(model[xref]) - yprops = getproperties(model[y]) - zprops = getproperties(model[zref]) - - @test is_random(xprops) && is_kind(xprops, VariableKindRandom) - @test is_data(yprops) && is_kind(yprops, VariableKindData) - @test is_constant(zprops) && is_kind(zprops, VariableKindConstant) -end - -@testitem "degree" begin - import GraphPPL: create_model, getcontext, getorcreate!, NodeCreationOptions, make_node!, degree - - include("testutils.jl") - - for n in 5:10 - model = create_test_model() - ctx = getcontext(model) - - unused = getorcreate!(model, ctx, :unusued, nothing) - xref = getorcreate!(model, ctx, :x, nothing) - y = getorcreate!(model, ctx, :y, nothing) - - foreach(1:n) do k - getorcreate!(model, ctx, :z, k) - end - - zref = getorcreate!(model, ctx, :z, 1) - - @test degree(model, unused) === 0 - @test degree(model, xref) === 0 - @test degree(model, y) === 0 - @test all(zᵢ -> degree(model, zᵢ) === 0, zref) - - for i in 1:n - make_node!(model, ctx, NodeCreationOptions(), sum, y, (in = [xref, zref[i]],)) - end - - @test degree(model, unused) === 0 - @test degree(model, xref) === n - @test degree(model, y) === n - @test all(zᵢ -> degree(model, zᵢ) === 1, zref) - end -end - -@testitem "is_constant" begin - import GraphPPL: create_model, is_constant, variable_nodes, getname, getproperties - - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments - model = create_model(model_fn()) - for label in variable_nodes(model) - node = model[label] - props = getproperties(node) - if occursin("constvar", string(getname(props))) - @test is_constant(props) - else - @test !is_constant(props) - end - end - end -end - -@testitem "is_data" begin - import GraphPPL: is_data, create_model, getcontext, getorcreate!, variable_nodes, NodeCreationOptions, getproperties - - include("testutils.jl") - - m = create_test_model() - ctx = getcontext(m) - xref = getorcreate!(m, ctx, NodeCreationOptions(kind = :data), :x, nothing) - @test is_data(getproperties(m[xref])) - - using .TestUtils.ModelZoo - - # Since the models here are without top arguments they cannot create `data` labels - for model_fn in ModelsInTheZooWithoutArguments - model = create_model(model_fn()) - for label in variable_nodes(model) - @test !is_data(getproperties(model[label])) - end - end -end - -@testitem "NodeCreationOptions" begin - import GraphPPL: NodeCreationOptions, withopts, withoutopts - - include("testutils.jl") - - @test NodeCreationOptions() == NodeCreationOptions() - @test keys(NodeCreationOptions()) === () - @test NodeCreationOptions(arbitrary_option = 1) == NodeCreationOptions((; arbitrary_option = 1)) - - @test haskey(NodeCreationOptions(arbitrary_option = 1), :arbitrary_option) - @test NodeCreationOptions(arbitrary_option = 1)[:arbitrary_option] === 1 - - @test @inferred(haskey(NodeCreationOptions(), :a)) === false - @test @inferred(haskey(NodeCreationOptions(), :b)) === false - @test @inferred(haskey(NodeCreationOptions(a = 1, b = 2), :b)) === true - @test @inferred(haskey(NodeCreationOptions(a = 1, b = 2), :c)) === false - @test @inferred(NodeCreationOptions(a = 1, b = 2)[:a]) === 1 - @test @inferred(NodeCreationOptions(a = 1, b = 2)[:b]) === 2 - - @test_throws ErrorException NodeCreationOptions()[:a] - @test_throws ErrorException NodeCreationOptions(a = 1, b = 2)[:c] - - @test @inferred(get(NodeCreationOptions(), :a, 2)) === 2 - @test @inferred(get(NodeCreationOptions(), :b, 3)) === 3 - @test @inferred(get(NodeCreationOptions(), :c, 4)) === 4 - @test @inferred(get(NodeCreationOptions(a = 1, b = 2), :a, 2)) === 1 - @test @inferred(get(NodeCreationOptions(a = 1, b = 2), :b, 3)) === 2 - @test @inferred(get(NodeCreationOptions(a = 1, b = 2), :c, 4)) === 4 - - @test NodeCreationOptions(a = 1, b = 2)[(:a,)] === NodeCreationOptions(a = 1) - @test NodeCreationOptions(a = 1, b = 2)[(:b,)] === NodeCreationOptions(b = 2) - - @test keys(NodeCreationOptions(a = 1, b = 2)) == (:a, :b) - - @test @inferred(withopts(NodeCreationOptions(), (a = 1,))) == NodeCreationOptions(a = 1) - @test @inferred(withopts(NodeCreationOptions(b = 2), (a = 1,))) == NodeCreationOptions(b = 2, a = 1) - - @test @inferred(withoutopts(NodeCreationOptions(), Val((:a,)))) == NodeCreationOptions() - @test @inferred(withoutopts(NodeCreationOptions(b = 1), Val((:a,)))) == NodeCreationOptions(b = 1) - @test @inferred(withoutopts(NodeCreationOptions(a = 1), Val((:a,)))) == NodeCreationOptions() - @test @inferred(withoutopts(NodeCreationOptions(a = 1, b = 2), Val((:c,)))) == NodeCreationOptions(a = 1, b = 2) -end - -@testitem "Check that factor node plugins are uniquely recreated" begin - import GraphPPL: create_model, with_plugins, getplugins, factor_nodes, PluginsCollection, setextra!, getextra - - include("testutils.jl") - - using .TestUtils.ModelZoo - - struct AnArbitraryPluginForTestUniqeness end - - GraphPPL.plugin_type(::AnArbitraryPluginForTestUniqeness) = GraphPPL.FactorNodePlugin() - - count = Ref(0) - - function GraphPPL.preprocess_plugin(::AnArbitraryPluginForTestUniqeness, model, context, label, nodedata, options) - setextra!(nodedata, :count, count[]) - count[] = count[] + 1 - return label, nodedata - end - - for model_fn in ModelsInTheZooWithoutArguments - model = create_model(with_plugins(model_fn(), PluginsCollection(AnArbitraryPluginForTestUniqeness()))) - for f1 in factor_nodes(model), f2 in factor_nodes(model) - if f1 !== f2 - @test getextra(model[f1], :count) !== getextra(model[f2], :count) - else - @test getextra(model[f1], :count) === getextra(model[f2], :count) - end - end - end -end - -@testitem "Check that plugins may change the options" begin - import GraphPPL: - NodeData, - variable_nodes, - getname, - index, - is_constant, - getproperties, - value, - PluginsCollection, - VariableNodeProperties, - NodeCreationOptions, - create_model, - with_plugins - - include("testutils.jl") - - using .TestUtils.ModelZoo - - struct AnArbitraryPluginForChangingOptions end - - GraphPPL.plugin_type(::AnArbitraryPluginForChangingOptions) = GraphPPL.VariableNodePlugin() - - function GraphPPL.preprocess_plugin(::AnArbitraryPluginForChangingOptions, model, context, label, nodedata, options) - # Here we replace the original options entirely - return label, NodeData(context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0))) - end - - for model_fn in ModelsInTheZooWithoutArguments - model = create_model(with_plugins(model_fn(), PluginsCollection(AnArbitraryPluginForChangingOptions()))) - for v in variable_nodes(model) - @test getname(getproperties(model[v])) === :x - @test index(getproperties(model[v])) === nothing - @test is_constant(getproperties(model[v])) === true - @test value(getproperties(model[v])) === 1.0 - end - end -end - -@testitem "proxy labels" begin - import GraphPPL: NodeLabel, ProxyLabel, proxylabel, getname, unroll, ResizableArray, FunctionalIndex - - y = NodeLabel(:y, 1) - - let p = proxylabel(:x, y, nothing) - @test last(p) === y - @test getname(p) === :x - @test getname(last(p)) === :y - end - - let p = proxylabel(:x, y, (1,)) - @test_throws "Indexing a single node label `y` with an index `[1]` is not allowed" unroll(p) - end - - let p = proxylabel(:x, y, (1, 2)) - @test_throws "Indexing a single node label `y` with an index `[1, 2]` is not allowed" unroll(p) - end - - let p = proxylabel(:r, proxylabel(:x, y, nothing), nothing) - @test last(p) === y - @test getname(p) === :r - @test getname(last(p)) === :y - end - - for n in (5, 10) - s = ResizableArray(NodeLabel, Val(1)) - - for i in 1:n - s[i] = NodeLabel(:s, i) - end - - let p = proxylabel(:x, s, nothing) - @test last(p) === s - @test all(i -> p[i] === s[i], 1:length(s)) - @test unroll(p) === s - end - - for i in 1:5 - let p = proxylabel(:r, proxylabel(:x, s, (i,)), nothing) - @test unroll(p) === s[i] - end - end - - let p = proxylabel(:r, proxylabel(:x, s, (2:4,)), (2,)) - @test unroll(p) === s[3] - end - let p = proxylabel(:x, s, (2:4,)) - @test p[1] === s[2] - end - end - - for n in (5, 10) - s = ResizableArray(NodeLabel, Val(1)) - - for i in 1:n - s[i] = NodeLabel(:s, i) - end - - let p = proxylabel(:x, s, FunctionalIndex{:begin}(firstindex)) - @test unroll(p) === s[begin] - end - end -end - -@testitem "Lift index" begin - import GraphPPL: lift_index, True, False, checked_getindex - - @test lift_index(True(), nothing, nothing) === nothing - @test lift_index(True(), (1,), nothing) === (1,) - @test lift_index(True(), nothing, (1,)) === (1,) - @test lift_index(True(), (2,), (1,)) === (2,) - @test lift_index(True(), (2, 2), (1,)) === (2, 2) - - @test lift_index(False(), nothing, nothing) === nothing - @test lift_index(False(), (1,), nothing) === nothing - @test lift_index(False(), nothing, (1,)) === (1,) - @test lift_index(False(), (2,), (1,)) === (1,) - @test lift_index(False(), (2, 2), (1,)) === (1,) - - import GraphPPL: proxylabel, lift_index, unroll, ProxyLabel - - struct LiftingTest end - - GraphPPL.is_proxied(::Type{LiftingTest}) = GraphPPL.True() - - function GraphPPL.unroll(proxy::ProxyLabel, ::LiftingTest, index, maycreate, liftedindex) - if liftedindex === nothing - return checked_getindex("Hello", index) - else - return checked_getindex("World", index) - end - end - - @test unroll(proxylabel(:x, LiftingTest(), nothing, True())) === "Hello" - @test unroll(proxylabel(:x, LiftingTest(), (1,), True())) === 'W' - @test unroll(proxylabel(:r, proxylabel(:x, proxylabel(:z, LiftingTest(), nothing), (3,), True()), nothing)) === 'r' - @test unroll( - proxylabel(:r, proxylabel(:x, proxylabel(:w, proxylabel(:z, LiftingTest(), nothing), (2:3,), True()), (1,), False()), nothing) - ) === 'o' -end - -@testitem "`VariableRef` iterators interface" begin - import GraphPPL: VariableRef, getcontext, NodeCreationOptions, VariableKindData, getorcreate! - - include("testutils.jl") - - @testset "Missing internal and external collections" begin - model = create_test_model() - ctx = getcontext(model) - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - - @test @inferred(Base.IteratorSize(xref)) === Base.SizeUnknown() - @test @inferred(Base.IteratorEltype(xref)) === Base.EltypeUnknown() - @test @inferred(Base.eltype(xref)) === Any - end - - @testset "Existing internal and external collections" begin - model = create_test_model() - ctx = getcontext(model) - xcollection = getorcreate!(model, ctx, NodeCreationOptions(), :x, 1) - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (1,), xcollection) - - @test @inferred(Base.IteratorSize(xref)) === Base.HasShape{1}() - @test @inferred(Base.IteratorEltype(xref)) === Base.HasEltype() - @test @inferred(Base.eltype(xref)) === GraphPPL.NodeLabel - end - - @testset "Missing internal but existing external collections" begin - model = create_test_model() - ctx = getcontext(model) - xref = VariableRef(model, ctx, NodeCreationOptions(kind = VariableKindData), :x, (nothing,), [1.0 1.0; 1.0 1.0]) - - @test @inferred(Base.IteratorSize(xref)) === Base.HasShape{2}() - @test @inferred(Base.IteratorEltype(xref)) === Base.HasEltype() - @test @inferred(Base.eltype(xref)) === Float64 - end -end - -@testitem "`VariableRef` in combination with `ProxyLabel` should create variables in the model" begin - import GraphPPL: - VariableRef, - makevarref, - getcontext, - getifcreated, - unroll, - set_maycreate, - ProxyLabel, - NodeLabel, - proxylabel, - NodeCreationOptions, - VariableKindRandom, - VariableKindData, - getproperties, - is_kind, - MissingCollection, - getorcreate! - - using Distributions - - include("testutils.jl") - - @testset "Individual variable creation" begin - model = create_test_model() - ctx = getcontext(model) - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) - x = unroll(proxylabel(:p, xref, nothing, True())) - @test x isa NodeLabel - @test x === ctx[:x] - @test is_kind(getproperties(model[x]), VariableKindRandom) - @test getifcreated(model, ctx, xref) === ctx[:x] - - zref = VariableRef(model, ctx, NodeCreationOptions(kind = VariableKindData), :z, (nothing,), MissingCollection()) - # @test_throws "The variable `z` has been used, but has not been instantiated." getifcreated(model, ctx, zref) - # The label above SHOULD NOT throw, since it has been instantiated with the `MissingCollection` - - # Top level `False` should not play a role here really, but is also essential - # The bottom level `True` does allow the creation of the variable and the top-level `False` should only fetch - z = unroll(proxylabel(:r, proxylabel(:w, zref, nothing, True()), nothing, False())) - @test z isa NodeLabel - @test z === ctx[:z] - @test is_kind(getproperties(model[z]), VariableKindData) - @test getifcreated(model, ctx, zref) === ctx[:z] - end - - @testset "Vectored variable creation" begin - model = create_test_model() - ctx = getcontext(model) - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) - for i in 1:10 - x = unroll(proxylabel(:x, xref, (i,), True())) - @test x isa NodeLabel - @test x === ctx[:x][i] - @test getifcreated(model, ctx, xref) === ctx[:x] - end - @test length(xref) === 10 - @test firstindex(xref) === 1 - @test lastindex(xref) === 10 - @test collect(eachindex(xref)) == collect(1:10) - @test size(xref) === (10,) - end - - @testset "Tensor variable creation" begin - model = create_test_model() - ctx = getcontext(model) - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) - for i in 1:10, j in 1:10 - xij = unroll(proxylabel(:x, xref, (i, j), True())) - @test xij isa NodeLabel - @test xij === ctx[:x][i, j] - @test getifcreated(model, ctx, xref) === ctx[:x] - end - @test length(xref) === 100 - @test firstindex(xref) === 1 - @test lastindex(xref) === 100 - @test collect(eachindex(xref)) == collect(CartesianIndices((1:10, 1:10))) - @test size(xref) === (10, 10) - end - - @testset "Variable should not be created if the `creation` flag is set to `False`" begin - model = create_test_model() - ctx = getcontext(model) - # `x` is not created here, should fail during `unroll` - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) - @test_throws "The variable `x` has been used, but has not been instantiated" unroll(proxylabel(:x, xref, nothing, False())) - # Force create `x` - getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - # Since `x` has been created the `False` flag should not throw - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - @test ctx[:x] === unroll(proxylabel(:x, xref, nothing, False())) - @test getifcreated(model, ctx, xref) === ctx[:x] - end - - @testset "Variable should be created if the `Atomic` fform is used as a first argument with `makevarref`" begin - model = create_test_model() - ctx = getcontext(model) - # `x` is not created here, but `makevarref` takes into account the `Atomic/Composite` - # we always create a variable when used with `Atomic` - xref = makevarref(Normal, model, ctx, NodeCreationOptions(), :x, (nothing,)) - # `@inferred` here is important for simple use cases like `x ~ Normal(0, 1)`, so - # `x` can be inferred properly - @test ctx[:x] === @inferred(unroll(proxylabel(:x, xref, nothing, False()))) - end - - @testset "It should be possible to toggle `maycreate` flag" begin - model = create_test_model() - ctx = getcontext(model) - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - # The first time should throw since the variable has not been instantiated yet - @test_throws "The variable `x` has been used, but has not been instantiated." unroll(proxylabel(:x, xref, nothing, False())) - # Even though the `maycreate` flag is set to `True`, the `set_maycreate` should overwrite it with `False` - @test_throws "The variable `x` has been used, but has not been instantiated." unroll( - set_maycreate(proxylabel(:x, xref, nothing, True()), False()) - ) - - # Even though the `maycreate` flag is set to `False`, the `set_maycreate` should overwrite it with `True` - @test unroll(set_maycreate(proxylabel(:x, xref, nothing, False()), True())) === ctx[:x] - # At this point the variable should be created - @test unroll(proxylabel(:x, xref, nothing, False())) === ctx[:x] - @test unroll(proxylabel(:x, xref, nothing, True())) === ctx[:x] - - @test set_maycreate(1, True()) === 1 - @test set_maycreate(1, False()) === 1 - end -end - -@testitem "`VariableRef` comparison" begin - import GraphPPL: - VariableRef, - makevarref, - getcontext, - getifcreated, - unroll, - ProxyLabel, - NodeLabel, - proxylabel, - NodeCreationOptions, - VariableKindRandom, - VariableKindData, - getproperties, - is_kind, - MissingCollection, - getorcreate! - - using Distributions - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) - @test xref == xref - @test_throws( - "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time.", - xref != 1 - ) - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 1 != - xref - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref == - 1 - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 1 == - xref - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref > - 0 - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 < - xref - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." "something" == - xref - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 10 > - xref - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref < - 10 - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 <= - xref - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref >= - 0 - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref <= - 0 - @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 >= - xref - - xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (1, 2)) - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref != - 1 - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 1 != - xref - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref == - 1 - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 1 == - xref - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref > - 0 - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 < - xref - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." "something" == - xref - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 10 > - xref - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref < - 10 - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 <= - xref - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref >= - 0 - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref <= - 0 - @test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 >= - xref -end - -@testitem "NodeLabel properties" begin - import GraphPPL: NodeLabel - - xref = NodeLabel(:x, 1) - @test xref[1] == xref - @test length(xref) === 1 - @test GraphPPL.to_symbol(xref) === :x_1 - - y = NodeLabel(:y, 2) - @test xref < y -end - -@testitem "getname(::NodeLabel)" begin - import GraphPPL: ResizableArray, NodeLabel, getname - - xref = NodeLabel(:x, 1) - @test getname(xref) == :x - - xref = ResizableArray(NodeLabel, Val(1)) - xref[1] = NodeLabel(:x, 1) - @test getname(xref) == :x - - xref = ResizableArray(NodeLabel, Val(1)) - xref[2] = NodeLabel(:x, 1) - @test getname(xref) == :x -end - -@testitem "setindex!(::Model, ::NodeData, ::NodeLabel)" begin - using Graphs - import GraphPPL: getcontext, NodeLabel, NodeData, VariableNodeProperties, FactorNodeProperties - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - model[NodeLabel(:μ, 1)] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing)) - @test nv(model) == 1 && ne(model) == 0 - - model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) - @test nv(model) == 2 && ne(model) == 0 - - model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum)) - @test nv(model) == 3 && ne(model) == 0 - - @test_throws MethodError model[0] = 1 - @test_throws MethodError model["string"] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) - @test_throws MethodError model["string"] = NodeData(ctx, FactorNodeProperties(fform = sum)) -end - -@testitem "setindex!(::Model, ::EdgeLabel, ::NodeLabel, ::NodeLabel)" begin - using Graphs - import GraphPPL: getcontext, NodeLabel, NodeData, VariableNodeProperties, EdgeLabel - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - - μ = NodeLabel(:μ, 1) - xref = NodeLabel(:x, 2) - - model[μ] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing)) - model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) - model[μ, xref] = EdgeLabel(:interface, 1) - - @test ne(model) == 1 - @test_throws MethodError model[0, 1] = 1 - - # Test that we can't add an edge between two nodes that don't exist - model[μ, NodeLabel(:x, 100)] = EdgeLabel(:if, 1) - @test ne(model) == 1 -end - -@testitem "setindex!(::Context, ::ResizableArray{NodeLabel}, ::Symbol)" begin - import GraphPPL: NodeLabel, ResizableArray, Context, vector_variables, tensor_variables - - context = Context() - context[:x] = ResizableArray(NodeLabel, Val(1)) - @test haskey(vector_variables(context), :x) - - context[:y] = ResizableArray(NodeLabel, Val(2)) - @test haskey(tensor_variables(context), :y) -end - -@testitem "getindex(::Model, ::NodeLabel)" begin - import GraphPPL: create_model, getcontext, NodeLabel, NodeData, VariableNodeProperties, getproperties - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - label = NodeLabel(:x, 1) - model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) - @test isa(model[label], NodeData) - @test isa(getproperties(model[label]), VariableNodeProperties) - @test_throws KeyError model[NodeLabel(:x, 10)] - @test_throws MethodError model[0] -end - -@testitem "getcounter and setcounter!" begin - import GraphPPL: create_model, setcounter!, getcounter - - include("testutils.jl") - - model = create_test_model() - - @test setcounter!(model, 1) == 1 - @test getcounter(model) == 1 - @test setcounter!(model, 2) == 2 - @test getcounter(model) == 2 - @test setcounter!(model, getcounter(model) + 1) == 3 - @test setcounter!(model, 100) == 100 - @test getcounter(model) == 100 -end - -@testitem "nv_ne(::Model)" begin - import GraphPPL: create_model, getcontext, nv, ne, NodeData, VariableNodeProperties, NodeLabel, EdgeLabel - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - @test isempty(model) - @test nv(model) == 0 - @test ne(model) == 0 - - model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing)) - model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing)) - @test !isempty(model) - @test nv(model) == 2 - @test ne(model) == 0 - - model[NodeLabel(:a, 1), NodeLabel(:b, 2)] = EdgeLabel(:edge, 1) - @test !isempty(model) - @test nv(model) == 2 - @test ne(model) == 1 -end - -@testitem "edges" begin - import GraphPPL: - edges, - create_model, - getcontext, - getproperties, - NodeData, - VariableNodeProperties, - FactorNodeProperties, - NodeLabel, - EdgeLabel, - getname, - add_edge!, - has_edge, - getproperties - - include("testutils.jl") - - # Test 1: Test getting all edges from a model - model = create_test_model() - ctx = getcontext(model) - a = NodeLabel(:a, 1) - b = NodeLabel(:b, 2) - model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing)) - model[b] = NodeData(ctx, FactorNodeProperties(fform = sum)) - @test !has_edge(model, a, b) - @test !has_edge(model, b, a) - add_edge!(model, b, getproperties(model[b]), a, :edge, 1) - @test has_edge(model, a, b) - @test has_edge(model, b, a) - @test length(edges(model)) == 1 - - c = NodeLabel(:c, 2) - model[c] = NodeData(ctx, FactorNodeProperties(fform = sum)) - @test !has_edge(model, a, c) - @test !has_edge(model, c, a) - add_edge!(model, c, getproperties(model[c]), a, :edge, 2) - @test has_edge(model, a, c) - @test has_edge(model, c, a) - - @test length(edges(model)) == 2 - - # Test 2: Test getting all edges from a model with a specific node - @test getname.(edges(model, a)) == [:edge, :edge] - @test getname.(edges(model, b)) == [:edge] - @test getname.(edges(model, c)) == [:edge] - # @test getname.(edges(model, [a, b])) == [:edge, :edge, :edge] -end - -@testitem "neighbors(::Model, ::NodeData)" begin - import GraphPPL: - create_model, - getcontext, - neighbors, - NodeData, - VariableNodeProperties, - FactorNodeProperties, - NodeLabel, - EdgeLabel, - getname, - ResizableArray, - add_edge!, - getproperties - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() - ctx = getcontext(model) - - a = NodeLabel(:a, 1) - b = NodeLabel(:b, 2) - model[a] = NodeData(ctx, FactorNodeProperties(fform = sum)) - model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing)) - add_edge!(model, a, getproperties(model[a]), b, :edge, 1) - @test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)] - - model = create_test_model() - ctx = getcontext(model) - a = ResizableArray(NodeLabel, Val(1)) - b = ResizableArray(NodeLabel, Val(1)) - for i in 1:3 - a[i] = NodeLabel(:a, i) - model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum)) - b[i] = NodeLabel(:b, i) - model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i)) - add_edge!(model, a[i], getproperties(model[a[i]]), b[i], :edge, i) - end - for n in b - @test n ∈ neighbors(model, a) - end - # Test 2: Test getting sorted neighbors - model = create_model(simple_model()) - ctx = getcontext(model) - node = first(neighbors(model, ctx[:z])) # Normal node we're investigating is the only neighbor of `z` in the graph. - @test getname.(neighbors(model, node)) == [:z, :x, :y] - - # Test 3: Test getting sorted neighbors when one of the edge indices is nothing - model = create_model(vector_model()) - ctx = getcontext(model) - node = first(neighbors(model, ctx[:z][1])) - @test getname.(collect(neighbors(model, node))) == [:z, :x, :y] -end - -@testitem "filter(::Predicate, ::Model)" begin - import GraphPPL: create_model, as_node, as_context, as_variable - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(simple_model()) - result = collect(filter(as_node(Normal) | as_variable(:x), model)) - @test length(result) == 3 - - model = create_model(outer()) - result = collect(filter(as_node(Gamma) & as_context(inner_inner), model)) - @test length(result) == 0 - - result = collect(filter(as_node(Gamma) | as_context(inner_inner), model)) - @test length(result) == 6 - - result = collect(filter(as_node(Normal) & as_context(inner_inner; children = true), model)) - @test length(result) == 1 -end - -@testitem "filter(::FactorNodePredicate, ::Model)" begin - import GraphPPL: create_model, as_node, getcontext - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(simple_model()) - context = getcontext(model) - result = filter(as_node(Normal), model) - @test collect(result) == [context[NormalMeanVariance, 1], context[NormalMeanVariance, 2]] - result = filter(as_node(), model) - @test collect(result) == [context[NormalMeanVariance, 1], context[GammaShapeScale, 1], context[NormalMeanVariance, 2]] -end - -@testitem "filter(::VariableNodePredicate, ::Model)" begin - import GraphPPL: create_model, as_variable, getcontext, variable_nodes - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(simple_model()) - context = getcontext(model) - result = filter(as_variable(:x), model) - @test collect(result) == [context[:x]...] - result = filter(as_variable(), model) - @test collect(result) == collect(variable_nodes(model)) -end - -@testitem "filter(::SubmodelPredicate, Model)" begin - import GraphPPL: create_model, as_context - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(outer()) - - result = filter(as_context(inner), model) - @test length(collect(result)) == 0 - - result = filter(as_context(inner; children = true), model) - @test length(collect(result)) == 1 - - result = filter(as_context(inner_inner), model) - @test length(collect(result)) == 1 - - result = filter(as_context(outer; children = true), model) - @test length(collect(result)) == 22 -end - -@testitem "generate_nodelabel(::Model, ::Symbol)" begin - import GraphPPL: create_model, gensym, NodeLabel, generate_nodelabel - - include("testutils.jl") - - model = create_test_model() - first_sym = generate_nodelabel(model, :x) - @test typeof(first_sym) == NodeLabel - - second_sym = generate_nodelabel(model, :x) - @test first_sym != second_sym && first_sym.name == second_sym.name - - id = generate_nodelabel(model, :c) - @test id.name == :c && id.global_counter == 3 -end - -@testitem "getname" begin - import GraphPPL: getname - @test getname(+) == "+" - @test getname(-) == "-" - @test getname(sin) == "sin" - @test getname(cos) == "cos" - @test getname(exp) == "exp" -end - -@testitem "Context" begin - import GraphPPL: Context - - ctx1 = Context() - @test typeof(ctx1) == Context && ctx1.prefix == "" && length(ctx1.individual_variables) == 0 && ctx1.depth == 0 - - io = IOBuffer() - show(io, ctx1) - output = String(take!(io)) - @test !isempty(output) - @test contains(output, "identity") # fform - - # By default `returnval` is not defined - @test_throws UndefRefError GraphPPL.returnval(ctx1) - for i in 1:10 - GraphPPL.returnval!(ctx1, (i, "$i")) - @test GraphPPL.returnval(ctx1) == (i, "$i") - end - - function test end - - ctx2 = Context(0, test, "test", nothing) - @test contains(repr(ctx2), "test") - @test typeof(ctx2) == Context && ctx2.prefix == "test" && length(ctx2.individual_variables) == 0 && ctx2.depth == 0 - - function layer end - - ctx3 = Context(ctx2, layer) - @test typeof(ctx3) == Context && ctx3.prefix == "test_layer" && length(ctx3.individual_variables) == 0 && ctx3.depth == 1 - - @test_throws MethodError Context(ctx2, :my_model) - - function secondlayer end - - ctx5 = Context(ctx2, secondlayer) - @test typeof(ctx5) == Context && ctx5.prefix == "test_secondlayer" && length(ctx5.individual_variables) == 0 && ctx5.depth == 1 - - ctx6 = Context(ctx3, secondlayer) - @test typeof(ctx6) == Context && ctx6.prefix == "test_layer_secondlayer" && length(ctx6.individual_variables) == 0 && ctx6.depth == 2 -end - -@testitem "haskey(::Context)" begin - import GraphPPL: - Context, - NodeLabel, - ResizableArray, - ProxyLabel, - individual_variables, - vector_variables, - tensor_variables, - proxies, - children, - proxylabel - - ctx = Context() - xlab = NodeLabel(:x, 1) - @test !haskey(ctx, :x) - ctx[:x] = xlab - @test haskey(ctx, :x) - @test haskey(individual_variables(ctx), :x) - @test !haskey(vector_variables(ctx), :x) - @test !haskey(tensor_variables(ctx), :x) - @test !haskey(proxies(ctx), :x) - - @test !haskey(ctx, :y) - ctx[:y] = ResizableArray(NodeLabel, Val(1)) - @test haskey(ctx, :y) - @test !haskey(individual_variables(ctx), :y) - @test haskey(vector_variables(ctx), :y) - @test !haskey(tensor_variables(ctx), :y) - @test !haskey(proxies(ctx), :y) - - @test !haskey(ctx, :z) - ctx[:z] = ResizableArray(NodeLabel, Val(2)) - @test haskey(ctx, :z) - @test !haskey(individual_variables(ctx), :z) - @test !haskey(vector_variables(ctx), :z) - @test haskey(tensor_variables(ctx), :z) - @test !haskey(proxies(ctx), :z) - - @test !haskey(ctx, :proxy) - ctx[:proxy] = proxylabel(:proxy, xlab, nothing) - @test !haskey(individual_variables(ctx), :proxy) - @test !haskey(vector_variables(ctx), :proxy) - @test !haskey(tensor_variables(ctx), :proxy) - @test haskey(proxies(ctx), :proxy) - - @test !haskey(ctx, GraphPPL.FactorID(sum, 1)) - ctx[GraphPPL.FactorID(sum, 1)] = Context() - @test haskey(ctx, GraphPPL.FactorID(sum, 1)) - @test haskey(children(ctx), GraphPPL.FactorID(sum, 1)) -end - -@testitem "getindex(::Context, ::Symbol)" begin - import GraphPPL: Context, NodeLabel - - ctx = Context() - xlab = NodeLabel(:x, 1) - @test_throws KeyError ctx[:x] - ctx[:x] = xlab - @test ctx[:x] == xlab -end - -@testitem "getindex(::Context, ::FactorID)" begin - import GraphPPL: Context, NodeLabel, FactorID - - ctx = Context() - @test_throws KeyError ctx[FactorID(sum, 1)] - ctx[FactorID(sum, 1)] = Context() - @test ctx[FactorID(sum, 1)] == ctx.children[FactorID(sum, 1)] - - @test_throws KeyError ctx[FactorID(sum, 2)] - ctx[FactorID(sum, 2)] = NodeLabel(:sum, 1) - @test ctx[FactorID(sum, 2)] == ctx.factor_nodes[FactorID(sum, 2)] -end - -@testitem "getcontext(::Model)" begin - import GraphPPL: Context, getcontext, create_model, add_variable_node!, NodeCreationOptions - - include("testutils.jl") - - model = create_test_model() - @test getcontext(model) == model.graph[] - add_variable_node!(model, getcontext(model), NodeCreationOptions(), :x, nothing) - @test getcontext(model)[:x] == model.graph[][:x] -end - -@testitem "path_to_root(::Context)" begin - import GraphPPL: create_model, Context, path_to_root, getcontext - - include("testutils.jl") - - using .TestUtils.ModelZoo - - ctx = Context() - @test path_to_root(ctx) == [ctx] - - model = create_model(outer()) - ctx = getcontext(model) - inner_context = ctx[inner, 1] - inner_inner_context = inner_context[inner_inner, 1] - @test path_to_root(inner_inner_context) == [inner_inner_context, inner_context, ctx] -end - -@testitem "VarDict" begin - using Distributions - import GraphPPL: - Context, VarDict, create_model, getorcreate!, datalabel, NodeCreationOptions, getcontext, is_random, is_data, getproperties - - include("testutils.jl") - - ctx = Context() - vardict = VarDict(ctx) - @test isa(vardict, VarDict) - - @model function submodel(y, x_prev, x_next) - γ ~ Gamma(1, 1) - x_next ~ Normal(x_prev, γ) - y ~ Normal(x_next, 1) - end - - @model function state_space_model_with_new(y) - x[1] ~ Normal(0, 1) - y[1] ~ Normal(x[1], 1) - for i in 2:length(y) - y[i] ~ submodel(x_next = new(x[i]), x_prev = x[i - 1]) - end - end - - ydata = ones(10) - model = create_model(state_space_model_with_new()) do model, ctx - y = datalabel(model, ctx, NodeCreationOptions(kind = :data), :y, ydata) - return (y = y,) - end - - context = getcontext(model) - vardict = VarDict(context) - - @test haskey(vardict, :y) - @test haskey(vardict, :x) - for i in 1:(length(ydata) - 1) - @test haskey(vardict, (submodel, i)) - @test haskey(vardict[submodel, i], :γ) - end - - @test vardict[:y] === context[:y] - @test vardict[:x] === context[:x] - @test vardict[submodel, 1] == VarDict(context[submodel, 1]) - - result = map(identity, vardict) - @test haskey(result, :y) - @test haskey(result, :x) - for i in 1:(length(ydata) - 1) - @test haskey(result, (submodel, i)) - @test haskey(result[submodel, i], :γ) - end - - result = map(vardict) do variable - return length(variable) - end - @test haskey(result, :y) - @test haskey(result, :x) - @test result[:y] === length(ydata) - @test result[:x] === length(ydata) - for i in 1:(length(ydata) - 1) - @test result[(submodel, i)][:γ] === 1 - @test result[GraphPPL.FactorID(submodel, i)][:γ] === 1 - @test result[submodel, i][:γ] === 1 - end - - # Filter only random variables - result = filter(vardict) do label - if label isa GraphPPL.ResizableArray - all(is_random.(getproperties.(model[label]))) - else - return is_random(getproperties(model[label])) - end - end - @test !haskey(result, :y) - @test haskey(result, :x) - for i in 1:(length(ydata) - 1) - @test haskey(result, (submodel, i)) - @test haskey(result[submodel, i], :γ) - end - - # Filter only data variables - result = filter(vardict) do label - if label isa GraphPPL.ResizableArray - all(is_data.(getproperties.(model[label]))) - else - return is_data(getproperties(model[label])) - end - end - @test haskey(result, :y) - @test !haskey(result, :x) - for i in 1:(length(ydata) - 1) - @test haskey(result, (submodel, i)) - @test !haskey(result[submodel, i], :γ) - end -end - -@testitem "NodeType" begin - import GraphPPL: NodeType, Composite, Atomic - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() - - @test NodeType(model, Composite) == Atomic() - @test NodeType(model, Atomic) == Atomic() - @test NodeType(model, abs) == Atomic() - @test NodeType(model, Normal) == Atomic() - @test NodeType(model, NormalMeanVariance) == Atomic() - @test NodeType(model, NormalMeanPrecision) == Atomic() - - # Could test all here - for model_fn in ModelsInTheZooWithoutArguments - @test NodeType(model, model_fn) == Composite() - end -end - -@testitem "NodeBehaviour" begin - import GraphPPL: NodeBehaviour, Deterministic, Stochastic - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() - - @test NodeBehaviour(model, () -> 1) == Deterministic() - @test NodeBehaviour(model, Matrix) == Deterministic() - @test NodeBehaviour(model, abs) == Deterministic() - @test NodeBehaviour(model, Normal) == Stochastic() - @test NodeBehaviour(model, NormalMeanVariance) == Stochastic() - @test NodeBehaviour(model, NormalMeanPrecision) == Stochastic() - - # Could test all here - for model_fn in ModelsInTheZooWithoutArguments - @test NodeBehaviour(model, model_fn) == Stochastic() - end -end - -@testitem "create_test_model()" begin - import GraphPPL: create_model, Model, nv, ne - - include("testutils.jl") - - model = create_test_model() - @test typeof(model) <: Model && nv(model) == 0 && ne(model) == 0 - - @test_throws MethodError create_test_model(:x, :y, :z) -end - -@testitem "copy_markov_blanket_to_child_context" begin - import GraphPPL: - create_model, copy_markov_blanket_to_child_context, Context, getorcreate!, proxylabel, unroll, getcontext, NodeCreationOptions - - include("testutils.jl") - - # Copy individual variables - model = create_test_model() - ctx = getcontext(model) - function child end - child_context = Context(ctx, child) - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - y = getorcreate!(model, ctx, NodeCreationOptions(), :y, nothing) - zref = getorcreate!(model, ctx, NodeCreationOptions(), :z, nothing) - - # Do not copy constant variables - model = create_test_model() - ctx = getcontext(model) - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - child_context = Context(ctx, child) - copy_markov_blanket_to_child_context(child_context, (in = 1,)) - @test !haskey(child_context, :in) - - # Do not copy vector valued constant variables - model = create_test_model() - ctx = getcontext(model) - child_context = Context(ctx, child) - copy_markov_blanket_to_child_context(child_context, (in = [1, 2, 3],)) - @test !haskey(child_context, :in) - - # Copy ProxyLabel variables to child context - model = create_test_model() - ctx = getcontext(model) - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - xref = proxylabel(:x, xref, nothing) - child_context = Context(ctx, child) - copy_markov_blanket_to_child_context(child_context, (in = xref,)) - @test child_context[:in] == xref -end - -@testitem "getorcreate!" begin - using Graphs - import GraphPPL: - create_model, - getcontext, - getorcreate!, - check_variate_compatability, - NodeLabel, - ResizableArray, - NodeCreationOptions, - getproperties, - is_kind - - include("testutils.jl") - - let # let block to suppress the scoping warnings - # Test 1: Creation of regular one-dimensional variable - model = create_test_model() - ctx = getcontext(model) - x = getorcreate!(model, ctx, :x, nothing) - @test nv(model) == 1 && ne(model) == 0 - - # Test 2: Ensure that getorcreating this variable again does not create a new node - x2 = getorcreate!(model, ctx, :x, nothing) - @test x == x2 && nv(model) == 1 && ne(model) == 0 - - # Test 3: Ensure that calling x another time gives us x - x = getorcreate!(model, ctx, :x, nothing) - @test x == x2 && nv(model) == 1 && ne(model) == 0 - - # Test 4: Test that creating a vector variable creates an array of the correct size - model = create_test_model() - ctx = getcontext(model) - y = getorcreate!(model, ctx, :y, 1) - @test nv(model) == 1 && ne(model) == 0 && y isa ResizableArray && y[1] isa NodeLabel - - # Test 5: Test that recreating the same variable changes nothing - y2 = getorcreate!(model, ctx, :y, 1) - @test y == y2 && nv(model) == 1 && ne(model) == 0 - - # Test 6: Test that adding a variable to this vector variable increases the size of the array - y = getorcreate!(model, ctx, :y, 2) - @test nv(model) == 2 && y[2] isa NodeLabel && haskey(ctx.vector_variables, :y) - - # Test 7: Test that getting this variable without index does not work - @test_throws ErrorException getorcreate!(model, ctx, :y, nothing) - - # Test 8: Test that getting this variable with an index that is too large does not work - @test_throws ErrorException getorcreate!(model, ctx, :y, 1, 2) - - #Test 9: Test that creating a tensor variable creates a tensor of the correct size - model = create_test_model() - ctx = getcontext(model) - z = getorcreate!(model, ctx, :z, 1, 1) - @test nv(model) == 1 && ne(model) == 0 && z isa ResizableArray && z[1, 1] isa NodeLabel - - #Test 10: Test that recreating the same variable changes nothing - z2 = getorcreate!(model, ctx, :z, 1, 1) - @test z == z2 && nv(model) == 1 && ne(model) == 0 - - #Test 11: Test that adding a variable to this tensor variable increases the size of the array - z = getorcreate!(model, ctx, :z, 2, 2) - @test nv(model) == 2 && z[2, 2] isa NodeLabel && haskey(ctx.tensor_variables, :z) - - #Test 12: Test that getting this variable without index does not work - @test_throws ErrorException z = getorcreate!(model, ctx, :z, nothing) - - #Test 13: Test that getting this variable with an index that is too small does not work - @test_throws ErrorException z = getorcreate!(model, ctx, :z, 1) - - #Test 14: Test that getting this variable with an index that is too large does not work - @test_throws ErrorException z = getorcreate!(model, ctx, :z, 1, 2, 3) - - # Test 15: Test that creating a variable that exists in the model scope but not in local scope still throws an error - let # force local scope - model = create_test_model() - ctx = getcontext(model) - getorcreate!(model, ctx, :a, nothing) - @test_throws ErrorException a = getorcreate!(model, ctx, :a, 1) - @test_throws ErrorException a = getorcreate!(model, ctx, :a, 1, 1) - end - - # Test 16. Test that the index is required to create a variable in the model - model = create_test_model() - ctx = getcontext(model) - @test_throws ErrorException getorcreate!(model, ctx, :a) - @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(), :a) - @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a) - @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 2), :a) - - # Test 17. Range based getorcreate! - model = create_test_model() - ctx = getcontext(model) - var = getorcreate!(model, ctx, :a, 1:2) - @test nv(model) == 2 && var[1] isa NodeLabel && var[2] isa NodeLabel - - # Test 17.1 Range based getorcreate! should use the same options - model = create_test_model() - ctx = getcontext(model) - var = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a, 1:2) - @test nv(model) == 2 && var[1] isa NodeLabel && var[2] isa NodeLabel - @test is_kind(getproperties(model[var[1]]), :data) - @test is_kind(getproperties(model[var[1]]), :data) - - # Test 18. Range x2 based getorcreate! - model = create_test_model() - ctx = getcontext(model) - var = getorcreate!(model, ctx, :a, 1:2, 1:3) - @test nv(model) == 6 - for i in 1:2, j in 1:3 - @test var[i, j] isa NodeLabel - end - - # Test 18. Range x2 based getorcreate! should use the same options - model = create_test_model() - ctx = getcontext(model) - var = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a, 1:2, 1:3) - @test nv(model) == 6 - for i in 1:2, j in 1:3 - @test var[i, j] isa NodeLabel - @test is_kind(getproperties(model[var[i, j]]), :data) - end - end -end - -@testitem "getifcreated" begin - using Graphs - import GraphPPL: - create_model, - getifcreated, - getorcreate!, - getcontext, - getproperties, - getname, - value, - getorcreate!, - proxylabel, - value, - NodeCreationOptions - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - - # Test case 1: check that getifcreated the variable created by getorcreate - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - @test getifcreated(model, ctx, xref) == xref - - # Test case 2: check that getifcreated returns the variable created by getorcreate in a vector - y = getorcreate!(model, ctx, NodeCreationOptions(), :y, 1) - @test getifcreated(model, ctx, y[1]) == y[1] - - # Test case 3: check that getifcreated returns a new variable node when called with integer input - c = getifcreated(model, ctx, 1) - @test value(getproperties(model[c])) == 1 - - # Test case 4: check that getifcreated returns a new variable node when called with a vector input - c = getifcreated(model, ctx, [1, 2]) - @test value(getproperties(model[c])) == [1, 2] - - # Test case 5: check that getifcreated returns a tuple of variable nodes when called with a tuple of NodeData - output = getifcreated(model, ctx, (xref, y[1])) - @test output == (xref, y[1]) - - # Test case 6: check that getifcreated returns a tuple of new variable nodes when called with a tuple of integers - output = getifcreated(model, ctx, (1, 2)) - @test value(getproperties(model[output[1]])) == 1 - @test value(getproperties(model[output[2]])) == 2 - - # Test case 7: check that getifcreated returns a tuple of variable nodes when called with a tuple of mixed input - output = getifcreated(model, ctx, (xref, 1)) - @test output[1] == xref && value(getproperties(model[output[2]])) == 1 - - # Test case 10: check that getifcreated returns the variable node if we create a variable and call it by symbol in a vector - model = create_test_model() - ctx = getcontext(model) - zref = getorcreate!(model, ctx, NodeCreationOptions(), :z, 1) - z_fetched = getifcreated(model, ctx, zref[1]) - @test z_fetched == zref[1] - - # Test case 11: Test that getifcreated returns a constant node when we call it with a symbol - model = create_test_model() - ctx = getcontext(model) - zref = getifcreated(model, ctx, :Bernoulli) - @test value(getproperties(model[zref])) == :Bernoulli - - # Test case 12: Test that getifcreated returns a vector of NodeLabels if called with a vector of NodeLabels - model = create_test_model() - ctx = getcontext(model) - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - y = getorcreate!(model, ctx, NodeCreationOptions(), :y, nothing) - zref = getifcreated(model, ctx, [xref, y]) - @test zref == [xref, y] - - # Test case 13: Test that getifcreated returns a ResizableArray tensor of NodeLabels if called with a ResizableArray tensor of NodeLabels - model = create_test_model() - ctx = getcontext(model) - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, 1, 1) - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, 2, 1) - zref = getifcreated(model, ctx, xref) - @test zref == xref - - # Test case 14: Test that getifcreated returns multiple variables if called with a tuple of constants - model = create_test_model() - ctx = getcontext(model) - zref = getifcreated(model, ctx, ([1, 1], 2)) - @test nv(model) == 2 && value(getproperties(model[zref[1]])) == [1, 1] && value(getproperties(model[zref[2]])) == 2 - - # Test case 15: Test that getifcreated returns a ProxyLabel if called with a ProxyLabel - model = create_test_model() - ctx = getcontext(model) - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - xref = proxylabel(:x, xref, nothing) - zref = getifcreated(model, ctx, xref) - @test zref === xref -end - -@testitem "datalabel" begin - import GraphPPL: getcontext, datalabel, NodeCreationOptions, VariableKindData, VariableKindRandom, unroll, proxylabel - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - ylabel = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :y) - yvar = unroll(ylabel) - @test haskey(ctx, :y) && ctx[:y] === yvar - @test GraphPPL.nv(model) === 1 - # subsequent unroll should return the same variable - unroll(ylabel) - @test haskey(ctx, :y) && ctx[:y] === yvar - @test GraphPPL.nv(model) === 1 - - yvlabel = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :yv, [1, 2, 3]) - for i in 1:3 - yvvar = unroll(proxylabel(:yv, yvlabel, (i,))) - @test haskey(ctx, :yv) && ctx[:yv][i] === yvvar - @test GraphPPL.nv(model) === 1 + i - end - # Incompatible data indices - @test_throws "The index `[4]` is not compatible with the underlying collection provided for the label `yv`" unroll( - proxylabel(:yv, yvlabel, (4,)) - ) - @test_throws "The underlying data provided for `yv` is `[1, 2, 3]`" unroll(proxylabel(:yv, yvlabel, (4,))) - - @test_throws "`datalabel` only supports `VariableKindData` in `NodeCreationOptions`" datalabel(model, ctx, NodeCreationOptions(), :z) - @test_throws "`datalabel` only supports `VariableKindData` in `NodeCreationOptions`" datalabel( - model, ctx, NodeCreationOptions(kind = VariableKindRandom), :z - ) -end - -@testitem "add_variable_node!" begin - import GraphPPL: - create_model, - add_variable_node!, - getcontext, - options, - NodeLabel, - ResizableArray, - nv, - ne, - NodeCreationOptions, - getproperties, - is_constant, - value - - include("testutils.jl") - - # Test 1: simple add variable to model - model = create_test_model() - ctx = getcontext(model) - node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, nothing) - @test nv(model) == 1 && haskey(ctx.individual_variables, :x) && ctx.individual_variables[:x] == node_id - - # Test 2: Add second variable to model - add_variable_node!(model, ctx, NodeCreationOptions(), :y, nothing) - @test nv(model) == 2 && haskey(ctx, :y) - - # Test 3: Check that adding an integer variable throws a MethodError - @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), 1) - @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), 1, 1) - - # Test 4: Add a vector variable to the model - model = create_test_model() - ctx = getcontext(model) - ctx[:x] = ResizableArray(NodeLabel, Val(1)) - node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, 2) - @test nv(model) == 1 && haskey(ctx, :x) && ctx[:x][2] == node_id && length(ctx[:x]) == 2 - - # Test 5: Add a second vector variable to the model - node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, 1) - @test nv(model) == 2 && haskey(ctx, :x) && ctx[:x][1] == node_id && length(ctx[:x]) == 2 - - # Test 6: Add a tensor variable to the model - model = create_test_model() - ctx = getcontext(model) - ctx[:x] = ResizableArray(NodeLabel, Val(2)) - node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, (2, 3)) - @test nv(model) == 1 && haskey(ctx, :x) && ctx[:x][2, 3] == node_id - - # Test 7: Add a second tensor variable to the model - node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, (2, 4)) - @test nv(model) == 2 && haskey(ctx, :x) && ctx[:x][2, 4] == node_id - - # Test 9: Add a variable with a non-integer index - model = create_test_model() - ctx = getcontext(model) - ctx[:z] = ResizableArray(NodeLabel, Val(2)) - @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, "a") - @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, ("a", "a")) - @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, ("a", 1)) - @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, (1, "a")) - - # Test 10: Add a variable with a negative index - ctx[:x] = ResizableArray(NodeLabel, Val(1)) - @test_throws BoundsError add_variable_node!(model, ctx, NodeCreationOptions(), :x, -1) - - # Test 11: Add a variable with options - model = create_test_model() - ctx = getcontext(model) - var = add_variable_node!(model, ctx, NodeCreationOptions(kind = :constant, value = 1.0), :x, nothing) - @test nv(model) == 1 && - haskey(ctx, :x) && - ctx[:x] == var && - is_constant(getproperties(model[var])) && - value(getproperties(model[var])) == 1.0 - - # Test 12: Add a variable without options - model = create_test_model() - ctx = getcontext(model) - var = add_variable_node!(model, ctx, :x, nothing) - @test nv(model) == 1 && haskey(ctx, :x) && ctx[:x] == var -end - -@testitem "interface_alias" begin - using GraphPPL - import GraphPPL: interface_aliases, StaticInterfaces - - include("testutils.jl") - - model = create_test_model() - - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :τ)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :precision)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :precision)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :τ)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :precision)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :τ)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :p)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :p)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :p)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :prec)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :prec)))) === StaticInterfaces((:out, :μ, :τ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :prec)))) === StaticInterfaces((:out, :μ, :τ)) - - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :τ)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :precision)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :τ)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :precision)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :precision)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :τ)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :p)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :p)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :p)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :prec)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :prec)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :prec)))) === 0 - - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :σ)))) === StaticInterfaces((:out, :μ, :σ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :variance)))) === StaticInterfaces((:out, :μ, :σ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :variance)))) === StaticInterfaces((:out, :μ, :σ)) - @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :σ)))) === StaticInterfaces((:out, :μ, :σ)) - - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :σ)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :variance)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :σ)))) === 0 - @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :variance)))) === 0 -end - -@testitem "add_atomic_factor_node!" begin - using Distributions - using Graphs - import GraphPPL: create_model, add_atomic_factor_node!, getorcreate!, getcontext, getorcreate!, label_for, getname, NodeCreationOptions - - include("testutils.jl") - - # Test 1: Add an atomic factor node to the model - model = create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.MetaPlugin())) - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) - node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, sum) - @test model[node_id] === node_data - @test nv(model) == 2 && getname(label_for(model.graph, 2)) == sum - - # Test 2: Add a second atomic factor node to the model with the same name and assert they are different - node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, sum) - @test model[node_id] === node_data - @test nv(model) == 3 && getname(label_for(model.graph, 3)) == sum - - # Test 3: Add an atomic factor node with options - options = NodeCreationOptions((; meta = true,)) - node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, sum) - @test model[node_id] === node_data - @test nv(model) == 4 && getname(label_for(model.graph, 4)) == sum - @test GraphPPL.hasextra(node_data, :meta) - @test GraphPPL.getextra(node_data, :meta) == true - - # Test 4: Test that creating a node with an instantiated object is supported - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - prior = Normal(0, 1) - node_id, node_data, node_properties = add_atomic_factor_node!(model, ctx, options, prior) - @test model[node_id] === node_data - @test nv(model) == 1 && getname(label_for(model.graph, 1)) == Normal(0, 1) -end - -@testitem "add_composite_factor_node!" begin - using Graphs - import GraphPPL: create_model, add_composite_factor_node!, getcontext, to_symbol, children, add_variable_node!, Context - - include("testutils.jl") - - # Add a composite factor node to the model - model = create_test_model() - parent_ctx = getcontext(model) - child_ctx = getcontext(model) - add_variable_node!(model, child_ctx, :x, nothing) - add_variable_node!(model, child_ctx, :y, nothing) - node_id = add_composite_factor_node!(model, parent_ctx, child_ctx, :f) - @test nv(model) == 2 && - haskey(children(parent_ctx), node_id) && - children(parent_ctx)[node_id] === child_ctx && - length(child_ctx.individual_variables) == 2 - - # Add a composite factor node with a different name - node_id = add_composite_factor_node!(model, parent_ctx, child_ctx, :g) - @test nv(model) == 2 && - haskey(children(parent_ctx), node_id) && - children(parent_ctx)[node_id] === child_ctx && - length(child_ctx.individual_variables) == 2 - - # Add a composite factor node with an empty child context - empty_ctx = Context() - node_id = add_composite_factor_node!(model, parent_ctx, empty_ctx, :h) - @test nv(model) == 2 && - haskey(children(parent_ctx), node_id) && - children(parent_ctx)[node_id] === empty_ctx && - length(empty_ctx.individual_variables) == 0 -end - -@testitem "add_edge!(::Model, ::NodeLabel, ::NodeLabel, ::Symbol)" begin - import GraphPPL: - create_model, getcontext, nv, ne, NodeData, NodeLabel, EdgeLabel, add_edge!, getorcreate!, generate_nodelabel, NodeCreationOptions - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum) - y = getorcreate!(model, ctx, :y, nothing) - - add_edge!(model, xref, xproperties, y, :interface) - - @test ne(model) == 1 - - @test_throws MethodError add_edge!(model, xref, xproperties, y, 123) -end - -@testitem "add_edge!(::Model, ::NodeLabel, ::Vector{NodeLabel}, ::Symbol)" begin - import GraphPPL: create_model, getcontext, nv, ne, NodeData, NodeLabel, EdgeLabel, add_edge!, getorcreate!, NodeCreationOptions - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - y = getorcreate!(model, ctx, :y, nothing) - - variable_nodes = [getorcreate!(model, ctx, i, nothing) for i in [:a, :b, :c]] - xref, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum) - add_edge!(model, xref, xproperties, variable_nodes, :interface) - - @test ne(model) == 3 && model[variable_nodes[1], xref] == EdgeLabel(:interface, 1) -end - -@testitem "default_parametrization" begin - import GraphPPL: default_parametrization, Composite, Atomic - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() - - # Test 1: Add default arguments to Normal call - @test default_parametrization(model, Atomic(), Normal, (0, 1)) == (μ = 0, σ = 1) - - # Test 2: Add :in to function call that has default behaviour - @test default_parametrization(model, Atomic(), +, (1, 2)) == (in = (1, 2),) - - # Test 3: Add :in to function call that has default behaviour with nested interfaces - @test default_parametrization(model, Atomic(), +, ([1, 1], 2)) == (in = ([1, 1], 2),) - - @test_throws ErrorException default_parametrization(model, Composite(), gcv, (1, 2)) -end - -@testitem "contains_nodelabel" begin - import GraphPPL: create_model, getcontext, getorcreate!, contains_nodelabel, NodeCreationOptions, True, False, MixedArguments - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - a = getorcreate!(model, ctx, :x, nothing) - b = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :x, nothing) - c = 1.0 - - # Test 1. Tuple based input - @test contains_nodelabel((a, b, c)) === True() - @test contains_nodelabel((a, b)) === True() - @test contains_nodelabel((a,)) === True() - @test contains_nodelabel((b,)) === True() - @test contains_nodelabel((c,)) === False() - - # Test 2. Named tuple based input - @test @inferred(contains_nodelabel((; a = a, b = b, c = c))) === True() - @test @inferred(contains_nodelabel((; a = a, b = b))) === True() - @test @inferred(contains_nodelabel((; a = a))) === True() - @test @inferred(contains_nodelabel((; b = b))) === True() - @test @inferred(contains_nodelabel((; c = c))) === False() - - # Test 3. MixedArguments based input - @test @inferred(contains_nodelabel(MixedArguments((), (; a = a, b = b, c = c)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((), (; a = a, b = b)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((), (; a = a)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((), (; b = b)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((), (; c = c)))) === False() - - @test @inferred(contains_nodelabel(MixedArguments((a,), (; b = b, c = c)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((c,), (; a = a, b = b)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((b,), (; a = a)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((c,), (; b = b)))) === True() - @test @inferred(contains_nodelabel(MixedArguments((c,), (;)))) === False() - @test @inferred(contains_nodelabel(MixedArguments((), (; c = c)))) === False() -end - -@testitem "make_node!(::Atomic)" begin - using Graphs, BitSetTuples - import GraphPPL: - getcontext, - make_node!, - create_model, - getorcreate!, - AnonymousVariable, - proxylabel, - getname, - label_for, - edges, - MixedArguments, - prune!, - fform, - value, - NodeCreationOptions, - getproperties - - include("testutils.jl") - - # Test 1: Deterministic call returns result of deterministic function and does not create new node - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = AnonymousVariable(model, ctx) - @test make_node!(model, ctx, options, +, xref, (1, 1)) == (nothing, 2) - @test make_node!(model, ctx, options, sin, xref, (0,)) == (nothing, 0) - @test nv(model) == 0 - - xref = proxylabel(:proxy, AnonymousVariable(model, ctx), nothing) - @test make_node!(model, ctx, options, +, xref, (1, 1)) == (nothing, 2) - @test make_node!(model, ctx, options, sin, xref, (0,)) == (nothing, 0) - @test nv(model) == 0 - - # Test 2: Stochastic atomic call returns a new node id - node_id, _ = make_node!(model, ctx, options, Normal, xref, (μ = 0, σ = 1)) - @test nv(model) == 4 - @test getname.(edges(model, node_id)) == [:out, :μ, :σ] - @test getname.(edges(model, node_id)) == [:out, :μ, :σ] - - # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - make_node!(model, ctx, options, Normal, xref, (0, 1)) - @test nv(model) == 4 && ne(model) == 3 - - # Test 4: Deterministic atomic call with nodelabels should create the actual node - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - in1 = getorcreate!(model, ctx, :in1, nothing) - in2 = getorcreate!(model, ctx, :in2, nothing) - out = getorcreate!(model, ctx, :out, nothing) - make_node!(model, ctx, options, +, out, (in1, in2)) - @test nv(model) == 4 && ne(model) == 3 - - # Test 5: Deterministic atomic call with nodelabels should create the actual node - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - in1 = getorcreate!(model, ctx, :in1, nothing) - in2 = getorcreate!(model, ctx, :in2, nothing) - out = getorcreate!(model, ctx, :out, nothing) - make_node!(model, ctx, options, +, out, (in = [in1, in2],)) - @test nv(model) == 4 - - # Test 6: Stochastic node with default arguments - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - node_id, _ = make_node!(model, ctx, options, Normal, xref, (0, 1)) - @test nv(model) == 4 - @test getname.(edges(model, node_id)) == [:out, :μ, :σ] - @test getname.(edges(model, node_id)) == [:out, :μ, :σ] - - # Test 7: Stochastic node with instantiated object - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - uprior = Normal(0, 1) - xref = getorcreate!(model, ctx, :x, nothing) - node_id = make_node!(model, ctx, options, uprior, xref, nothing) - @test nv(model) == 2 - - # Test 8: Deterministic node with nodelabel objects where all interfaces are already defined (no missing interfaces) - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - in1 = getorcreate!(model, ctx, :in1, nothing) - in2 = getorcreate!(model, ctx, :in2, nothing) - out = getorcreate!(model, ctx, :out, nothing) - @test_throws "Expected only one missing interface, got () of length 0 (node sum with interfaces (:in, :out))" make_node!( - model, ctx, options, +, out, (in = in1, out = in2) - ) - - # Test 8: Stochastic node with nodelabel objects where we have an array on the rhs (so should create 1 node for [0, 1]) - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - out = getorcreate!(model, ctx, :out, nothing) - nodeid, _ = make_node!(model, ctx, options, ArbitraryNode, out, (in = [0, 1],)) - @test nv(model) == 3 && value(getproperties(model[ctx[:constvar_2]])) == [0, 1] - - # Test 9: Stochastic node with all interfaces defined as constants - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - out = getorcreate!(model, ctx, :out, nothing) - nodeid, _ = make_node!(model, ctx, options, ArbitraryNode, out, (1, 1)) - @test nv(model) == 4 - @test getname.(edges(model, nodeid)) == [:out, :in, :in] - @test getname.(edges(model, nodeid)) == [:out, :in, :in] - - #Test 10: Deterministic node with keyword arguments - function abc(; a = 1, b = 2) - return a + b - end - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - out = AnonymousVariable(model, ctx) - @test make_node!(model, ctx, options, abc, out, (a = 1, b = 2)) == (nothing, 3) - - # Test 11: Deterministic node with mixed arguments - function abc(a; b = 2) - return a + b - end - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - out = AnonymousVariable(model, ctx) - @test make_node!(model, ctx, options, abc, out, MixedArguments((2,), (b = 2,))) == (nothing, 4) - - # Test 12: Deterministic node with mixed arguments that has to be materialized should throw error - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - out = getorcreate!(model, ctx, :out, nothing) - a = getorcreate!(model, ctx, :a, nothing) - @test_throws ErrorException make_node!(model, ctx, options, abc, out, MixedArguments((a,), (b = 2,))) - - # Test 13: Make stochastic node with aliases - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - node_id = make_node!(model, ctx, options, Normal, xref, (μ = 0, τ = 1)) - @test any((key) -> fform(key) == NormalMeanPrecision, keys(ctx.factor_nodes)) - @test nv(model) == 4 - - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - node_id = make_node!(model, ctx, options, Normal, xref, (μ = 0, σ = 1)) - @test any((key) -> fform(key) == NormalMeanVariance, keys(ctx.factor_nodes)) - @test nv(model) == 4 - - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - node_id = make_node!(model, ctx, options, Normal, xref, (0, 1)) - @test any((key) -> fform(key) == NormalMeanVariance, keys(ctx.factor_nodes)) - @test nv(model) == 4 - - # Test 14: Make deterministic node with ProxyLabels as arguments - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - xref = proxylabel(:x, xref, nothing) - y = getorcreate!(model, ctx, :y, nothing) - y = proxylabel(:y, y, nothing) - zref = getorcreate!(model, ctx, :z, nothing) - node_id = make_node!(model, ctx, options, +, zref, (xref, y)) - prune!(model) - @test nv(model) == 4 - - # Test 15.1: Make stochastic node with aliased interfaces - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - μ = getorcreate!(model, ctx, :μ, nothing) - σ = getorcreate!(model, ctx, :σ, nothing) - out = getorcreate!(model, ctx, :out, nothing) - for keys in [(:mean, :variance), (:m, :variance), (:mean, :v)] - local node_id, _ = make_node!(model, ctx, options, Normal, out, NamedTuple{keys}((μ, σ))) - @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === NormalMeanVariance - @test GraphPPL.neighbors(model, node_id) == [out, μ, σ] - end - - # Test 15.2: Make stochastic node with aliased interfaces - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - μ = getorcreate!(model, ctx, :μ, nothing) - p = getorcreate!(model, ctx, :σ, nothing) - out = getorcreate!(model, ctx, :out, nothing) - for keys in [(:mean, :precision), (:m, :precision), (:mean, :p)] - local node_id, _ = make_node!(model, ctx, options, Normal, out, NamedTuple{keys}((μ, p))) - @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === NormalMeanPrecision - @test GraphPPL.neighbors(model, node_id) == [out, μ, p] - end -end - -@testitem "materialize_factor_node!" begin - using Distributions - using Graphs - import GraphPPL: - getcontext, - materialize_factor_node!, - create_model, - getorcreate!, - getifcreated, - proxylabel, - prune!, - getname, - label_for, - edges, - NodeCreationOptions - - include("testutils.jl") - - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, options, :x, nothing) - μref = getifcreated(model, ctx, 0) - σref = getifcreated(model, ctx, 1) - - # Test 1: Stochastic atomic call returns a new node - node_id, _, _ = materialize_factor_node!(model, ctx, options, Normal, (out = xref, μ = μref, σ = σref)) - @test nv(model) == 4 - @test getname.(edges(model, node_id)) == [:out, :μ, :σ] - @test getname.(edges(model, node_id)) == [:out, :μ, :σ] - - # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - μref = getifcreated(model, ctx, 0) - σref = getifcreated(model, ctx, 1) - materialize_factor_node!(model, ctx, options, Normal, (out = xref, μ = μref, σ = σref)) - @test nv(model) == 4 && ne(model) == 3 - - # Test 4: Deterministic atomic call with nodelabels should create the actual node - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - in1 = getorcreate!(model, ctx, :in1, nothing) - in2 = getorcreate!(model, ctx, :in2, nothing) - out = getorcreate!(model, ctx, :out, nothing) - materialize_factor_node!(model, ctx, options, +, (out = out, in = (in1, in2))) - @test nv(model) == 4 && ne(model) == 3 - - # Test 14: Make deterministic node with ProxyLabels as arguments - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - xref = proxylabel(:x, xref, nothing) - y = getorcreate!(model, ctx, :y, nothing) - y = proxylabel(:y, y, nothing) - zref = getorcreate!(model, ctx, :z, nothing) - node_id = materialize_factor_node!(model, ctx, options, +, (out = zref, in = (xref, y))) - prune!(model) - @test nv(model) == 4 -end - -@testitem "make_node!(::Composite)" begin - using MetaGraphsNext, Graphs - import GraphPPL: getcontext, make_node!, create_model, getorcreate!, proxylabel, NodeCreationOptions - - include("testutils.jl") - - using .TestUtils.ModelZoo - - #test make node for priors - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - make_node!(model, ctx, options, prior, proxylabel(:x, xref, nothing), ()) - @test nv(model) == 4 - @test ctx[prior, 1][:a] == proxylabel(:x, xref, nothing) - - #test make node for other composite models - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - @test_throws ErrorException make_node!(model, ctx, options, gcv, proxylabel(:x, xref, nothing), (0, 1)) - - # test make node of broadcastable composite model - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - out = getorcreate!(model, ctx, :out, nothing) - model = create_model(broadcaster()) - @test nv(model) == 103 -end - -@testitem "prune!(m::Model)" begin - using Graphs - import GraphPPL: create_model, getcontext, getorcreate!, prune!, create_model, getorcreate!, add_edge!, NodeCreationOptions - - include("testutils.jl") - - # Test 1: Prune a node with no edges - model = create_test_model() - ctx = getcontext(model) - xref = getorcreate!(model, ctx, :x, nothing) - prune!(model) - @test nv(model) == 0 - - # Test 2: Prune two nodes - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, nothing) - y, ydata, yproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum) - zref = getorcreate!(model, ctx, :z, nothing) - w = getorcreate!(model, ctx, :w, nothing) - - add_edge!(model, y, yproperties, zref, :test) - prune!(model) - @test nv(model) == 2 -end - -@testitem "broadcast" begin - import GraphPPL: NodeLabel, ResizableArray, create_model, getcontext, getorcreate!, make_node!, NodeCreationOptions - - include("testutils.jl") - - # Test 1: Broadcast a vector node - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, 1) - xref = getorcreate!(model, ctx, :x, 2) - y = getorcreate!(model, ctx, :y, 1) - y = getorcreate!(model, ctx, :y, 2) - zref = getorcreate!(model, ctx, :z, 1) - zref = getorcreate!(model, ctx, :z, 2) - zref = broadcast((z_, x_, y_) -> begin - var = make_node!(model, ctx, options, +, z_, (x_, y_)) - end, zref, xref, y) - @test size(zref) == (2,) - - # Test 2: Broadcast a matrix node - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, 1, 1) - xref = getorcreate!(model, ctx, :x, 1, 2) - xref = getorcreate!(model, ctx, :x, 2, 1) - xref = getorcreate!(model, ctx, :x, 2, 2) - - y = getorcreate!(model, ctx, :y, 1, 1) - y = getorcreate!(model, ctx, :y, 1, 2) - y = getorcreate!(model, ctx, :y, 2, 1) - y = getorcreate!(model, ctx, :y, 2, 2) - - zref = getorcreate!(model, ctx, :z, 1, 1) - zref = getorcreate!(model, ctx, :z, 1, 2) - zref = getorcreate!(model, ctx, :z, 2, 1) - zref = getorcreate!(model, ctx, :z, 2, 2) - - zref = broadcast((z_, x_, y_) -> begin - var = make_node!(model, ctx, options, +, z_, (x_, y_)) - end, zref, xref, y) - @test size(zref) == (2, 2) - - # Test 3: Broadcast a vector node with a matrix node - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - xref = getorcreate!(model, ctx, :x, 1) - xref = getorcreate!(model, ctx, :x, 2) - y = getorcreate!(model, ctx, :y, 1, 1) - y = getorcreate!(model, ctx, :y, 1, 2) - y = getorcreate!(model, ctx, :y, 2, 1) - y = getorcreate!(model, ctx, :y, 2, 2) - - zref = getorcreate!(model, ctx, :z, 1, 1) - zref = getorcreate!(model, ctx, :z, 1, 2) - zref = getorcreate!(model, ctx, :z, 2, 1) - zref = getorcreate!(model, ctx, :z, 2, 2) - - zref = broadcast((z_, x_, y_) -> begin - var = make_node!(model, ctx, options, +, z_, (x_, y_)) - end, zref, xref, y) - @test size(zref) == (2, 2) -end - -@testitem "getindex for StaticInterfaces" begin - import GraphPPL: StaticInterfaces - - interfaces = (:a, :b, :c) - sinterfaces = StaticInterfaces(interfaces) - - for (i, interface) in enumerate(interfaces) - @test sinterfaces[i] === interface - end -end - -@testitem "missing_interfaces" begin - import GraphPPL: missing_interfaces, interfaces - - include("testutils.jl") - - model = create_test_model() - - function abc end - - GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(abc), ::StaticInt{3}) = GraphPPL.StaticInterfaces((:in1, :in2, :out)) - - @test missing_interfaces(model, abc, static(3), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces((:out,)) - @test missing_interfaces(model, abc, static(3), (out = :y,)) == GraphPPL.StaticInterfaces((:in1, :in2)) - @test missing_interfaces(model, abc, static(3), NamedTuple()) == GraphPPL.StaticInterfaces((:in1, :in2, :out)) - - function xyz end - - GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInt{0}) = GraphPPL.StaticInterfaces(()) - @test missing_interfaces(model, xyz, static(0), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces(()) - - function foo end - - GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(foo), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:a, :b)) - @test missing_interfaces(model, foo, static(2), (a = 1, b = 2)) == GraphPPL.StaticInterfaces(()) - - function bar end - GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(bar), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:in1, :in2, :out)) - @test missing_interfaces(model, bar, static(2), (in1 = 1, in2 = 2, out = 3, test = 4)) == GraphPPL.StaticInterfaces(()) -end - -@testitem "sort_interfaces" begin - import GraphPPL: sort_interfaces - - include("testutils.jl") - - model = create_test_model() - - # Test 1: Test that sort_interfaces sorts the interfaces in the correct order - @test sort_interfaces(model, NormalMeanVariance, (μ = 1, σ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanVariance, (out = 1, μ = 1, σ = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanVariance, (σ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanVariance, (σ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (μ = 1, τ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (out = 1, μ = 1, τ = 1)) == (out = 1, μ = 1, τ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (τ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, τ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (τ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) - - @test_throws ErrorException sort_interfaces(model, NormalMeanVariance, (σ = 1, μ = 1, τ = 1)) -end - -@testitem "prepare_interfaces" begin - import GraphPPL: prepare_interfaces - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() - - @test prepare_interfaces(model, anonymous_in_loop, 1, (y = 1,)) == (x = 1, y = 1) - @test prepare_interfaces(model, anonymous_in_loop, 1, (x = 1,)) == (y = 1, x = 1) - - @test prepare_interfaces(model, type_arguments, 1, (x = 1,)) == (n = 1, x = 1) - @test prepare_interfaces(model, type_arguments, 1, (n = 1,)) == (x = 1, n = 1) -end - -@testitem "save and load graph" begin - import GraphPPL: create_model, with_plugins, savegraph, loadgraph, getextra, as_node - - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(with_plugins(vector_model(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) - mktemp() do file, io - file = file * ".jld2" - savegraph(file, model) - model2 = loadgraph(file, GraphPPL.Model) - for (node, node2) in zip(filter(as_node(), model), filter(as_node(), model2)) - @test node == node2 - @test GraphPPL.getextra(model[node], :factorization_constraint_bitset) == - GraphPPL.getextra(model2[node2], :factorization_constraint_bitset) - end - end -end - -@testitem "factor_alias" begin - import GraphPPL: factor_alias, StaticInterfaces - - include("testutils.jl") - - function abc end - function xyz end - - GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(abc), ::StaticInterfaces{(:a, :b)}) = abc - GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(abc), ::StaticInterfaces{(:x, :y)}) = xyz - - GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInterfaces{(:a, :b)}) = abc - GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInterfaces{(:x, :y)}) = xyz - - model = create_test_model() - - @test factor_alias(model, abc, StaticInterfaces((:a, :b))) === abc - @test factor_alias(model, abc, StaticInterfaces((:x, :y))) === xyz - - @test factor_alias(model, xyz, StaticInterfaces((:a, :b))) === abc - @test factor_alias(model, xyz, StaticInterfaces((:x, :y))) === xyz -end diff --git a/test/model/context_tests.jl b/test/model/context_tests.jl new file mode 100644 index 00000000..a0fe25b3 --- /dev/null +++ b/test/model/context_tests.jl @@ -0,0 +1,252 @@ + +@testitem "Context" begin + import GraphPPL: Context + + ctx1 = Context() + @test typeof(ctx1) == Context && ctx1.prefix == "" && length(ctx1.individual_variables) == 0 && ctx1.depth == 0 + + io = IOBuffer() + show(io, ctx1) + output = String(take!(io)) + @test !isempty(output) + @test contains(output, "identity") # fform + + # By default `returnval` is not defined + @test_throws UndefRefError GraphPPL.returnval(ctx1) + for i in 1:10 + GraphPPL.returnval!(ctx1, (i, "$i")) + @test GraphPPL.returnval(ctx1) == (i, "$i") + end + + function test end + + ctx2 = Context(0, test, "test", nothing) + @test contains(repr(ctx2), "test") + @test typeof(ctx2) == Context && ctx2.prefix == "test" && length(ctx2.individual_variables) == 0 && ctx2.depth == 0 + + function layer end + + ctx3 = Context(ctx2, layer) + @test typeof(ctx3) == Context && ctx3.prefix == "test_layer" && length(ctx3.individual_variables) == 0 && ctx3.depth == 1 + + @test_throws MethodError Context(ctx2, :my_model) + + function secondlayer end + + ctx5 = Context(ctx2, secondlayer) + @test typeof(ctx5) == Context && ctx5.prefix == "test_secondlayer" && length(ctx5.individual_variables) == 0 && ctx5.depth == 1 + + ctx6 = Context(ctx3, secondlayer) + @test typeof(ctx6) == Context && ctx6.prefix == "test_layer_secondlayer" && length(ctx6.individual_variables) == 0 && ctx6.depth == 2 +end + +@testitem "haskey(::Context)" begin + import GraphPPL: + Context, + NodeLabel, + ResizableArray, + ProxyLabel, + individual_variables, + vector_variables, + tensor_variables, + proxies, + children, + proxylabel + + ctx = Context() + xlab = NodeLabel(:x, 1) + @test !haskey(ctx, :x) + ctx[:x] = xlab + @test haskey(ctx, :x) + @test haskey(individual_variables(ctx), :x) + @test !haskey(vector_variables(ctx), :x) + @test !haskey(tensor_variables(ctx), :x) + @test !haskey(proxies(ctx), :x) + + @test !haskey(ctx, :y) + ctx[:y] = ResizableArray(NodeLabel, Val(1)) + @test haskey(ctx, :y) + @test !haskey(individual_variables(ctx), :y) + @test haskey(vector_variables(ctx), :y) + @test !haskey(tensor_variables(ctx), :y) + @test !haskey(proxies(ctx), :y) + + @test !haskey(ctx, :z) + ctx[:z] = ResizableArray(NodeLabel, Val(2)) + @test haskey(ctx, :z) + @test !haskey(individual_variables(ctx), :z) + @test !haskey(vector_variables(ctx), :z) + @test haskey(tensor_variables(ctx), :z) + @test !haskey(proxies(ctx), :z) + + @test !haskey(ctx, :proxy) + ctx[:proxy] = proxylabel(:proxy, xlab, nothing) + @test !haskey(individual_variables(ctx), :proxy) + @test !haskey(vector_variables(ctx), :proxy) + @test !haskey(tensor_variables(ctx), :proxy) + @test haskey(proxies(ctx), :proxy) + + @test !haskey(ctx, GraphPPL.FactorID(sum, 1)) + ctx[GraphPPL.FactorID(sum, 1)] = Context() + @test haskey(ctx, GraphPPL.FactorID(sum, 1)) + @test haskey(children(ctx), GraphPPL.FactorID(sum, 1)) +end + +@testitem "getindex(::Context, ::Symbol)" begin + import GraphPPL: Context, NodeLabel + + ctx = Context() + xlab = NodeLabel(:x, 1) + @test_throws KeyError ctx[:x] + ctx[:x] = xlab + @test ctx[:x] == xlab +end + +@testitem "getindex(::Context, ::FactorID)" begin + import GraphPPL: Context, NodeLabel, FactorID + + ctx = Context() + @test_throws KeyError ctx[FactorID(sum, 1)] + ctx[FactorID(sum, 1)] = Context() + @test ctx[FactorID(sum, 1)] == ctx.children[FactorID(sum, 1)] + + @test_throws KeyError ctx[FactorID(sum, 2)] + ctx[FactorID(sum, 2)] = NodeLabel(:sum, 1) + @test ctx[FactorID(sum, 2)] == ctx.factor_nodes[FactorID(sum, 2)] +end + +@testitem "getcontext(::Model)" begin + import GraphPPL: Context, getcontext, create_model, add_variable_node!, NodeCreationOptions + + include("testutils.jl") + + model = create_test_model() + @test getcontext(model) == model.graph[] + add_variable_node!(model, getcontext(model), NodeCreationOptions(), :x, nothing) + @test getcontext(model)[:x] == model.graph[][:x] +end + +@testitem "path_to_root(::Context)" begin + import GraphPPL: create_model, Context, path_to_root, getcontext + + include("testutils.jl") + + using .TestUtils.ModelZoo + + ctx = Context() + @test path_to_root(ctx) == [ctx] + + model = create_model(outer()) + ctx = getcontext(model) + inner_context = ctx[inner, 1] + inner_inner_context = inner_context[inner_inner, 1] + @test path_to_root(inner_inner_context) == [inner_inner_context, inner_context, ctx] +end + +@testitem "VarDict" begin + using Distributions + import GraphPPL: + Context, VarDict, create_model, getorcreate!, datalabel, NodeCreationOptions, getcontext, is_random, is_data, getproperties + + include("testutils.jl") + + ctx = Context() + vardict = VarDict(ctx) + @test isa(vardict, VarDict) + + @model function submodel(y, x_prev, x_next) + γ ~ Gamma(1, 1) + x_next ~ Normal(x_prev, γ) + y ~ Normal(x_next, 1) + end + + @model function state_space_model_with_new(y) + x[1] ~ Normal(0, 1) + y[1] ~ Normal(x[1], 1) + for i in 2:length(y) + y[i] ~ submodel(x_next = new(x[i]), x_prev = x[i - 1]) + end + end + + ydata = ones(10) + model = create_model(state_space_model_with_new()) do model, ctx + y = datalabel(model, ctx, NodeCreationOptions(kind = :data), :y, ydata) + return (y = y,) + end + + context = getcontext(model) + vardict = VarDict(context) + + @test haskey(vardict, :y) + @test haskey(vardict, :x) + for i in 1:(length(ydata) - 1) + @test haskey(vardict, (submodel, i)) + @test haskey(vardict[submodel, i], :γ) + end + + @test vardict[:y] === context[:y] + @test vardict[:x] === context[:x] + @test vardict[submodel, 1] == VarDict(context[submodel, 1]) + + result = map(identity, vardict) + @test haskey(result, :y) + @test haskey(result, :x) + for i in 1:(length(ydata) - 1) + @test haskey(result, (submodel, i)) + @test haskey(result[submodel, i], :γ) + end + + result = map(vardict) do variable + return length(variable) + end + @test haskey(result, :y) + @test haskey(result, :x) + @test result[:y] === length(ydata) + @test result[:x] === length(ydata) + for i in 1:(length(ydata) - 1) + @test result[(submodel, i)][:γ] === 1 + @test result[GraphPPL.FactorID(submodel, i)][:γ] === 1 + @test result[submodel, i][:γ] === 1 + end + + # Filter only random variables + result = filter(vardict) do label + if label isa GraphPPL.ResizableArray + all(is_random.(getproperties.(model[label]))) + else + return is_random(getproperties(model[label])) + end + end + @test !haskey(result, :y) + @test haskey(result, :x) + for i in 1:(length(ydata) - 1) + @test haskey(result, (submodel, i)) + @test haskey(result[submodel, i], :γ) + end + + # Filter only data variables + result = filter(vardict) do label + if label isa GraphPPL.ResizableArray + all(is_data.(getproperties.(model[label]))) + else + return is_data(getproperties(model[label])) + end + end + @test haskey(result, :y) + @test !haskey(result, :x) + for i in 1:(length(ydata) - 1) + @test haskey(result, (submodel, i)) + @test !haskey(result[submodel, i], :γ) + end +end + +@testitem "setindex!(::Context, ::ResizableArray{NodeLabel}, ::Symbol)" begin + import GraphPPL: NodeLabel, ResizableArray, Context, vector_variables, tensor_variables + + context = Context() + context[:x] = ResizableArray(NodeLabel, Val(1)) + @test haskey(vector_variables(context), :x) + + context[:y] = ResizableArray(NodeLabel, Val(2)) + @test haskey(tensor_variables(context), :y) +end \ No newline at end of file diff --git a/test/model/model_construction_tests.jl b/test/model/model_construction_tests.jl new file mode 100644 index 00000000..a2cf4a16 --- /dev/null +++ b/test/model/model_construction_tests.jl @@ -0,0 +1,38 @@ + +@testitem "model constructor" begin + import GraphPPL: create_model, Model + + include("testutils.jl") + + @test typeof(create_test_model()) <: Model + + @test_throws MethodError Model() +end + +# TODO this is not a test for GraphPPL but for the tests. +@testitem "create_test_model()" begin + import GraphPPL: create_model, Model, nv, ne + + include("testutils.jl") + + model = create_test_model() + @test typeof(model) <: Model && nv(model) == 0 && ne(model) == 0 + + @test_throws MethodError create_test_model(:x, :y, :z) +end + +@testitem "getcounter and setcounter!" begin + import GraphPPL: create_model, setcounter!, getcounter + + include("testutils.jl") + + model = create_test_model() + + @test setcounter!(model, 1) == 1 + @test getcounter(model) == 1 + @test setcounter!(model, 2) == 2 + @test getcounter(model) == 2 + @test setcounter!(model, getcounter(model) + 1) == 3 + @test setcounter!(model, 100) == 100 + @test getcounter(model) == 100 +end \ No newline at end of file diff --git a/test/model/model_operations_tests.jl b/test/model/model_operations_tests.jl new file mode 100644 index 00000000..44b0a279 --- /dev/null +++ b/test/model/model_operations_tests.jl @@ -0,0 +1,646 @@ +@testitem "getindex(::Model, ::NodeLabel)" begin + import GraphPPL: create_model, getcontext, NodeLabel, NodeData, VariableNodeProperties, getproperties + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + label = NodeLabel(:x, 1) + model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) + @test isa(model[label], NodeData) + @test isa(getproperties(model[label]), VariableNodeProperties) + @test_throws KeyError model[NodeLabel(:x, 10)] + @test_throws MethodError model[0] +end + +@testitem "copy_markov_blanket_to_child_context" begin + import GraphPPL: + create_model, copy_markov_blanket_to_child_context, Context, getorcreate!, proxylabel, unroll, getcontext, NodeCreationOptions + + include("testutils.jl") + + # Copy individual variables + model = create_test_model() + ctx = getcontext(model) + function child end + child_context = Context(ctx, child) + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + y = getorcreate!(model, ctx, NodeCreationOptions(), :y, nothing) + zref = getorcreate!(model, ctx, NodeCreationOptions(), :z, nothing) + + # Do not copy constant variables + model = create_test_model() + ctx = getcontext(model) + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + child_context = Context(ctx, child) + copy_markov_blanket_to_child_context(child_context, (in = 1,)) + @test !haskey(child_context, :in) + + # Do not copy vector valued constant variables + model = create_test_model() + ctx = getcontext(model) + child_context = Context(ctx, child) + copy_markov_blanket_to_child_context(child_context, (in = [1, 2, 3],)) + @test !haskey(child_context, :in) + + # Copy ProxyLabel variables to child context + model = create_test_model() + ctx = getcontext(model) + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + xref = proxylabel(:x, xref, nothing) + child_context = Context(ctx, child) + copy_markov_blanket_to_child_context(child_context, (in = xref,)) + @test child_context[:in] == xref +end + +@testitem "getorcreate!" begin + using Graphs + import GraphPPL: + create_model, + getcontext, + getorcreate!, + check_variate_compatability, + NodeLabel, + ResizableArray, + NodeCreationOptions, + getproperties, + is_kind + + include("testutils.jl") + + let # let block to suppress the scoping warnings + # Test 1: Creation of regular one-dimensional variable + model = create_test_model() + ctx = getcontext(model) + x = getorcreate!(model, ctx, :x, nothing) + @test nv(model) == 1 && ne(model) == 0 + + # Test 2: Ensure that getorcreating this variable again does not create a new node + x2 = getorcreate!(model, ctx, :x, nothing) + @test x == x2 && nv(model) == 1 && ne(model) == 0 + + # Test 3: Ensure that calling x another time gives us x + x = getorcreate!(model, ctx, :x, nothing) + @test x == x2 && nv(model) == 1 && ne(model) == 0 + + # Test 4: Test that creating a vector variable creates an array of the correct size + model = create_test_model() + ctx = getcontext(model) + y = getorcreate!(model, ctx, :y, 1) + @test nv(model) == 1 && ne(model) == 0 && y isa ResizableArray && y[1] isa NodeLabel + + # Test 5: Test that recreating the same variable changes nothing + y2 = getorcreate!(model, ctx, :y, 1) + @test y == y2 && nv(model) == 1 && ne(model) == 0 + + # Test 6: Test that adding a variable to this vector variable increases the size of the array + y = getorcreate!(model, ctx, :y, 2) + @test nv(model) == 2 && y[2] isa NodeLabel && haskey(ctx.vector_variables, :y) + + # Test 7: Test that getting this variable without index does not work + @test_throws ErrorException getorcreate!(model, ctx, :y, nothing) + + # Test 8: Test that getting this variable with an index that is too large does not work + @test_throws ErrorException getorcreate!(model, ctx, :y, 1, 2) + + #Test 9: Test that creating a tensor variable creates a tensor of the correct size + model = create_test_model() + ctx = getcontext(model) + z = getorcreate!(model, ctx, :z, 1, 1) + @test nv(model) == 1 && ne(model) == 0 && z isa ResizableArray && z[1, 1] isa NodeLabel + + #Test 10: Test that recreating the same variable changes nothing + z2 = getorcreate!(model, ctx, :z, 1, 1) + @test z == z2 && nv(model) == 1 && ne(model) == 0 + + #Test 11: Test that adding a variable to this tensor variable increases the size of the array + z = getorcreate!(model, ctx, :z, 2, 2) + @test nv(model) == 2 && z[2, 2] isa NodeLabel && haskey(ctx.tensor_variables, :z) + + #Test 12: Test that getting this variable without index does not work + @test_throws ErrorException z = getorcreate!(model, ctx, :z, nothing) + + #Test 13: Test that getting this variable with an index that is too small does not work + @test_throws ErrorException z = getorcreate!(model, ctx, :z, 1) + + #Test 14: Test that getting this variable with an index that is too large does not work + @test_throws ErrorException z = getorcreate!(model, ctx, :z, 1, 2, 3) + + # Test 15: Test that creating a variable that exists in the model scope but not in local scope still throws an error + let # force local scope + model = create_test_model() + ctx = getcontext(model) + getorcreate!(model, ctx, :a, nothing) + @test_throws ErrorException a = getorcreate!(model, ctx, :a, 1) + @test_throws ErrorException a = getorcreate!(model, ctx, :a, 1, 1) + end + + # Test 16. Test that the index is required to create a variable in the model + model = create_test_model() + ctx = getcontext(model) + @test_throws ErrorException getorcreate!(model, ctx, :a) + @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(), :a) + @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a) + @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 2), :a) + + # Test 17. Range based getorcreate! + model = create_test_model() + ctx = getcontext(model) + var = getorcreate!(model, ctx, :a, 1:2) + @test nv(model) == 2 && var[1] isa NodeLabel && var[2] isa NodeLabel + + # Test 17.1 Range based getorcreate! should use the same options + model = create_test_model() + ctx = getcontext(model) + var = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a, 1:2) + @test nv(model) == 2 && var[1] isa NodeLabel && var[2] isa NodeLabel + @test is_kind(getproperties(model[var[1]]), :data) + @test is_kind(getproperties(model[var[1]]), :data) + + # Test 18. Range x2 based getorcreate! + model = create_test_model() + ctx = getcontext(model) + var = getorcreate!(model, ctx, :a, 1:2, 1:3) + @test nv(model) == 6 + for i in 1:2, j in 1:3 + @test var[i, j] isa NodeLabel + end + + # Test 18. Range x2 based getorcreate! should use the same options + model = create_test_model() + ctx = getcontext(model) + var = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a, 1:2, 1:3) + @test nv(model) == 6 + for i in 1:2, j in 1:3 + @test var[i, j] isa NodeLabel + @test is_kind(getproperties(model[var[i, j]]), :data) + end + end +end + +@testitem "getifcreated" begin + using Graphs + import GraphPPL: + create_model, + getifcreated, + getorcreate!, + getcontext, + getproperties, + getname, + value, + getorcreate!, + proxylabel, + value, + NodeCreationOptions + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + + # Test case 1: check that getifcreated the variable created by getorcreate + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + @test getifcreated(model, ctx, xref) == xref + + # Test case 2: check that getifcreated returns the variable created by getorcreate in a vector + y = getorcreate!(model, ctx, NodeCreationOptions(), :y, 1) + @test getifcreated(model, ctx, y[1]) == y[1] + + # Test case 3: check that getifcreated returns a new variable node when called with integer input + c = getifcreated(model, ctx, 1) + @test value(getproperties(model[c])) == 1 + + # Test case 4: check that getifcreated returns a new variable node when called with a vector input + c = getifcreated(model, ctx, [1, 2]) + @test value(getproperties(model[c])) == [1, 2] + + # Test case 5: check that getifcreated returns a tuple of variable nodes when called with a tuple of NodeData + output = getifcreated(model, ctx, (xref, y[1])) + @test output == (xref, y[1]) + + # Test case 6: check that getifcreated returns a tuple of new variable nodes when called with a tuple of integers + output = getifcreated(model, ctx, (1, 2)) + @test value(getproperties(model[output[1]])) == 1 + @test value(getproperties(model[output[2]])) == 2 + + # Test case 7: check that getifcreated returns a tuple of variable nodes when called with a tuple of mixed input + output = getifcreated(model, ctx, (xref, 1)) + @test output[1] == xref && value(getproperties(model[output[2]])) == 1 + + # Test case 10: check that getifcreated returns the variable node if we create a variable and call it by symbol in a vector + model = create_test_model() + ctx = getcontext(model) + zref = getorcreate!(model, ctx, NodeCreationOptions(), :z, 1) + z_fetched = getifcreated(model, ctx, zref[1]) + @test z_fetched == zref[1] + + # Test case 11: Test that getifcreated returns a constant node when we call it with a symbol + model = create_test_model() + ctx = getcontext(model) + zref = getifcreated(model, ctx, :Bernoulli) + @test value(getproperties(model[zref])) == :Bernoulli + + # Test case 12: Test that getifcreated returns a vector of NodeLabels if called with a vector of NodeLabels + model = create_test_model() + ctx = getcontext(model) + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + y = getorcreate!(model, ctx, NodeCreationOptions(), :y, nothing) + zref = getifcreated(model, ctx, [xref, y]) + @test zref == [xref, y] + + # Test case 13: Test that getifcreated returns a ResizableArray tensor of NodeLabels if called with a ResizableArray tensor of NodeLabels + model = create_test_model() + ctx = getcontext(model) + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, 1, 1) + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, 2, 1) + zref = getifcreated(model, ctx, xref) + @test zref == xref + + # Test case 14: Test that getifcreated returns multiple variables if called with a tuple of constants + model = create_test_model() + ctx = getcontext(model) + zref = getifcreated(model, ctx, ([1, 1], 2)) + @test nv(model) == 2 && value(getproperties(model[zref[1]])) == [1, 1] && value(getproperties(model[zref[2]])) == 2 + + # Test case 15: Test that getifcreated returns a ProxyLabel if called with a ProxyLabel + model = create_test_model() + ctx = getcontext(model) + xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) + xref = proxylabel(:x, xref, nothing) + zref = getifcreated(model, ctx, xref) + @test zref === xref +end + +@testitem "make_node!(::Atomic)" begin + using Graphs, BitSetTuples + import GraphPPL: + getcontext, + make_node!, + create_model, + getorcreate!, + AnonymousVariable, + proxylabel, + getname, + label_for, + edges, + MixedArguments, + prune!, + fform, + value, + NodeCreationOptions, + getproperties + + include("testutils.jl") + + # Test 1: Deterministic call returns result of deterministic function and does not create new node + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = AnonymousVariable(model, ctx) + @test make_node!(model, ctx, options, +, xref, (1, 1)) == (nothing, 2) + @test make_node!(model, ctx, options, sin, xref, (0,)) == (nothing, 0) + @test nv(model) == 0 + + xref = proxylabel(:proxy, AnonymousVariable(model, ctx), nothing) + @test make_node!(model, ctx, options, +, xref, (1, 1)) == (nothing, 2) + @test make_node!(model, ctx, options, sin, xref, (0,)) == (nothing, 0) + @test nv(model) == 0 + + # Test 2: Stochastic atomic call returns a new node id + node_id, _ = make_node!(model, ctx, options, Normal, xref, (μ = 0, σ = 1)) + @test nv(model) == 4 + @test getname.(edges(model, node_id)) == [:out, :μ, :σ] + @test getname.(edges(model, node_id)) == [:out, :μ, :σ] + + # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + make_node!(model, ctx, options, Normal, xref, (0, 1)) + @test nv(model) == 4 && ne(model) == 3 + + # Test 4: Deterministic atomic call with nodelabels should create the actual node + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + in1 = getorcreate!(model, ctx, :in1, nothing) + in2 = getorcreate!(model, ctx, :in2, nothing) + out = getorcreate!(model, ctx, :out, nothing) + make_node!(model, ctx, options, +, out, (in1, in2)) + @test nv(model) == 4 && ne(model) == 3 + + # Test 5: Deterministic atomic call with nodelabels should create the actual node + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + in1 = getorcreate!(model, ctx, :in1, nothing) + in2 = getorcreate!(model, ctx, :in2, nothing) + out = getorcreate!(model, ctx, :out, nothing) + make_node!(model, ctx, options, +, out, (in = [in1, in2],)) + @test nv(model) == 4 + + # Test 6: Stochastic node with default arguments + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + node_id, _ = make_node!(model, ctx, options, Normal, xref, (0, 1)) + @test nv(model) == 4 + @test getname.(edges(model, node_id)) == [:out, :μ, :σ] + @test getname.(edges(model, node_id)) == [:out, :μ, :σ] + + # Test 7: Stochastic node with instantiated object + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + uprior = Normal(0, 1) + xref = getorcreate!(model, ctx, :x, nothing) + node_id = make_node!(model, ctx, options, uprior, xref, nothing) + @test nv(model) == 2 + + # Test 8: Deterministic node with nodelabel objects where all interfaces are already defined (no missing interfaces) + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + in1 = getorcreate!(model, ctx, :in1, nothing) + in2 = getorcreate!(model, ctx, :in2, nothing) + out = getorcreate!(model, ctx, :out, nothing) + @test_throws "Expected only one missing interface, got () of length 0 (node sum with interfaces (:in, :out))" make_node!( + model, ctx, options, +, out, (in = in1, out = in2) + ) + + # Test 8: Stochastic node with nodelabel objects where we have an array on the rhs (so should create 1 node for [0, 1]) + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + out = getorcreate!(model, ctx, :out, nothing) + nodeid, _ = make_node!(model, ctx, options, ArbitraryNode, out, (in = [0, 1],)) + @test nv(model) == 3 && value(getproperties(model[ctx[:constvar_2]])) == [0, 1] + + # Test 9: Stochastic node with all interfaces defined as constants + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + out = getorcreate!(model, ctx, :out, nothing) + nodeid, _ = make_node!(model, ctx, options, ArbitraryNode, out, (1, 1)) + @test nv(model) == 4 + @test getname.(edges(model, nodeid)) == [:out, :in, :in] + @test getname.(edges(model, nodeid)) == [:out, :in, :in] + + #Test 10: Deterministic node with keyword arguments + function abc(; a = 1, b = 2) + return a + b + end + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + out = AnonymousVariable(model, ctx) + @test make_node!(model, ctx, options, abc, out, (a = 1, b = 2)) == (nothing, 3) + + # Test 11: Deterministic node with mixed arguments + function abc(a; b = 2) + return a + b + end + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + out = AnonymousVariable(model, ctx) + @test make_node!(model, ctx, options, abc, out, MixedArguments((2,), (b = 2,))) == (nothing, 4) + + # Test 12: Deterministic node with mixed arguments that has to be materialized should throw error + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + out = getorcreate!(model, ctx, :out, nothing) + a = getorcreate!(model, ctx, :a, nothing) + @test_throws ErrorException make_node!(model, ctx, options, abc, out, MixedArguments((a,), (b = 2,))) + + # Test 13: Make stochastic node with aliases + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + node_id = make_node!(model, ctx, options, Normal, xref, (μ = 0, τ = 1)) + @test any((key) -> fform(key) == NormalMeanPrecision, keys(ctx.factor_nodes)) + @test nv(model) == 4 + + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + node_id = make_node!(model, ctx, options, Normal, xref, (μ = 0, σ = 1)) + @test any((key) -> fform(key) == NormalMeanVariance, keys(ctx.factor_nodes)) + @test nv(model) == 4 + + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + node_id = make_node!(model, ctx, options, Normal, xref, (0, 1)) + @test any((key) -> fform(key) == NormalMeanVariance, keys(ctx.factor_nodes)) + @test nv(model) == 4 + + # Test 14: Make deterministic node with ProxyLabels as arguments + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + xref = proxylabel(:x, xref, nothing) + y = getorcreate!(model, ctx, :y, nothing) + y = proxylabel(:y, y, nothing) + zref = getorcreate!(model, ctx, :z, nothing) + node_id = make_node!(model, ctx, options, +, zref, (xref, y)) + prune!(model) + @test nv(model) == 4 + + # Test 15.1: Make stochastic node with aliased interfaces + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + μ = getorcreate!(model, ctx, :μ, nothing) + σ = getorcreate!(model, ctx, :σ, nothing) + out = getorcreate!(model, ctx, :out, nothing) + for keys in [(:mean, :variance), (:m, :variance), (:mean, :v)] + local node_id, _ = make_node!(model, ctx, options, Normal, out, NamedTuple{keys}((μ, σ))) + @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === NormalMeanVariance + @test GraphPPL.neighbors(model, node_id) == [out, μ, σ] + end + + # Test 15.2: Make stochastic node with aliased interfaces + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + μ = getorcreate!(model, ctx, :μ, nothing) + p = getorcreate!(model, ctx, :σ, nothing) + out = getorcreate!(model, ctx, :out, nothing) + for keys in [(:mean, :precision), (:m, :precision), (:mean, :p)] + local node_id, _ = make_node!(model, ctx, options, Normal, out, NamedTuple{keys}((μ, p))) + @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === NormalMeanPrecision + @test GraphPPL.neighbors(model, node_id) == [out, μ, p] + end +end + +@testitem "materialize_factor_node!" begin + using Distributions + using Graphs + import GraphPPL: + getcontext, + materialize_factor_node!, + create_model, + getorcreate!, + getifcreated, + proxylabel, + prune!, + getname, + label_for, + edges, + NodeCreationOptions + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, options, :x, nothing) + μref = getifcreated(model, ctx, 0) + σref = getifcreated(model, ctx, 1) + + # Test 1: Stochastic atomic call returns a new node + node_id, _, _ = materialize_factor_node!(model, ctx, options, Normal, (out = xref, μ = μref, σ = σref)) + @test nv(model) == 4 + @test getname.(edges(model, node_id)) == [:out, :μ, :σ] + @test getname.(edges(model, node_id)) == [:out, :μ, :σ] + + # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + μref = getifcreated(model, ctx, 0) + σref = getifcreated(model, ctx, 1) + materialize_factor_node!(model, ctx, options, Normal, (out = xref, μ = μref, σ = σref)) + @test nv(model) == 4 && ne(model) == 3 + + # Test 4: Deterministic atomic call with nodelabels should create the actual node + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + in1 = getorcreate!(model, ctx, :in1, nothing) + in2 = getorcreate!(model, ctx, :in2, nothing) + out = getorcreate!(model, ctx, :out, nothing) + materialize_factor_node!(model, ctx, options, +, (out = out, in = (in1, in2))) + @test nv(model) == 4 && ne(model) == 3 + + # Test 14: Make deterministic node with ProxyLabels as arguments + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + xref = proxylabel(:x, xref, nothing) + y = getorcreate!(model, ctx, :y, nothing) + y = proxylabel(:y, y, nothing) + zref = getorcreate!(model, ctx, :z, nothing) + node_id = materialize_factor_node!(model, ctx, options, +, (out = zref, in = (xref, y))) + prune!(model) + @test nv(model) == 4 +end + +@testitem "make_node!(::Composite)" begin + using MetaGraphsNext, Graphs + import GraphPPL: getcontext, make_node!, create_model, getorcreate!, proxylabel, NodeCreationOptions + + include("testutils.jl") + + using .TestUtils.ModelZoo + + #test make node for priors + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + make_node!(model, ctx, options, prior, proxylabel(:x, xref, nothing), ()) + @test nv(model) == 4 + @test ctx[prior, 1][:a] == proxylabel(:x, xref, nothing) + + #test make node for other composite models + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, nothing) + @test_throws ErrorException make_node!(model, ctx, options, gcv, proxylabel(:x, xref, nothing), (0, 1)) + + # test make node of broadcastable composite model + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + out = getorcreate!(model, ctx, :out, nothing) + model = create_model(broadcaster()) + @test nv(model) == 103 +end + +@testitem "broadcast" begin + import GraphPPL: NodeLabel, ResizableArray, create_model, getcontext, getorcreate!, make_node!, NodeCreationOptions + + include("testutils.jl") + + # Test 1: Broadcast a vector node + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, 1) + xref = getorcreate!(model, ctx, :x, 2) + y = getorcreate!(model, ctx, :y, 1) + y = getorcreate!(model, ctx, :y, 2) + zref = getorcreate!(model, ctx, :z, 1) + zref = getorcreate!(model, ctx, :z, 2) + zref = broadcast((z_, x_, y_) -> begin + var = make_node!(model, ctx, options, +, z_, (x_, y_)) + end, zref, xref, y) + @test size(zref) == (2,) + + # Test 2: Broadcast a matrix node + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, 1, 1) + xref = getorcreate!(model, ctx, :x, 1, 2) + xref = getorcreate!(model, ctx, :x, 2, 1) + xref = getorcreate!(model, ctx, :x, 2, 2) + + y = getorcreate!(model, ctx, :y, 1, 1) + y = getorcreate!(model, ctx, :y, 1, 2) + y = getorcreate!(model, ctx, :y, 2, 1) + y = getorcreate!(model, ctx, :y, 2, 2) + + zref = getorcreate!(model, ctx, :z, 1, 1) + zref = getorcreate!(model, ctx, :z, 1, 2) + zref = getorcreate!(model, ctx, :z, 2, 1) + zref = getorcreate!(model, ctx, :z, 2, 2) + + zref = broadcast((z_, x_, y_) -> begin + var = make_node!(model, ctx, options, +, z_, (x_, y_)) + end, zref, xref, y) + @test size(zref) == (2, 2) + + # Test 3: Broadcast a vector node with a matrix node + model = create_test_model() + ctx = getcontext(model) + options = NodeCreationOptions() + xref = getorcreate!(model, ctx, :x, 1) + xref = getorcreate!(model, ctx, :x, 2) + y = getorcreate!(model, ctx, :y, 1, 1) + y = getorcreate!(model, ctx, :y, 1, 2) + y = getorcreate!(model, ctx, :y, 2, 1) + y = getorcreate!(model, ctx, :y, 2, 2) + + zref = getorcreate!(model, ctx, :z, 1, 1) + zref = getorcreate!(model, ctx, :z, 1, 2) + zref = getorcreate!(model, ctx, :z, 2, 1) + zref = getorcreate!(model, ctx, :z, 2, 2) + + zref = broadcast((z_, x_, y_) -> begin + var = make_node!(model, ctx, options, +, z_, (x_, y_)) + end, zref, xref, y) + @test size(zref) == (2, 2) +end \ No newline at end of file diff --git a/test/nodes/node_data_tests.jl b/test/nodes/node_data_tests.jl new file mode 100644 index 00000000..fba52203 --- /dev/null +++ b/test/nodes/node_data_tests.jl @@ -0,0 +1,224 @@ +@testitem "NodeData constructor" begin + import GraphPPL: create_model, getcontext, NodeData, FactorNodeProperties, VariableNodeProperties, getproperties + + include("testutils.jl") + + model = create_test_model() + context = getcontext(model) + + @testset "FactorNodeProperties" begin + properties = FactorNodeProperties(fform = String) + nodedata = NodeData(context, properties) + + @test getcontext(nodedata) === context + @test getproperties(nodedata) === properties + + io = IOBuffer() + + show(io, nodedata) + + output = String(take!(io)) + + @test !isempty(output) + @test contains(output, "String") # fform + end + + @testset "VariableNodeProperties" begin + properties = VariableNodeProperties(name = :x, index = 1) + nodedata = NodeData(context, properties) + + @test getcontext(nodedata) === context + @test getproperties(nodedata) === properties + + io = IOBuffer() + + show(io, nodedata) + + output = String(take!(io)) + + @test !isempty(output) + @test contains(output, "x") # name + @test contains(output, "1") # index + end +end + +@testitem "NodeDataExtraKey" begin + import GraphPPL: NodeDataExtraKey, getkey + + @test NodeDataExtraKey{:a, Int}() isa NodeDataExtraKey + @test NodeDataExtraKey{:a, Int}() === NodeDataExtraKey{:a, Int}() + @test NodeDataExtraKey{:a, Int}() !== NodeDataExtraKey{:a, Float64}() + @test NodeDataExtraKey{:a, Int}() !== NodeDataExtraKey{:b, Int}() + @test getkey(NodeDataExtraKey{:a, Int}()) === :a + @test getkey(NodeDataExtraKey{:a, Float64}()) === :a + @test getkey(NodeDataExtraKey{:b, Float64}()) === :b +end + +@testitem "NodeData extra properties" begin + import GraphPPL: + create_model, + getcontext, + NodeData, + FactorNodeProperties, + VariableNodeProperties, + getproperties, + setextra!, + getextra, + hasextra, + NodeDataExtraKey + + include("testutils.jl") + + model = create_test_model() + context = getcontext(model) + + @testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1)) + nodedata = NodeData(context, properties) + + @test !hasextra(nodedata, :a) + @test getextra(nodedata, :a, 2) === 2 + @test !hasextra(nodedata, :a) # the default should not add the extra property, only return + setextra!(nodedata, :a, 1) + @test hasextra(nodedata, :a) + @test getextra(nodedata, :a) === 1 + @test getextra(nodedata, :a, 2) === 1 + @test !hasextra(nodedata, :b) + @test_throws Exception getextra(nodedata, :b) + @test getextra(nodedata, :b, 2) === 2 + + # In the current implementation it is not possible to update extra properties + @test_throws Exception setextra!(nodedata, :a, 2) + + @test !hasextra(nodedata, :b) + setextra!(nodedata, :b, 2) + @test hasextra(nodedata, :b) + @test getextra(nodedata, :b) === 2 + + constkey_c_float = NodeDataExtraKey{:c, Float64}() + + @test !@inferred(hasextra(nodedata, constkey_c_float)) + @test @inferred(getextra(nodedata, constkey_c_float, 4.0)) === 4.0 + @inferred(setextra!(nodedata, constkey_c_float, 3.0)) + @test @inferred(hasextra(nodedata, constkey_c_float)) + @test @inferred(getextra(nodedata, constkey_c_float)) === 3.0 + @test @inferred(getextra(nodedata, constkey_c_float, 4.0)) === 3.0 + + # The default has a different type from the key (4.0 is Float and 4 is Int), thus the error + @test_throws MethodError getextra(nodedata, constkey_c_float, 4) + + constkey_d_int = NodeDataExtraKey{:d, Int64}() + + @test !@inferred(hasextra(nodedata, constkey_d_int)) + @inferred(setextra!(nodedata, constkey_d_int, 4)) + @test @inferred(hasextra(nodedata, constkey_d_int)) + @test @inferred(getextra(nodedata, constkey_d_int)) === 4 + end +end + +@testitem "NodeCreationOptions" begin + import GraphPPL: NodeCreationOptions, withopts, withoutopts + + include("testutils.jl") + + @test NodeCreationOptions() == NodeCreationOptions() + @test keys(NodeCreationOptions()) === () + @test NodeCreationOptions(arbitrary_option = 1) == NodeCreationOptions((; arbitrary_option = 1)) + + @test haskey(NodeCreationOptions(arbitrary_option = 1), :arbitrary_option) + @test NodeCreationOptions(arbitrary_option = 1)[:arbitrary_option] === 1 + + @test @inferred(haskey(NodeCreationOptions(), :a)) === false + @test @inferred(haskey(NodeCreationOptions(), :b)) === false + @test @inferred(haskey(NodeCreationOptions(a = 1, b = 2), :b)) === true + @test @inferred(haskey(NodeCreationOptions(a = 1, b = 2), :c)) === false + @test @inferred(NodeCreationOptions(a = 1, b = 2)[:a]) === 1 + @test @inferred(NodeCreationOptions(a = 1, b = 2)[:b]) === 2 + + @test_throws ErrorException NodeCreationOptions()[:a] + @test_throws ErrorException NodeCreationOptions(a = 1, b = 2)[:c] + + @test @inferred(get(NodeCreationOptions(), :a, 2)) === 2 + @test @inferred(get(NodeCreationOptions(), :b, 3)) === 3 + @test @inferred(get(NodeCreationOptions(), :c, 4)) === 4 + @test @inferred(get(NodeCreationOptions(a = 1, b = 2), :a, 2)) === 1 + @test @inferred(get(NodeCreationOptions(a = 1, b = 2), :b, 3)) === 2 + @test @inferred(get(NodeCreationOptions(a = 1, b = 2), :c, 4)) === 4 + + @test NodeCreationOptions(a = 1, b = 2)[(:a,)] === NodeCreationOptions(a = 1) + @test NodeCreationOptions(a = 1, b = 2)[(:b,)] === NodeCreationOptions(b = 2) + + @test keys(NodeCreationOptions(a = 1, b = 2)) == (:a, :b) + + @test @inferred(withopts(NodeCreationOptions(), (a = 1,))) == NodeCreationOptions(a = 1) + @test @inferred(withopts(NodeCreationOptions(b = 2), (a = 1,))) == NodeCreationOptions(b = 2, a = 1) + + @test @inferred(withoutopts(NodeCreationOptions(), Val((:a,)))) == NodeCreationOptions() + @test @inferred(withoutopts(NodeCreationOptions(b = 1), Val((:a,)))) == NodeCreationOptions(b = 1) + @test @inferred(withoutopts(NodeCreationOptions(a = 1), Val((:a,)))) == NodeCreationOptions() + @test @inferred(withoutopts(NodeCreationOptions(a = 1, b = 2), Val((:c,)))) == NodeCreationOptions(a = 1, b = 2) +end + +@testitem "is_constant" begin + import GraphPPL: create_model, is_constant, variable_nodes, getname, getproperties + + include("testutils.jl") + + using .TestUtils.ModelZoo + + for model_fn in ModelsInTheZooWithoutArguments + model = create_model(model_fn()) + for label in variable_nodes(model) + node = model[label] + props = getproperties(node) + if occursin("constvar", string(getname(props))) + @test is_constant(props) + else + @test !is_constant(props) + end + end + end +end + +@testitem "is_data" begin + import GraphPPL: is_data, create_model, getcontext, getorcreate!, variable_nodes, NodeCreationOptions, getproperties + + include("testutils.jl") + + m = create_test_model() + ctx = getcontext(m) + xref = getorcreate!(m, ctx, NodeCreationOptions(kind = :data), :x, nothing) + @test is_data(getproperties(m[xref])) + + using .TestUtils.ModelZoo + + # Since the models here are without top arguments they cannot create `data` labels + for model_fn in ModelsInTheZooWithoutArguments + model = create_model(model_fn()) + for label in variable_nodes(model) + @test !is_data(getproperties(model[label])) + end + end +end + +@testitem "Predefined kinds of variable nodes" begin + import GraphPPL: VariableKindRandom, VariableKindData, VariableKindConstant + import GraphPPL: getcontext, getorcreate!, NodeCreationOptions, getproperties + + include("testutils.jl") + + model = create_test_model() + context = getcontext(model) + xref = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindRandom), :x, nothing) + y = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindData), :y, nothing) + zref = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindConstant), :z, nothing) + + import GraphPPL: is_random, is_data, is_constant, is_kind + + xprops = getproperties(model[xref]) + yprops = getproperties(model[y]) + zprops = getproperties(model[zref]) + + @test is_random(xprops) && is_kind(xprops, VariableKindRandom) + @test is_data(yprops) && is_kind(yprops, VariableKindData) + @test is_constant(zprops) && is_kind(zprops, VariableKindConstant) +end \ No newline at end of file diff --git a/test/nodes/node_label_tests.jl b/test/nodes/node_label_tests.jl new file mode 100644 index 00000000..fe37a839 --- /dev/null +++ b/test/nodes/node_label_tests.jl @@ -0,0 +1,181 @@ +@testitem "NodeLabel properties" begin + import GraphPPL: NodeLabel + + xref = NodeLabel(:x, 1) + @test xref[1] == xref + @test length(xref) === 1 + @test GraphPPL.to_symbol(xref) === :x_1 + + y = NodeLabel(:y, 2) + @test xref < y +end + +@testitem "getname(::NodeLabel)" begin + import GraphPPL: ResizableArray, NodeLabel, getname + + xref = NodeLabel(:x, 1) + @test getname(xref) == :x + + xref = ResizableArray(NodeLabel, Val(1)) + xref[1] = NodeLabel(:x, 1) + @test getname(xref) == :x + + xref = ResizableArray(NodeLabel, Val(1)) + xref[2] = NodeLabel(:x, 1) + @test getname(xref) == :x +end + +@testitem "generate_nodelabel(::Model, ::Symbol)" begin + import GraphPPL: create_model, gensym, NodeLabel, generate_nodelabel + + include("testutils.jl") + + model = create_test_model() + first_sym = generate_nodelabel(model, :x) + @test typeof(first_sym) == NodeLabel + + second_sym = generate_nodelabel(model, :x) + @test first_sym != second_sym && first_sym.name == second_sym.name + + id = generate_nodelabel(model, :c) + @test id.name == :c && id.global_counter == 3 +end + +@testitem "proxy labels" begin + import GraphPPL: NodeLabel, ProxyLabel, proxylabel, getname, unroll, ResizableArray, FunctionalIndex + + y = NodeLabel(:y, 1) + + let p = proxylabel(:x, y, nothing) + @test last(p) === y + @test getname(p) === :x + @test getname(last(p)) === :y + end + + let p = proxylabel(:x, y, (1,)) + @test_throws "Indexing a single node label `y` with an index `[1]` is not allowed" unroll(p) + end + + let p = proxylabel(:x, y, (1, 2)) + @test_throws "Indexing a single node label `y` with an index `[1, 2]` is not allowed" unroll(p) + end + + let p = proxylabel(:r, proxylabel(:x, y, nothing), nothing) + @test last(p) === y + @test getname(p) === :r + @test getname(last(p)) === :y + end + + for n in (5, 10) + s = ResizableArray(NodeLabel, Val(1)) + + for i in 1:n + s[i] = NodeLabel(:s, i) + end + + let p = proxylabel(:x, s, nothing) + @test last(p) === s + @test all(i -> p[i] === s[i], 1:length(s)) + @test unroll(p) === s + end + + for i in 1:5 + let p = proxylabel(:r, proxylabel(:x, s, (i,)), nothing) + @test unroll(p) === s[i] + end + end + + let p = proxylabel(:r, proxylabel(:x, s, (2:4,)), (2,)) + @test unroll(p) === s[3] + end + let p = proxylabel(:x, s, (2:4,)) + @test p[1] === s[2] + end + end + + for n in (5, 10) + s = ResizableArray(NodeLabel, Val(1)) + + for i in 1:n + s[i] = NodeLabel(:s, i) + end + + let p = proxylabel(:x, s, FunctionalIndex{:begin}(firstindex)) + @test unroll(p) === s[begin] + end + end +end + +@testitem "datalabel" begin + import GraphPPL: getcontext, datalabel, NodeCreationOptions, VariableKindData, VariableKindRandom, unroll, proxylabel + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + ylabel = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :y) + yvar = unroll(ylabel) + @test haskey(ctx, :y) && ctx[:y] === yvar + @test GraphPPL.nv(model) === 1 + # subsequent unroll should return the same variable + unroll(ylabel) + @test haskey(ctx, :y) && ctx[:y] === yvar + @test GraphPPL.nv(model) === 1 + + yvlabel = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :yv, [1, 2, 3]) + for i in 1:3 + yvvar = unroll(proxylabel(:yv, yvlabel, (i,))) + @test haskey(ctx, :yv) && ctx[:yv][i] === yvvar + @test GraphPPL.nv(model) === 1 + i + end + # Incompatible data indices + @test_throws "The index `[4]` is not compatible with the underlying collection provided for the label `yv`" unroll( + proxylabel(:yv, yvlabel, (4,)) + ) + @test_throws "The underlying data provided for `yv` is `[1, 2, 3]`" unroll(proxylabel(:yv, yvlabel, (4,))) + + @test_throws "`datalabel` only supports `VariableKindData` in `NodeCreationOptions`" datalabel(model, ctx, NodeCreationOptions(), :z) + @test_throws "`datalabel` only supports `VariableKindData` in `NodeCreationOptions`" datalabel( + model, ctx, NodeCreationOptions(kind = VariableKindRandom), :z + ) +end + +@testitem "contains_nodelabel" begin + import GraphPPL: create_model, getcontext, getorcreate!, contains_nodelabel, NodeCreationOptions, True, False, MixedArguments + + include("testutils.jl") + + model = create_test_model() + ctx = getcontext(model) + a = getorcreate!(model, ctx, :x, nothing) + b = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :x, nothing) + c = 1.0 + + # Test 1. Tuple based input + @test contains_nodelabel((a, b, c)) === True() + @test contains_nodelabel((a, b)) === True() + @test contains_nodelabel((a,)) === True() + @test contains_nodelabel((b,)) === True() + @test contains_nodelabel((c,)) === False() + + # Test 2. Named tuple based input + @test @inferred(contains_nodelabel((; a = a, b = b, c = c))) === True() + @test @inferred(contains_nodelabel((; a = a, b = b))) === True() + @test @inferred(contains_nodelabel((; a = a))) === True() + @test @inferred(contains_nodelabel((; b = b))) === True() + @test @inferred(contains_nodelabel((; c = c))) === False() + + # Test 3. MixedArguments based input + @test @inferred(contains_nodelabel(MixedArguments((), (; a = a, b = b, c = c)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((), (; a = a, b = b)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((), (; a = a)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((), (; b = b)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((), (; c = c)))) === False() + + @test @inferred(contains_nodelabel(MixedArguments((a,), (; b = b, c = c)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((c,), (; a = a, b = b)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((b,), (; a = a)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((c,), (; b = b)))) === True() + @test @inferred(contains_nodelabel(MixedArguments((c,), (;)))) === False() + @test @inferred(contains_nodelabel(MixedArguments((), (; c = c)))) === False() +end \ No newline at end of file diff --git a/test/nodes/node_semantics_tests.jl b/test/nodes/node_semantics_tests.jl new file mode 100644 index 00000000..a5f1b54d --- /dev/null +++ b/test/nodes/node_semantics_tests.jl @@ -0,0 +1,209 @@ +@testitem "NodeType" begin + import GraphPPL: NodeType, Composite, Atomic + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_test_model() + + @test NodeType(model, Composite) == Atomic() + @test NodeType(model, Atomic) == Atomic() + @test NodeType(model, abs) == Atomic() + @test NodeType(model, Normal) == Atomic() + @test NodeType(model, NormalMeanVariance) == Atomic() + @test NodeType(model, NormalMeanPrecision) == Atomic() + + # Could test all here + for model_fn in ModelsInTheZooWithoutArguments + @test NodeType(model, model_fn) == Composite() + end +end + +@testitem "NodeBehaviour" begin + import GraphPPL: NodeBehaviour, Deterministic, Stochastic + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_test_model() + + @test NodeBehaviour(model, () -> 1) == Deterministic() + @test NodeBehaviour(model, Matrix) == Deterministic() + @test NodeBehaviour(model, abs) == Deterministic() + @test NodeBehaviour(model, Normal) == Stochastic() + @test NodeBehaviour(model, NormalMeanVariance) == Stochastic() + @test NodeBehaviour(model, NormalMeanPrecision) == Stochastic() + + # Could test all here + for model_fn in ModelsInTheZooWithoutArguments + @test NodeBehaviour(model, model_fn) == Stochastic() + end +end + +@testitem "interface_alias" begin + using GraphPPL + import GraphPPL: interface_aliases, StaticInterfaces + + include("testutils.jl") + + model = create_test_model() + + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :τ)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :precision)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :precision)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :τ)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :precision)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :τ)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :p)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :p)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :p)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :prec)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :prec)))) === StaticInterfaces((:out, :μ, :τ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :prec)))) === StaticInterfaces((:out, :μ, :τ)) + + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :τ)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :precision)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :τ)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :precision)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :precision)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :τ)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :p)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :p)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :p)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :prec)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :m, :prec)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :prec)))) === 0 + + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :σ)))) === StaticInterfaces((:out, :μ, :σ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :variance)))) === StaticInterfaces((:out, :μ, :σ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :variance)))) === StaticInterfaces((:out, :μ, :σ)) + @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :σ)))) === StaticInterfaces((:out, :μ, :σ)) + + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :σ)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :variance)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :σ)))) === 0 + @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :variance)))) === 0 +end + +@testitem "factor_alias" begin + import GraphPPL: factor_alias, StaticInterfaces + + include("testutils.jl") + + function abc end + function xyz end + + GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(abc), ::StaticInterfaces{(:a, :b)}) = abc + GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(abc), ::StaticInterfaces{(:x, :y)}) = xyz + + GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInterfaces{(:a, :b)}) = abc + GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInterfaces{(:x, :y)}) = xyz + + model = create_test_model() + + @test factor_alias(model, abc, StaticInterfaces((:a, :b))) === abc + @test factor_alias(model, abc, StaticInterfaces((:x, :y))) === xyz + + @test factor_alias(model, xyz, StaticInterfaces((:a, :b))) === abc + @test factor_alias(model, xyz, StaticInterfaces((:x, :y))) === xyz +end + +@testitem "default_parametrization" begin + import GraphPPL: default_parametrization, Composite, Atomic + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_test_model() + + # Test 1: Add default arguments to Normal call + @test default_parametrization(model, Atomic(), Normal, (0, 1)) == (μ = 0, σ = 1) + + # Test 2: Add :in to function call that has default behaviour + @test default_parametrization(model, Atomic(), +, (1, 2)) == (in = (1, 2),) + + # Test 3: Add :in to function call that has default behaviour with nested interfaces + @test default_parametrization(model, Atomic(), +, ([1, 1], 2)) == (in = ([1, 1], 2),) + + @test_throws ErrorException default_parametrization(model, Composite(), gcv, (1, 2)) +end + +@testitem "getindex for StaticInterfaces" begin + import GraphPPL: StaticInterfaces + + interfaces = (:a, :b, :c) + sinterfaces = StaticInterfaces(interfaces) + + for (i, interface) in enumerate(interfaces) + @test sinterfaces[i] === interface + end +end + +@testitem "missing_interfaces" begin + import GraphPPL: missing_interfaces, interfaces + + include("testutils.jl") + + model = create_test_model() + + function abc end + + GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(abc), ::StaticInt{3}) = GraphPPL.StaticInterfaces((:in1, :in2, :out)) + + @test missing_interfaces(model, abc, static(3), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces((:out,)) + @test missing_interfaces(model, abc, static(3), (out = :y,)) == GraphPPL.StaticInterfaces((:in1, :in2)) + @test missing_interfaces(model, abc, static(3), NamedTuple()) == GraphPPL.StaticInterfaces((:in1, :in2, :out)) + + function xyz end + + GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInt{0}) = GraphPPL.StaticInterfaces(()) + @test missing_interfaces(model, xyz, static(0), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces(()) + + function foo end + + GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(foo), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:a, :b)) + @test missing_interfaces(model, foo, static(2), (a = 1, b = 2)) == GraphPPL.StaticInterfaces(()) + + function bar end + GraphPPL.interfaces(::TestUtils.TestGraphPPLBackend, ::typeof(bar), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:in1, :in2, :out)) + @test missing_interfaces(model, bar, static(2), (in1 = 1, in2 = 2, out = 3, test = 4)) == GraphPPL.StaticInterfaces(()) +end + +@testitem "sort_interfaces" begin + import GraphPPL: sort_interfaces + + include("testutils.jl") + + model = create_test_model() + + # Test 1: Test that sort_interfaces sorts the interfaces in the correct order + @test sort_interfaces(model, NormalMeanVariance, (μ = 1, σ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, NormalMeanVariance, (out = 1, μ = 1, σ = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, NormalMeanVariance, (σ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, NormalMeanVariance, (σ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, NormalMeanPrecision, (μ = 1, τ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(model, NormalMeanPrecision, (out = 1, μ = 1, τ = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(model, NormalMeanPrecision, (τ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(model, NormalMeanPrecision, (τ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) + + @test_throws ErrorException sort_interfaces(model, NormalMeanVariance, (σ = 1, μ = 1, τ = 1)) +end + +@testitem "prepare_interfaces" begin + import GraphPPL: prepare_interfaces + + include("testutils.jl") + + using .TestUtils.ModelZoo + + model = create_test_model() + + @test prepare_interfaces(model, anonymous_in_loop, 1, (y = 1,)) == (x = 1, y = 1) + @test prepare_interfaces(model, anonymous_in_loop, 1, (x = 1,)) == (y = 1, x = 1) + + @test prepare_interfaces(model, type_arguments, 1, (x = 1,)) == (n = 1, x = 1) + @test prepare_interfaces(model, type_arguments, 1, (n = 1,)) == (x = 1, n = 1) +end \ No newline at end of file diff --git a/test/plugins/plugin_lifecycle_tests.jl b/test/plugins/plugin_lifecycle_tests.jl new file mode 100644 index 00000000..a60decbf --- /dev/null +++ b/test/plugins/plugin_lifecycle_tests.jl @@ -0,0 +1,70 @@ + +@testitem "Check that factor node plugins are uniquely recreated" begin + import GraphPPL: create_model, with_plugins, getplugins, factor_nodes, PluginsCollection, setextra!, getextra + + include("testutils.jl") + + using .TestUtils.ModelZoo + + struct AnArbitraryPluginForTestUniqeness end + + GraphPPL.plugin_type(::AnArbitraryPluginForTestUniqeness) = GraphPPL.FactorNodePlugin() + + count = Ref(0) + + function GraphPPL.preprocess_plugin(::AnArbitraryPluginForTestUniqeness, model, context, label, nodedata, options) + setextra!(nodedata, :count, count[]) + count[] = count[] + 1 + return label, nodedata + end + + for model_fn in ModelsInTheZooWithoutArguments + model = create_model(with_plugins(model_fn(), PluginsCollection(AnArbitraryPluginForTestUniqeness()))) + for f1 in factor_nodes(model), f2 in factor_nodes(model) + if f1 !== f2 + @test getextra(model[f1], :count) !== getextra(model[f2], :count) + else + @test getextra(model[f1], :count) === getextra(model[f2], :count) + end + end + end +end + +@testitem "Check that plugins may change the options" begin + import GraphPPL: + NodeData, + variable_nodes, + getname, + index, + is_constant, + getproperties, + value, + PluginsCollection, + VariableNodeProperties, + NodeCreationOptions, + create_model, + with_plugins + + include("testutils.jl") + + using .TestUtils.ModelZoo + + struct AnArbitraryPluginForChangingOptions end + + GraphPPL.plugin_type(::AnArbitraryPluginForChangingOptions) = GraphPPL.VariableNodePlugin() + + function GraphPPL.preprocess_plugin(::AnArbitraryPluginForChangingOptions, model, context, label, nodedata, options) + # Here we replace the original options entirely + return label, NodeData(context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0))) + end + + for model_fn in ModelsInTheZooWithoutArguments + model = create_model(with_plugins(model_fn(), PluginsCollection(AnArbitraryPluginForChangingOptions()))) + for v in variable_nodes(model) + @test getname(getproperties(model[v])) === :x + @test index(getproperties(model[v])) === nothing + @test is_constant(getproperties(model[v])) === true + @test value(getproperties(model[v])) === 1.0 + end + end +end \ No newline at end of file From e0ea7bb5955a7bceb8b5b84b0e83e3fad15eec23 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 7 May 2025 11:18:33 +0200 Subject: [PATCH 4/6] Refactor tests --- test/Project.toml | 5 +- test/backends/default_tests.jl | 6 +- test/ext/graphviz_integration_tests.jl | 8 +- test/graph/graph_modification_tests.jl | 62 +- test/graph/graph_properties_tests.jl | 40 +- test/graph/graph_traversal_filtering_tests.jl | 97 +-- test/graph/indexing_refs_tests.jl | 34 +- test/graph_construction_tests.jl | 397 +++++------ test/model/context_tests.jl | 26 +- test/model/model_construction_tests.jl | 21 +- test/model/model_operations_tests.jl | 156 ++-- test/model_generator_tests.jl | 59 +- test/model_macro_tests.jl | 411 +++++------ test/nodes/node_data_tests.jl | 42 +- test/nodes/node_label_tests.jl | 24 +- test/nodes/node_semantics_tests.jl | 106 ++- test/plugins/meta/meta_engine_tests.jl | 99 ++- test/plugins/meta/meta_macro_tests.jl | 66 +- test/plugins/meta/meta_tests.jl | 43 +- test/plugins/node_created_by_tests.jl | 18 +- test/plugins/node_id_tests.jl | 5 +- test/plugins/node_tag_tests.jl | 6 +- test/plugins/plugin_lifecycle_tests.jl | 16 +- .../variational_constraints_engine_tests.jl | 246 +++---- .../variational_constraints_macro_tests.jl | 147 ++-- .../variational_constraints_tests.jl | 221 +++--- test/runtests.jl | 20 +- test/testutils.jl | 674 +++++++++--------- 28 files changed, 1366 insertions(+), 1689 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 1c252e10..926c8c65 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,12 +2,13 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BitSetTuples = "0f2f92aa-23a3-4d05-b791-88071d064721" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" diff --git a/test/backends/default_tests.jl b/test/backends/default_tests.jl index a1d11d41..357a6cf4 100644 --- a/test/backends/default_tests.jl +++ b/test/backends/default_tests.jl @@ -4,14 +4,10 @@ @test instantiate(DefaultBackend) == DefaultBackend() end -@testitem "NodeBehaviour" begin +@testitem "NodeBehaviour" setup = [TestUtils] begin using Distributions import GraphPPL: DefaultBackend, NodeBehaviour, Stochastic, Deterministic - include("../testutils.jl") - - using .TestUtils.ModelZoo - # The `DefaultBackend` defines `Stochastic` behaviour for objects from the `Distributions` module @test NodeBehaviour(DefaultBackend(), Normal) == Stochastic() @test NodeBehaviour(DefaultBackend(), Gamma) == Stochastic() diff --git a/test/ext/graphviz_integration_tests.jl b/test/ext/graphviz_integration_tests.jl index b1454610..ef65d8c8 100644 --- a/test/ext/graphviz_integration_tests.jl +++ b/test/ext/graphviz_integration_tests.jl @@ -1,8 +1,6 @@ -@testitem "Model visualizations with GraphViz: generate DOT and save to file" begin +@testitem "Model visualizations with GraphViz: generate DOT and save to file" setup = [TestUtils] begin using GraphPPL, Distributions, GraphViz - include("../testutils.jl") - # test params for layout and strategy combinations layouts = ["dot", "neato"] strategies = [:bfs, :simple] @@ -12,10 +10,8 @@ mkdir(test_imgs_path) end - import .TestUtils.ModelZoo as A - # for all models in the models zoo - for model in TestUtils.ModelZoo.ModelsInTheZooWithoutArguments + for model in TestUtils.ModelsInTheZooWithoutArguments # for each combination of layout and strategy for gv_layout in layouts for gv_strategy in strategies diff --git a/test/graph/graph_modification_tests.jl b/test/graph/graph_modification_tests.jl index 6d4ff8f5..39e780f7 100644 --- a/test/graph/graph_modification_tests.jl +++ b/test/graph/graph_modification_tests.jl @@ -1,10 +1,8 @@ -@testitem "setindex!(::Model, ::NodeData, ::NodeLabel)" begin +@testitem "setindex!(::Model, ::NodeData, ::NodeLabel)" setup = [TestUtils] begin using Graphs import GraphPPL: getcontext, NodeLabel, NodeData, VariableNodeProperties, FactorNodeProperties - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) model[NodeLabel(:μ, 1)] = NodeData(ctx, VariableNodeProperties(name = :μ, index = nothing)) @test nv(model) == 1 && ne(model) == 0 @@ -20,13 +18,11 @@ @test_throws MethodError model["string"] = NodeData(ctx, FactorNodeProperties(fform = sum)) end -@testitem "setindex!(::Model, ::EdgeLabel, ::NodeLabel, ::NodeLabel)" begin +@testitem "setindex!(::Model, ::EdgeLabel, ::NodeLabel, ::NodeLabel)" setup = [TestUtils] begin using Graphs import GraphPPL: getcontext, NodeLabel, NodeData, VariableNodeProperties, EdgeLabel - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) μ = NodeLabel(:μ, 1) @@ -44,7 +40,7 @@ end @test ne(model) == 1 end -@testitem "add_variable_node!" begin +@testitem "add_variable_node!" setup = [TestUtils] begin import GraphPPL: create_model, add_variable_node!, @@ -59,10 +55,8 @@ end is_constant, value - include("testutils.jl") - # Test 1: simple add variable to model - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, nothing) @test nv(model) == 1 && haskey(ctx.individual_variables, :x) && ctx.individual_variables[:x] == node_id @@ -76,7 +70,7 @@ end @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), 1, 1) # Test 4: Add a vector variable to the model - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) ctx[:x] = ResizableArray(NodeLabel, Val(1)) node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, 2) @@ -87,7 +81,7 @@ end @test nv(model) == 2 && haskey(ctx, :x) && ctx[:x][1] == node_id && length(ctx[:x]) == 2 # Test 6: Add a tensor variable to the model - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) ctx[:x] = ResizableArray(NodeLabel, Val(2)) node_id = add_variable_node!(model, ctx, NodeCreationOptions(), :x, (2, 3)) @@ -98,7 +92,7 @@ end @test nv(model) == 2 && haskey(ctx, :x) && ctx[:x][2, 4] == node_id # Test 9: Add a variable with a non-integer index - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) ctx[:z] = ResizableArray(NodeLabel, Val(2)) @test_throws MethodError add_variable_node!(model, ctx, NodeCreationOptions(), :z, "a") @@ -111,7 +105,7 @@ end @test_throws BoundsError add_variable_node!(model, ctx, NodeCreationOptions(), :x, -1) # Test 11: Add a variable with options - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) var = add_variable_node!(model, ctx, NodeCreationOptions(kind = :constant, value = 1.0), :x, nothing) @test nv(model) == 1 && @@ -121,21 +115,19 @@ end value(getproperties(model[var])) == 1.0 # Test 12: Add a variable without options - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) var = add_variable_node!(model, ctx, :x, nothing) @test nv(model) == 1 && haskey(ctx, :x) && ctx[:x] == var end -@testitem "add_atomic_factor_node!" begin +@testitem "add_atomic_factor_node!" setup = [TestUtils] begin using Distributions using Graphs import GraphPPL: create_model, add_atomic_factor_node!, getorcreate!, getcontext, getorcreate!, label_for, getname, NodeCreationOptions - include("testutils.jl") - # Test 1: Add an atomic factor node to the model - model = create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.MetaPlugin())) + model = TestUtils.create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.MetaPlugin())) ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) @@ -157,7 +149,7 @@ end @test GraphPPL.getextra(node_data, :meta) == true # Test 4: Test that creating a node with an instantiated object is supported - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() prior = Normal(0, 1) @@ -166,14 +158,12 @@ end @test nv(model) == 1 && getname(label_for(model.graph, 1)) == Normal(0, 1) end -@testitem "add_composite_factor_node!" begin +@testitem "add_composite_factor_node!" setup = [TestUtils] begin using Graphs import GraphPPL: create_model, add_composite_factor_node!, getcontext, to_symbol, children, add_variable_node!, Context - include("testutils.jl") - # Add a composite factor node to the model - model = create_test_model() + model = TestUtils.create_test_model() parent_ctx = getcontext(model) child_ctx = getcontext(model) add_variable_node!(model, child_ctx, :x, nothing) @@ -200,13 +190,11 @@ end length(empty_ctx.individual_variables) == 0 end -@testitem "add_edge!(::Model, ::NodeLabel, ::NodeLabel, ::Symbol)" begin +@testitem "add_edge!(::Model, ::NodeLabel, ::NodeLabel, ::Symbol)" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, nv, ne, NodeData, NodeLabel, EdgeLabel, add_edge!, getorcreate!, generate_nodelabel, NodeCreationOptions - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum) @@ -219,12 +207,10 @@ end @test_throws MethodError add_edge!(model, xref, xproperties, y, 123) end -@testitem "add_edge!(::Model, ::NodeLabel, ::Vector{NodeLabel}, ::Symbol)" begin +@testitem "add_edge!(::Model, ::NodeLabel, ::Vector{NodeLabel}, ::Symbol)" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, nv, ne, NodeData, NodeLabel, EdgeLabel, add_edge!, getorcreate!, NodeCreationOptions - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() y = getorcreate!(model, ctx, :y, nothing) @@ -236,21 +222,19 @@ end @test ne(model) == 3 && model[variable_nodes[1], xref] == EdgeLabel(:interface, 1) end -@testitem "prune!(m::Model)" begin +@testitem "prune!(m::Model)" setup = [TestUtils] begin using Graphs import GraphPPL: create_model, getcontext, getorcreate!, prune!, create_model, getorcreate!, add_edge!, NodeCreationOptions - include("testutils.jl") - # Test 1: Prune a node with no edges - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = getorcreate!(model, ctx, :x, nothing) prune!(model) @test nv(model) == 0 # Test 2: Prune two nodes - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) diff --git a/test/graph/graph_properties_tests.jl b/test/graph/graph_properties_tests.jl index 4a5f678f..070c4113 100644 --- a/test/graph/graph_properties_tests.jl +++ b/test/graph/graph_properties_tests.jl @@ -1,10 +1,8 @@ -@testitem "degree" begin +@testitem "degree" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, getorcreate!, NodeCreationOptions, make_node!, degree - include("testutils.jl") - for n in 5:10 - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) unused = getorcreate!(model, ctx, :unusued, nothing) @@ -33,12 +31,10 @@ end end -@testitem "nv_ne(::Model)" begin +@testitem "nv_ne(::Model)" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, nv, ne, NodeData, VariableNodeProperties, NodeLabel, EdgeLabel - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) @test isempty(model) @test nv(model) == 0 @@ -56,7 +52,7 @@ end @test ne(model) == 1 end -@testitem "edges" begin +@testitem "edges" setup = [TestUtils] begin import GraphPPL: edges, create_model, @@ -72,10 +68,8 @@ end has_edge, getproperties - include("testutils.jl") - # Test 1: Test getting all edges from a model - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) a = NodeLabel(:a, 1) b = NodeLabel(:b, 2) @@ -105,7 +99,7 @@ end # @test getname.(edges(model, [a, b])) == [:edge, :edge, :edge] end -@testitem "neighbors(::Model, ::NodeData)" begin +@testitem "neighbors(::Model, ::NodeData)" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, @@ -120,11 +114,7 @@ end add_edge!, getproperties - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) a = NodeLabel(:a, 1) @@ -134,7 +124,7 @@ end add_edge!(model, a, getproperties(model[a]), b, :edge, 1) @test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)] - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) a = ResizableArray(NodeLabel, Val(1)) b = ResizableArray(NodeLabel, Val(1)) @@ -149,26 +139,22 @@ end @test n ∈ neighbors(model, a) end # Test 2: Test getting sorted neighbors - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) ctx = getcontext(model) node = first(neighbors(model, ctx[:z])) # Normal node we're investigating is the only neighbor of `z` in the graph. @test getname.(neighbors(model, node)) == [:z, :x, :y] # Test 3: Test getting sorted neighbors when one of the edge indices is nothing - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) ctx = getcontext(model) node = first(neighbors(model, ctx[:z][1])) @test getname.(collect(neighbors(model, node))) == [:z, :x, :y] end -@testitem "save and load graph" begin +@testitem "save and load graph" setup = [TestUtils] begin import GraphPPL: create_model, with_plugins, savegraph, loadgraph, getextra, as_node - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(with_plugins(vector_model(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model(with_plugins(TestUtils.vector_model(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) mktemp() do file, io file = file * ".jld2" savegraph(file, model) diff --git a/test/graph/graph_traversal_filtering_tests.jl b/test/graph/graph_traversal_filtering_tests.jl index 0cb4a0c7..ebcfd66d 100644 --- a/test/graph/graph_traversal_filtering_tests.jl +++ b/test/graph/graph_traversal_filtering_tests.jl @@ -1,11 +1,7 @@ -@testitem "factor_nodes" begin +@testitem "factor_nodes" setup = [TestUtils] begin import GraphPPL: create_model, factor_nodes, is_factor, labels - include("testutils.jl") - - using .TestUtils.ModelZoo - - for modelfn in ModelsInTheZooWithoutArguments + for modelfn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(modelfn()) fnodes = collect(factor_nodes(model)) for node in fnodes @@ -19,14 +15,10 @@ end end -@testitem "factor_nodes with lambda function" begin +@testitem "factor_nodes with lambda function" setup = [TestUtils] begin import GraphPPL: create_model, factor_nodes, is_factor, labels - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(model_fn()) fnodes = collect(factor_nodes(model)) factor_nodes(model) do label, nodedata @@ -44,14 +36,10 @@ end end end -@testitem "variable_nodes" begin +@testitem "variable_nodes" setup = [TestUtils] begin import GraphPPL: create_model, variable_nodes, is_variable, labels - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(model_fn()) fnodes = collect(variable_nodes(model)) for node in fnodes @@ -65,14 +53,10 @@ end end end -@testitem "variable_nodes with lambda function" begin +@testitem "variable_nodes with lambda function" setup = [TestUtils] begin import GraphPPL: create_model, variable_nodes, is_variable, labels - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(model_fn()) fnodes = collect(variable_nodes(model)) variable_nodes(model) do label, nodedata @@ -90,24 +74,22 @@ end end end -@testitem "variable_nodes with anonymous variables" begin +@testitem "variable_nodes with anonymous variables" setup = [TestUtils] begin # The idea here is that the `variable_nodes` must return ALL anonymous variables as well using Distributions import GraphPPL: create_model, variable_nodes, getname, is_anonymous, getproperties - include("testutils.jl") - - @model function simple_submodel_with_2_anonymous_for_variable_nodes(z, x, y) + TestUtils.@model function simple_submodel_with_2_anonymous_for_variable_nodes(z, x, y) # Creates two anonymous variables here z ~ Normal(x + 1, y - 1) end - @model function simple_submodel_with_3_anonymous_for_variable_nodes(z, x, y) + TestUtils.@model function simple_submodel_with_3_anonymous_for_variable_nodes(z, x, y) # Creates three anonymous variables here z ~ Normal(x + 1, y - 1 + 1) end - @model function simple_model_for_variable_nodes(submodel) + TestUtils.@model function simple_model_for_variable_nodes(submodel) xref ~ Normal(0, 1) y ~ Gamma(1, 1) zref ~ submodel(x = xref, y = y) @@ -126,51 +108,42 @@ end end end -@testitem "filter(::Predicate, ::Model)" begin +@testitem "filter(::Predicate, ::Model)" setup = [TestUtils] begin import GraphPPL: create_model, as_node, as_context, as_variable + using Distributions - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) result = collect(filter(as_node(Normal) | as_variable(:x), model)) @test length(result) == 3 - model = create_model(outer()) - result = collect(filter(as_node(Gamma) & as_context(inner_inner), model)) + model = create_model(TestUtils.outer()) + result = collect(filter(as_node(Gamma) & as_context(TestUtils.inner_inner), model)) @test length(result) == 0 - result = collect(filter(as_node(Gamma) | as_context(inner_inner), model)) + result = collect(filter(as_node(Gamma) | as_context(TestUtils.inner_inner), model)) @test length(result) == 6 - result = collect(filter(as_node(Normal) & as_context(inner_inner; children = true), model)) + result = collect(filter(as_node(Normal) & as_context(TestUtils.inner_inner; children = true), model)) @test length(result) == 1 end -@testitem "filter(::FactorNodePredicate, ::Model)" begin +@testitem "filter(::FactorNodePredicate, ::Model)" setup = [TestUtils] begin import GraphPPL: create_model, as_node, getcontext + using Distributions - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = getcontext(model) result = filter(as_node(Normal), model) - @test collect(result) == [context[NormalMeanVariance, 1], context[NormalMeanVariance, 2]] + @test collect(result) == [context[TestUtils.NormalMeanVariance, 1], context[TestUtils.NormalMeanVariance, 2]] result = filter(as_node(), model) - @test collect(result) == [context[NormalMeanVariance, 1], context[GammaShapeScale, 1], context[NormalMeanVariance, 2]] + @test collect(result) == + [context[TestUtils.NormalMeanVariance, 1], context[TestUtils.GammaShapeScale, 1], context[TestUtils.NormalMeanVariance, 2]] end -@testitem "filter(::VariableNodePredicate, ::Model)" begin +@testitem "filter(::VariableNodePredicate, ::Model)" setup = [TestUtils] begin import GraphPPL: create_model, as_variable, getcontext, variable_nodes - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = getcontext(model) result = filter(as_variable(:x), model) @test collect(result) == [context[:x]...] @@ -178,24 +151,20 @@ end @test collect(result) == collect(variable_nodes(model)) end -@testitem "filter(::SubmodelPredicate, Model)" begin +@testitem "filter(::SubmodelPredicate, Model)" setup = [TestUtils] begin import GraphPPL: create_model, as_context - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(outer()) + model = create_model(TestUtils.outer()) - result = filter(as_context(inner), model) + result = filter(as_context(TestUtils.inner), model) @test length(collect(result)) == 0 - result = filter(as_context(inner; children = true), model) + result = filter(as_context(TestUtils.inner; children = true), model) @test length(collect(result)) == 1 - result = filter(as_context(inner_inner), model) + result = filter(as_context(TestUtils.inner_inner), model) @test length(collect(result)) == 1 - result = filter(as_context(outer; children = true), model) + result = filter(as_context(TestUtils.outer; children = true), model) @test length(collect(result)) == 22 end \ No newline at end of file diff --git a/test/graph/indexing_refs_tests.jl b/test/graph/indexing_refs_tests.jl index 71206749..94c50b7d 100644 --- a/test/graph/indexing_refs_tests.jl +++ b/test/graph/indexing_refs_tests.jl @@ -135,13 +135,11 @@ end ) === 'o' end -@testitem "`VariableRef` iterators interface" begin +@testitem "`VariableRef` iterators interface" setup = [TestUtils] begin import GraphPPL: VariableRef, getcontext, NodeCreationOptions, VariableKindData, getorcreate! - include("testutils.jl") - @testset "Missing internal and external collections" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) @@ -151,7 +149,7 @@ end end @testset "Existing internal and external collections" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xcollection = getorcreate!(model, ctx, NodeCreationOptions(), :x, 1) xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (1,), xcollection) @@ -162,7 +160,7 @@ end end @testset "Missing internal but existing external collections" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = VariableRef(model, ctx, NodeCreationOptions(kind = VariableKindData), :x, (nothing,), [1.0 1.0; 1.0 1.0]) @@ -172,7 +170,7 @@ end end end -@testitem "`VariableRef` in combination with `ProxyLabel` should create variables in the model" begin +@testitem "`VariableRef` in combination with `ProxyLabel` should create variables in the model" setup = [TestUtils] begin import GraphPPL: VariableRef, makevarref, @@ -191,12 +189,10 @@ end MissingCollection, getorcreate! - using Distributions - - include("testutils.jl") + using Distributions, Static @testset "Individual variable creation" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) @@ -220,7 +216,7 @@ end end @testset "Vectored variable creation" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) @@ -238,7 +234,7 @@ end end @testset "Tensor variable creation" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) @test_throws "The variable `x` has been used, but has not been instantiated." getifcreated(model, ctx, xref) @@ -256,7 +252,7 @@ end end @testset "Variable should not be created if the `creation` flag is set to `False`" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) # `x` is not created here, should fail during `unroll` xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) @@ -271,7 +267,7 @@ end end @testset "Variable should be created if the `Atomic` fform is used as a first argument with `makevarref`" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) # `x` is not created here, but `makevarref` takes into account the `Atomic/Composite` # we always create a variable when used with `Atomic` @@ -282,7 +278,7 @@ end end @testset "It should be possible to toggle `maycreate` flag" begin - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) # The first time should throw since the variable has not been instantiated yet @@ -303,7 +299,7 @@ end end end -@testitem "`VariableRef` comparison" begin +@testitem "`VariableRef` comparison" setup = [TestUtils] begin import GraphPPL: VariableRef, makevarref, @@ -323,9 +319,7 @@ end using Distributions - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,)) @test xref == xref diff --git a/test/graph_construction_tests.jl b/test/graph_construction_tests.jl index eb6dd758..ca1d4f21 100644 --- a/test/graph_construction_tests.jl +++ b/test/graph_construction_tests.jl @@ -2,11 +2,12 @@ # We don't use models from the `model_zoo.jl` file because they are subject to change # These tests are meant to be stable and not change often -@testitem "Simple model 1" begin +@testitem "Simple model 1" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getcontext, + getorcreate!, add_toplevel_model!, factor_nodes, variable_nodes, @@ -16,9 +17,7 @@ as_variable, degree - include("testutils.jl") - - @model function simple_model_1() + TestUtils.@model function simple_model_1() x ~ Normal(0, 1) y ~ Gamma(1, 1) z ~ Normal(x, y) @@ -47,13 +46,11 @@ @test degree(model, first(collect(filter(as_variable(:z), model)))) === 1 end -@testitem "Simple model 2" begin +@testitem "Simple model 2" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getcontext, getorcreate!, add_toplevel_model!, as_node, NodeCreationOptions, prune! - include("testutils.jl") - - @model function simple_model_2(a, b, c) + TestUtils.@model function simple_model_2(a, b, c) x ~ Gamma(α = b, θ = sqrt(c)) a ~ Normal(μ = x, σ = 1) end @@ -70,13 +67,11 @@ end @test length(collect(filter(as_node(sqrt), model))) === 0 # should be compiled out, c is a constant end -@testitem "Simple model but wrong indexing into a single random variable" begin +@testitem "Simple model but wrong indexing into a single random variable" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions - include("testutils.jl") - - @model function simple_model_with_wrong_indexing(y) + TestUtils.@model function simple_model_with_wrong_indexing(y) x ~ MvNormal([0.0, 0.0], [1.0 0.0; 0.0 1.0]) y ~ Beta(x[1], x[2]) end @@ -89,13 +84,11 @@ end end end -@testitem "Simple model with lazy data (number) creation" begin +@testitem "Simple model with lazy data (number) creation" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions, is_data, is_constant, is_random, getproperties, datalabel - include("testutils.jl") - - @model function simple_model_3(a, b, c, d) + TestUtils.@model function simple_model_3(a, b, c, d) x ~ Beta(a, b) y ~ Gamma(c, d) z ~ Normal(x, y) @@ -118,7 +111,7 @@ end @test length(filter(label -> is_constant(getproperties(model[label])), collect(filter(as_variable(), model)))) === 0 end -@testitem "Simple model with lazy data (vector) creation" begin +@testitem "Simple model with lazy data (vector) creation" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, @@ -132,13 +125,11 @@ end datalabel, VariableKindData - include("testutils.jl") - - @model function simple_submodel_3(T, x, y, Λ) + TestUtils.@model function simple_submodel_3(T, x, y, Λ) T ~ Normal(x + y, Λ) end - @model function simple_model_3(y, Σ, n, T) + TestUtils.@model function simple_model_3(y, Σ, n, T) m ~ Beta(1, 1) for i in 1:n, j in 1:n T[i, j] ~ simple_submodel_3(x = m, Λ = Σ, y = y[i]) @@ -180,13 +171,11 @@ end end end -@testitem "Simple model with lazy data creation with attached data" begin +@testitem "Simple model with lazy data creation with attached data" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions, index, getproperties, is_kind, datalabel, MissingCollection - include("testutils.jl") - - @model function simple_model_4_withlength(y, Σ) + TestUtils.@model function simple_model_4_withlength(y, Σ) m ~ Beta(1, 1) for i in 1:length(y) @@ -194,7 +183,7 @@ end end end - @model function simple_model_4_withsize(y, Σ) + TestUtils.@model function simple_model_4_withsize(y, Σ) m ~ Beta(1, 1) for i in 1:size(y, 1) @@ -202,7 +191,7 @@ end end end - @model function simple_model_4_witheachindex(y, Σ) + TestUtils.@model function simple_model_4_witheachindex(y, Σ) m ~ Beta(1, 1) for i in eachindex(y) @@ -210,7 +199,7 @@ end end end - @model function simple_model_4_with_firstindex_lastindex(y, Σ) + TestUtils.@model function simple_model_4_with_firstindex_lastindex(y, Σ) m ~ Beta(1, 1) for i in firstindex(y):lastindex(y) @@ -218,7 +207,7 @@ end end end - @model function simple_model_4_with_forloop(y, Σ) + TestUtils.@model function simple_model_4_with_forloop(y, Σ) m ~ Beta(1, 1) for yᵢ in y @@ -226,7 +215,7 @@ end end end - @model function simple_model_4_with_foreach(y, Σ) + TestUtils.@model function simple_model_4_with_foreach(y, Σ) m ~ Beta(1, 1) foreach(y) do yᵢ @@ -302,13 +291,11 @@ end end end -@testitem "Simple model with lazy data creation with attached data but out of bounds" begin +@testitem "Simple model with lazy data creation with attached data but out of bounds" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions, index, getproperties, is_kind, datalabel - include("testutils.jl") - - @model function simple_model_a_vector(a) + TestUtils.@model function simple_model_a_vector(a) x ~ Beta(a[1], a[2]) # In the test the provided `a` will either a scalar or a vector of length 1 b ~ Gamma(a[3], a[4]) z ~ Normal(x, b) @@ -360,7 +347,7 @@ end end end - @model function simple_model_a_matrix(a) + TestUtils.@model function simple_model_a_matrix(a) x ~ Beta(a[1, 1], a[1, 2]) # In the test the provided `a` will either a scalar or a matrix of smaller size b ~ Gamma(a[2, 1], a[2, 2]) z ~ Normal(x, b) @@ -424,14 +411,12 @@ end end end -@testitem "Simple state space model" begin +@testitem "Simple state space model" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, add_toplevel_model!, degree - include("testutils.jl") - # Test that graph construction creates the right amount of nodes and variables in a simple state space model - @model function state_space_model(n) + TestUtils.@model function state_space_model(n) γ ~ Gamma(1, 1) x[1] ~ Normal(0, 1) y[1] ~ Normal(x[1], γ) @@ -441,7 +426,7 @@ end end end for n in [10, 30, 50, 100, 1000] - model = create_test_model() + model = TestUtils.create_test_model() add_toplevel_model!(model, state_space_model, (n = n,)) @test length(collect(filter(as_node(Normal), model))) == 2 * n @test length(collect(filter(as_variable(:x), model))) == n @@ -458,13 +443,11 @@ end end end -@testitem "Simple state space model with lazy data creation with attached data" begin +@testitem "Simple state space model with lazy data creation with attached data" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions, index, getproperties, is_random, is_data, degree - include("testutils.jl") - - @model function state_space_model_with_lazy_data(y, Σ) + TestUtils.@model function state_space_model_with_lazy_data(y, Σ) x[1] ~ Normal(0, 1) y[1] ~ Normal(x[1], Σ) for i in 2:length(y) @@ -513,24 +496,21 @@ end end end -@testitem "Nested model structure" begin +@testitem "Nested model structure" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, add_toplevel_model! - include("testutils.jl") - - # Test that graph construction creates the right amount of nodes and variables in a nested model structure - @model function gcv(κ, ω, z, x, y) + TestUtils.@model function gcv(κ, ω, z, x, y) log_σ := κ * z + ω y ~ Normal(x, exp(log_σ)) end - @model function gcv_lm(y, x_prev, x_next, z, ω, κ) + TestUtils.@model function gcv_lm(y, x_prev, x_next, z, ω, κ) x_next ~ gcv(x = x_prev, z = z, ω = ω, κ = κ) y ~ Normal(x_next, 1) end - @model function hgf(y) + TestUtils.@model function hgf(y) # Specify priors @@ -569,17 +549,15 @@ end end end -@testitem "Nested model structure but with constants" begin +@testitem "Nested model structure but with constants" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions - include("testutils.jl") - - @model function submodel(y, x, z) + TestUtils.@model function submodel(y, x, z) y ~ Normal(x, z) end - @model function mainmodel(y) + TestUtils.@model function mainmodel(y) y ~ submodel(x = 1, z = 2) end @@ -594,18 +572,16 @@ end @test length(collect(filter(as_variable(:z), model))) === 0 end -@testitem "Force create a new variable with the `new` syntax" begin +@testitem "Force create a new variable with the `new` syntax" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions - include("testutils.jl") - - @model function submodel(y, x_prev, x_next) + TestUtils.@model function submodel(y, x_prev, x_next) x_next ~ Normal(x_prev, 1) y ~ Normal(x_next, 1) end - @model function state_space_model(y) + TestUtils.@model function state_space_model(y) x[1] ~ Normal(0, 1) y[1] ~ Normal(x[1], 1) @@ -622,7 +598,7 @@ end return (y = y,) end - @model function state_space_model_with_new(y) + TestUtils.@model function state_space_model_with_new(y) x[1] ~ Normal(0, 1) y[1] ~ Normal(x[1], 1) for i in 2:length(y) @@ -641,13 +617,11 @@ end @test length(collect(filter(as_variable(:y), model))) === 10 end -@testitem "Anonymous variables should not be created from arithmetical operations on pure constants" begin +@testitem "Anonymous variables should not be created from arithmetical operations on pure constants" setup = [TestUtils] begin using Distributions, LinearAlgebra import GraphPPL: create_model, getorcreate!, NodeCreationOptions, datalabel, variable_nodes, getproperties, is_random, getname - include("testutils.jl") - - @model function mv_iid_inverse_wishart_known_mean(y, d) + TestUtils.@model function mv_iid_inverse_wishart_known_mean(y, d) m ~ MvNormal(zeros(d + 1 - 1 + 1 - 1), Matrix(Diagonal(ones(d + 1 - 1 + 1 - 1)))) C ~ InverseWishart(d + 1, Matrix(Diagonal(ones(d)))) @@ -685,12 +659,10 @@ end end end -@testitem "Aliases in the model should be resolved automatically" begin +@testitem "Aliases in the model should be resolved automatically" setup = [TestUtils] begin import GraphPPL: create_model, getorcreate!, NodeCreationOptions, fform, factor_nodes, getproperties - - include("testutils.jl") - - @model function aliases_for_normal(s4) + using Distributions + TestUtils.@model function aliases_for_normal(s4) r1 ~ Normal(μ = 1.0, τ = 1.0) r2 ~ Normal(m = r1, γ = 1.0) r3 ~ Normal(mean = r2, σ⁻² = 1.0) @@ -714,27 +686,25 @@ end # The manual search however does indicate that the aliases are resolved and `Normal` node has NOT been created (as intended) @test length(collect(filter(label -> fform(getproperties(model[label])) === Normal, collect(factor_nodes(model))))) === 0 # Double check the number of `NormalMeanPrecision` and `NormalMeanVariance` nodes - @test length(collect(filter(as_node(NormalMeanPrecision), model))) === 7 - @test length(collect(filter(as_node(NormalMeanVariance), model))) === 4 + @test length(collect(filter(as_node(TestUtils.NormalMeanPrecision), model))) === 7 + @test length(collect(filter(as_node(TestUtils.NormalMeanVariance), model))) === 4 end -@testitem "Submodels can be used in the keyword arguments" begin +@testitem "Submodels can be used in the keyword arguments" setup = [TestUtils] begin using Distributions, LinearAlgebra import GraphPPL: create_model, getorcreate!, NodeCreationOptions, datalabel, variable_nodes, getproperties, is_random, getname - include("testutils.jl") - - @model function prod_distributions(a, b, c) + TestUtils.@model function prod_distributions(a, b, c) a ~ b * c end # The test tests if we can write `μ = prod_distributions(b = A, c = x_prev)` - @model function state_transition_with_submodel(y_next, x_next, x_prev, A, B, P, Q) + TestUtils.@model function state_transition_with_submodel(y_next, x_next, x_prev, A, B, P, Q) x_next ~ MvNormal(μ = prod_distributions(b = A, c = x_prev), Σ = Q) y_next ~ MvNormal(μ = prod_distributions(b = B, c = x_next), Σ = P) end - @model function multivariate_lgssm_model_with_several_submodels(y, mean0, cov0, A, B, Q, P) + TestUtils.@model function multivariate_lgssm_model_with_several_submodels(y, mean0, cov0, A, B, Q, P) x_prev ~ MvNormal(μ = mean0, Σ = cov0) for i in eachindex(y) x[i] ~ state_transition_with_submodel(y_next = y[i], x_prev = x_prev, A = A, B = B, P = P, Q = Q) @@ -764,13 +734,11 @@ end @test length(collect(filter(as_variable(:x), model))) === 10 end -@testitem "Using distribution objects as priors" begin +@testitem "Using distribution objects as priors" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions, datalabel - include("testutils.jl") - - @model function coin_model_priors(y, prior) + TestUtils.@model function coin_model_priors(y, prior) θ ~ prior for i in eachindex(y) y[i] ~ Bernoulli(θ) @@ -788,18 +756,18 @@ end @test length(collect(filter(as_node(prior), model))) === 1 end -@testitem "Model that passes a slice to child model" begin +@testitem "Model that passes a slice to child model" setup = [TestUtils] begin using GraphPPL - include("testutils.jl") + using Distributions - @model function mixed_v(y, v) + TestUtils.@model function mixed_v(y, v) for i in 1:3 v[i] ~ Normal(0, 1) end y ~ Normal(v[1], v[2]) end - @model function mixed_m() + TestUtils.@model function mixed_m() local m for i in 1:3 for j in 1:3 @@ -815,17 +783,16 @@ end @test haskey(context[mixed_v, 1], :v) end -@testitem "Model that constructs a new vector to pass to children" begin - include("testutils.jl") - - @model function mixed_v(y, v) +@testitem "Model that constructs a new vector to pass to children" setup = [TestUtils] begin + using Distributions + TestUtils.@model function mixed_v(y, v) for i in 1:3 v[i] ~ Normal(0, 1) end y ~ Normal(v[1], v[2]) end - @model function mixed_m() + TestUtils.@model function mixed_m() v1 ~ Normal(0, 1) v2 ~ Normal(0, 1) v3 ~ Normal(0, 1) @@ -838,17 +805,16 @@ end @test haskey(context[mixed_v, 1], :v) end -@testitem "Model that constructs a new matrix to pass to children" begin - include("testutils.jl") - - @model function mixed_v(y, v) +@testitem "Model that constructs a new matrix to pass to children" setup = [TestUtils] begin + using Distributions + TestUtils.@model function mixed_v(y, v) for i in 1:3 v[i] ~ Normal(0, 1) end y ~ Normal(v[1], v[3]) end - @model function mixed_m() + TestUtils.@model function mixed_m() v1 ~ Normal(0, 1) v2 ~ Normal(0, 1) v3 ~ Normal(0, 1) @@ -861,13 +827,11 @@ end @test haskey(context[mixed_v, 1], :v) end -@testitem "Model creation should throw if a `~` using with a constant on RHS" begin +@testitem "Model creation should throw if a `~` using with a constant on RHS" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions, datalabel - include("testutils.jl") - - @model function broken_beta_bernoulli(y) + TestUtils.@model function broken_beta_bernoulli(y) # This should throw an error since `Matrix` is not defined as a proper node θ ~ Matrix([1.0 0.0; 0.0 1.0]) for i in eachindex(y) @@ -882,13 +846,11 @@ end end end -@testitem "Condition based initialization of variables" begin +@testitem "Condition based initialization of variables" setup = [TestUtils] begin using Distributions import GraphPPL: create_model - include("testutils.jl") - - @model function condition_based_initialization(condition) + TestUtils.@model function condition_based_initialization(condition) if condition y ~ Normal(0.0, 1.0) else @@ -909,23 +871,21 @@ end @test length(collect(filter(as_node(Gamma), model2))) == 1 end -@testitem "Attempt to trick Julia's parser" begin +@testitem "Attempt to trick Julia's parser" setup = [TestUtils] begin using Distributions import GraphPPL: create_model - include("testutils.jl") - # We use `@isdefined` macro inside the macro generator code to check if the variable is defined # The idea of this test is to double check that `@model` parser and Julia in particular # does not confuse the undefined `y` variable with the `y` variable defined in the model - @model function tricky_model_1() + TestUtils.@model function tricky_model_1() b ~ Normal(0.0, 1.0) if false b = nothing end end - @model function tricky_model_2() + TestUtils.@model function tricky_model_2() b ~ Normal(0.0, 1.0) if false b = nothing @@ -941,7 +901,7 @@ end global yy = 1 - @model function tricky_model_3() + TestUtils.@model function tricky_model_3() yy ~ Normal(0.0, 1.0) # This is technically not allowed in real models # However we want the `@model` macro to instantiate a different `yy` variable @@ -962,7 +922,7 @@ end # We double check though that the `@model` macro may depend on global variables if needed global boolean = true - @model function model_that_uses_global_variables_1() + TestUtils.@model function model_that_uses_global_variables_1() if boolean b ~ Normal(0.0, 1.0) else @@ -978,7 +938,7 @@ end global m = 0.0 global v = 1.0 - @model function model_that_uses_global_variables_2() + TestUtils.@model function model_that_uses_global_variables_2() b ~ Normal(m, v) end @@ -995,14 +955,12 @@ end @test GraphPPL.is_constant(nodeneighborsproperties[3]) && GraphPPL.value(nodeneighborsproperties[3]) === v end -@testitem "Broadcasting in the model" begin +@testitem "Broadcasting in the model" setup = [TestUtils] begin using Distributions import GraphPPL: create_model using LinearAlgebra - include("testutils.jl") - - @model function linreg() + TestUtils.@model function linreg() x .~ Normal(fill(0, 10), 1) a .~ Normal(fill(0, 10), 1) b .~ Normal(fill(0, 10), 1) @@ -1014,7 +972,7 @@ end @test length(collect(filter(as_node(sum), model))) == 10 @test length(collect(filter(as_node(prod), model))) == 10 - @model function nested_normal() + TestUtils.@model function nested_normal() x .~ Normal(fill(0, 10), 1) a .~ Gamma(fill(0, 10), 1) b .~ Normal(Normal.(Normal.(x, 1), a), 1) @@ -1027,7 +985,7 @@ end function foo end GraphPPL.NodeBehaviour(::TestUtils.TestGraphPPLBackend, ::typeof(foo)) = GraphPPL.Stochastic() - @model function emtpy_broadcast() + TestUtils.@model function emtpy_broadcast() x .~ Normal(fill(0, 10), 1) x .~ foo() end @@ -1035,7 +993,7 @@ end model = create_model(emtpy_broadcast()) @test length(collect(filter(as_node(foo), model))) == 10 - @model function coin_toss(x) + TestUtils.@model function coin_toss(x) π ~ Beta(1, 1) x .~ Bernoulli(π) end @@ -1045,20 +1003,18 @@ end @test length(collect(filter(as_node(Bernoulli), model))) == 10 @test length(collect(filter(as_node(Beta), model))) == 1 - @model function weird_broadcast() + TestUtils.@model function weird_broadcast() π ~ Beta(1, 1) z .~ Bernoulli(Normal.(0, 1)) end @test_throws ErrorException local model = create_model(weird_broadcast()) end -@testitem "Broadcasting with datalabels" begin +@testitem "Broadcasting with datalabels" setup = [TestUtils] begin using Distributions, LinearAlgebra import GraphPPL: create_model, getorcreate!, NodeCreationOptions, datalabel, MissingCollection - include("testutils.jl") - - @model function linear_regression_broadcasted(x, y) + TestUtils.@model function linear_regression_broadcasted(x, y) a ~ Normal(mean = 0.0, var = 1.0) b ~ Normal(mean = 0.0, var = 1.0) # Variance over-complicated for a purpose of checking that this expressions are allowed, it should be equal to `1.0` @@ -1106,7 +1062,7 @@ end ) end - @model function beta_bernoulli_broadcasted(x) + TestUtils.@model function beta_bernoulli_broadcasted(x) θ ~ Beta(1, 1) x .~ Bernoulli(θ) end @@ -1121,13 +1077,11 @@ end @test length(collect(filter(as_node(Beta), model))) == 1 end -@testitem "Ambiguous broadcasting should give a descriptive error" begin +@testitem "Ambiguous broadcasting should give a descriptive error" setup = [TestUtils] begin using Distributions, LinearAlgebra import GraphPPL: create_model, getorcreate!, NodeCreationOptions - include("testutils.jl") - - @model function faulty_beta_bernoulli_broadcasted() + TestUtils.@model function faulty_beta_bernoulli_broadcasted() θ ~ Beta(1, 1) x .~ Bernoulli(θ) end @@ -1138,13 +1092,11 @@ end ) end -@testitem "Broadcasting over ranges" begin +@testitem "Broadcasting over ranges" setup = [TestUtils] begin using Distributions, LinearAlgebra import GraphPPL: create_model, getproperties, neighbor_data, is_random, is_constant, value - include("testutils.jl") - - @model function broadcasting_over_range() + TestUtils.@model function broadcasting_over_range() # Should create 10 `x` variables x .~ Normal(ones(10), 1) @@ -1175,13 +1127,11 @@ end end end -@testitem "Complex ranges with `begin`/`end` should be supported" begin +@testitem "Complex ranges with `begin`/`end` should be supported" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, getproperties, neighbor_data, is_constant, value - include("testutils.jl") - - @model function complex_ranges_with_begin_end_1() + TestUtils.@model function complex_ranges_with_begin_end_1() c = [1.0, 2.0] b[1] ~ Normal(0.0, c[begin + 1]) b[2] ~ Normal(0.0, c[end - 1]) @@ -1201,22 +1151,22 @@ end @test is_constant(c_for_y_2) && value(c_for_y_2) === 1.0 end - @model function complex_ranges_with_begin_end_2() + TestUtils.@model function complex_ranges_with_begin_end_2() c = [1.0, 2.0] b .~ Normal(0.0, c[1:(end - 1 + 1)]) end - @model function complex_ranges_with_begin_end_3() + TestUtils.@model function complex_ranges_with_begin_end_3() c = [1.0, 2.0] b .~ Normal(0.0, c[(begin + 1 - 1):2]) end - @model function complex_ranges_with_begin_end_4() + TestUtils.@model function complex_ranges_with_begin_end_4() c = [1.0, 2.0] b .~ Normal(0.0, c[(begin + 1 - 1):(end - 1 + 1)]) end - @model function complex_ranges_with_begin_end_5() + TestUtils.@model function complex_ranges_with_begin_end_5() c = [1.0, 2.0] b .~ Normal(0.0, c[begin:end]) end @@ -1237,14 +1187,12 @@ end end end -@testitem "Anonymous variables" begin - using GraphPPL +@testitem "Anonymous variables" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, VariableNameAnonymous - include("testutils.jl") - # Test whether generic anonymous variables are created correctly - @model function anonymous_variables() + TestUtils.@model function anonymous_variables() b ~ Normal(Normal(0, 1), 1) end @@ -1256,7 +1204,7 @@ end function foo end - @model function det_anonymous_variables() + TestUtils.@model function det_anonymous_variables() b .~ Bernoulli(fill(0.5, 10)) x ~ foo(foo(in = b)) end @@ -1266,7 +1214,10 @@ end @test length(collect(filter(as_variable(VariableNameAnonymous), model))) == 1 end -@testitem "data/const variables should automatically fold when used with anonymous variable and deterministic relationship" begin +@testitem "data/const variables should automatically fold when used with anonymous variable and deterministic relationship" setup = [ + TestUtils +] begin + using Distributions import GraphPPL: create_model, getorcreate!, @@ -1279,13 +1230,11 @@ end value, VariableNameAnonymous - include("testutils.jl") - - @model function fold_datavars_1(f, a, b) + TestUtils.@model function fold_datavars_1(f, a, b) y ~ Normal(f(a, b), 0.5) end - @model function fold_datavars_2(f, a, b) + TestUtils.@model function fold_datavars_2(f, a, b) y ~ Normal(f(f(a, b), f(a, b)), 0.5) end @@ -1440,18 +1389,16 @@ end end end -@testitem "return value from the `@model` should be saved in the Context" begin +@testitem "return value from the `@model` should be saved in the Context" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, datalabel, getorcreate!, NodeCreationOptions, returnval, getcontext, children - include("testutils.jl") - - @model function submodel_with_return(y, x, z, subval) + TestUtils.@model function submodel_with_return(y, x, z, subval) y ~ Normal(x, z) return subval end - @model function model_with_return(y, val) + TestUtils.@model function model_with_return(y, val) x .~ Normal(ones(10), ones(10)) z .~ Normal(ones(10), ones(10)) @@ -1489,13 +1436,11 @@ end end end -@testitem "return value from the model must materialize `VariableRef`" begin +@testitem "return value from the model must materialize `VariableRef`" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, datalabel, getorcreate!, NodeCreationOptions, returnval, getcontext, NodeLabel - include("testutils.jl") - - @model function model_with_return_of_var(y, x, z, val) + TestUtils.@model function model_with_return_of_var(y, x, z, val) y ~ Normal(x, z) return (y, val) end @@ -1511,12 +1456,11 @@ end @test returnval(toplevelcontext)[2] === 3 end -@testitem "`end` index should be allowed in the `~` operator" begin +@testitem "`end` index should be allowed in the `~` operator" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model - include("testutils.jl") - - @model function begin_end_in_rhs() + TestUtils.@model function begin_end_in_rhs() s[1] ~ Beta(0.0, 1.0) b[1] ~ Normal(s[begin], 1.0) b[2] ~ Normal(s[end], 1.0) @@ -1528,7 +1472,7 @@ end @test length(collect(filter(as_variable(:s), model))) == 1 end - @model function begin_end_in_lhs() + TestUtils.@model function begin_end_in_lhs() s[1] ~ Beta(0.0, 1.0) s[begin] ~ Normal(0.0, 1.0) s[end] ~ Normal(0.0, 1.0) @@ -1541,12 +1485,11 @@ end end end -@testitem "Use local scoped variable in two different scopes" begin +@testitem "Use local scoped variable in two different scopes" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model - include("testutils.jl") - - @model function scope_twice() + TestUtils.@model function scope_twice() for i in 1:5 tmp[i] ~ Normal(0, 1) end @@ -1561,12 +1504,11 @@ end end end -@testitem "datalabel should support empty indices if array is passed" begin +@testitem "datalabel should support empty indices if array is passed" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions, datalabel - include("testutils.jl") - - @model function foo(y) + TestUtils.@model function foo(y) x ~ MvNormal([1, 1], [1 0.0; 0.0 1.0]) y ~ MvNormal(x, [1.0 0.0; 0.0 1.0]) end @@ -1580,38 +1522,37 @@ end @test length(collect(filter(as_variable(:y), model))) == 1 end -@testitem "Node arguments must be unique" begin +@testitem "Node arguments must be unique" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, getorcreate!, NodeCreationOptions, datalabel - include("testutils.jl") - - @model function simple_model_duplicate_1() + TestUtils.@model function simple_model_duplicate_1() x ~ Normal(0.0, 1.0) b ~ x + x end - @model function simple_model_duplicate_2() + TestUtils.@model function simple_model_duplicate_2() x ~ Normal(0.0, 1.0) b ~ x + x + x end - @model function simple_model_duplicate_3() + TestUtils.@model function simple_model_duplicate_3() x ~ Normal(0.0, 1.0) b ~ Normal(x, x) end - @model function simple_model_duplicate_4() + TestUtils.@model function simple_model_duplicate_4() x ~ Normal(0.0, 1.0) hide_x = x b ~ Normal(hide_x, x) end - @model function simple_model_duplicate_5() + TestUtils.@model function simple_model_duplicate_5() x ~ Normal(0.0, 1.0) x ~ Normal(x, 1) end - @model function simple_model_duplicate_6() + TestUtils.@model function simple_model_duplicate_6() x ~ Normal(0.0, 1.0) hide_x = x hide_x ~ Normal(x, 1) @@ -1630,7 +1571,7 @@ end ) end - @model function my_model(obs, N, sigma) + TestUtils.@model function my_model(obs, N, sigma) local x for i in 1:N x[i] ~ Bernoulli(0.5) @@ -1650,7 +1591,7 @@ end return (obs = obs,) end - @model function my_model(obs, N, sigma) + TestUtils.@model function my_model(obs, N, sigma) local x for i in 1:N x[i] ~ Bernoulli(0.5) @@ -1672,12 +1613,11 @@ end end end -@testitem "Neural network model" begin +@testitem "Neural network model" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, datalabel, NodeCreationOptions - include("testutils.jl") - - @model function neural_dot(out, in, w) + TestUtils.@model function neural_dot(out, in, w) c[1] ~ in[1] * w[1] for i in 2:length(in) c[i] ~ c[i - 1] + in[i] * w[i] @@ -1685,7 +1625,7 @@ end out := identity(c[end]) end - @model function neuron(in, out) + TestUtils.@model function neuron(in, out) local w for i in 1:length(in) w[i] ~ Normal(0.0, 1.0) @@ -1693,13 +1633,13 @@ end out ~ neural_dot(in = in, w = w) end - @model function neural_network_layer(in, out, n) + TestUtils.@model function neural_network_layer(in, out, n) for i in 1:n out[i] ~ neuron(in = in) end end - @model function neural_net(in, out) + TestUtils.@model function neural_net(in, out) h1 ~ neural_network_layer(in = in, n = 10) h2 ~ neural_network_layer(in = h1, n = 16) out ~ neural_network_layer(in = h2, n = 2) @@ -1716,11 +1656,11 @@ end @test length(collect(filter(as_variable(:out), model))) == 2 end -@testitem "Comparing variables throws warning" begin +@testitem "Comparing variables throws warning" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, getorcreate! - include("testutils.jl") - @model function test_model(y) + TestUtils.@model function test_model(y) x ~ Normal(0.0, 1.0) if x == 0 z ~ Normal(0.0, 1.0) @@ -1733,7 +1673,7 @@ end test_model(y = 1) ) - @model function test_model(y) + TestUtils.@model function test_model(y) x ~ Normal(0.0, 1.0) if x > 0 z ~ Normal(0.0, 1.0) @@ -1745,7 +1685,7 @@ end @test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." create_model( test_model(y = 1) ) - @model function test_model(y) + TestUtils.@model function test_model(y) x ~ Normal(0.0, 1.0) if x < 0 z ~ Normal(0.0, 1.0) @@ -1758,7 +1698,7 @@ end test_model(y = 1) ) - @model function test_model(y) + TestUtils.@model function test_model(y) x ~ Normal(0.0, 1.0) if 0 >= x z ~ Normal(0.0, 1.0) @@ -1772,15 +1712,14 @@ end ) end -@testitem "Multivariate input to function" begin - using GraphPPL +@testitem "Multivariate input to function" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, getorcreate!, datalabel - include("testutils.jl") function dot end function relu end - @model function neuron(in, out) + TestUtils.@model function neuron(in, out) local w for i in 1:(length(in)) w[i] ~ Normal(0.0, 1.0) @@ -1790,13 +1729,13 @@ end out := relu(unactivated) end - @model function neural_network_layer(in, out, n) + TestUtils.@model function neural_network_layer(in, out, n) for i in 1:n out[i] ~ neuron(in = in) end end - @model function neural_net(in, out) + TestUtils.@model function neural_net(in, out) local softin for i in 1:length(in) softin[i] ~ Normal(in[i], 1.0) @@ -1816,8 +1755,8 @@ end @test length(collect(filter(as_variable(:in), model))) == 3 end -@testitem "Constraints over nested models" begin - using GraphPPL +@testitem "Constraints over nested models" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, getorcreate!, @@ -1829,14 +1768,12 @@ end hasextra, getextra - include("testutils.jl") - - @model function inner_model(x, y) + TestUtils.@model function inner_model(x, y) θ ~ Normal(0.0, 1.0) y ~ Normal(x, θ) end - @model function outer_model(y) + TestUtils.@model function outer_model(y) x ~ Normal(0.0, 1.0) y ~ inner_model(x = x) end @@ -1853,7 +1790,7 @@ end end context = GraphPPL.getcontext(model) - node = context[inner_model, 1][NormalMeanVariance, 2] + node = context[inner_model, 1][TestUtils.NormalMeanVariance, 2] @test hasextra(model[node], :factorization_constraint_indices) @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) @@ -1869,17 +1806,17 @@ end end context = GraphPPL.getcontext(model) - node = context[inner_model, 1][NormalMeanVariance, 2] + node = context[inner_model, 1][TestUtils.NormalMeanVariance, 2] @test hasextra(model[node], :factorization_constraint_indices) @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) end -@testitem "Inference with DataArray" begin +@testitem "Inference with DataArray" setup = [TestUtils] begin using Distributions using GraphPPL import GraphPPL: @model, create_model, datalabel, NodeCreationOptions, neighbors - @model function data_array_model(y) + TestUtils.@model function data_array_model(y) σ ~ Gamma(1.0, 1.0) for i in 1:10 y[i + 10] ~ Normal(y[i], σ) @@ -1901,33 +1838,33 @@ end end end -@testitem "Splatting in the `~` operator" begin +@testitem "Splatting in the `~` operator" setup = [TestUtils] begin using GraphPPL using Distributions - import GraphPPL: create_model, datalabel, NodeCreationOptions, @model + import GraphPPL: create_model, datalabel, NodeCreationOptions - @model function splatting_model_1(y) + TestUtils.@model function splatting_model_1(y) a ~ Normal(0.0, 1.0) b ~ InverseGamma(1.0, 1.0) x = [a, b] y ~ Normal(x...) end - @model function splatting_model_2(y) + TestUtils.@model function splatting_model_2(y) a ~ Normal(0.0, 1.0) b ~ InverseGamma(1.0, 1.0) x = [b] y ~ Normal(a, x...) end - @model function splatting_model_3(y) + TestUtils.@model function splatting_model_3(y) a ~ Normal(0.0, 1.0) b ~ InverseGamma(1.0, 1.0) x = [a] y ~ Normal(x..., b) end - @model function splatting_model_4(y) + TestUtils.@model function splatting_model_4(y) a ~ Normal(0.0, 1.0) b ~ InverseGamma(1.0, 1.0) x_1 = [a] @@ -1950,7 +1887,7 @@ end a = context[:a] b = context[:b] y = context[:y] - normal_node = context[Normal, 2] + normal_node = context[TestUtils.NormalMeanVariance, 2] @test a ∈ GraphPPL.neighbors(model, normal_node) @test b ∈ GraphPPL.neighbors(model, normal_node) @test y ∈ GraphPPL.neighbors(model, normal_node) @@ -1959,7 +1896,7 @@ end @test GraphPPL.getname(model[normal_node, y]) == :out end - @model function splatting_model_5(y) + TestUtils.@model function splatting_model_5(y) x[1] ~ Normal(0.0, 1.0) x[2] ~ InverseGamma(1.0, 1.0) y ~ Normal(x...) @@ -1976,7 +1913,7 @@ end context = GraphPPL.getcontext(model) x = context[:x] y = context[:y] - normal_node = context[Normal, 2] + normal_node = context[TestUtils.NormalMeanVariance, 2] @test x[1] ∈ GraphPPL.neighbors(model, normal_node) @test x[2] ∈ GraphPPL.neighbors(model, normal_node) @test y ∈ GraphPPL.neighbors(model, normal_node) @@ -1985,12 +1922,12 @@ end @test GraphPPL.getname(model[normal_node, y]) == :out end -@testitem "Multiple indices in rhs statement" begin +@testitem "Multiple indices in rhs statement" setup = [TestUtils] begin using Distributions using GraphPPL - import GraphPPL: @model, create_model, datalabel, NodeCreationOptions, neighbors + import GraphPPL: create_model, datalabel, NodeCreationOptions, neighbors - @model function multiple_indices(prior_params, y) + TestUtils.@model function multiple_indices(prior_params, y) x ~ Normal(prior_params[1][1], prior_params[1][2]) y ~ Normal(x, 1.0) end @@ -2004,12 +1941,12 @@ end @test length(collect(filter(as_variable(:x), model))) == 1 end -@testitem "Create empty array" begin +@testitem "Create empty array" setup = [TestUtils] begin using Distributions using GraphPPL - import GraphPPL: @model, create_model, datalabel, NodeCreationOptions, neighbors + import GraphPPL: create_model, datalabel, NodeCreationOptions, neighbors - @model function empty_array_model() + TestUtils.@model function empty_array_model() x = [] @test isempty(x) end diff --git a/test/model/context_tests.jl b/test/model/context_tests.jl index a0fe25b3..30d7f2cd 100644 --- a/test/model/context_tests.jl +++ b/test/model/context_tests.jl @@ -115,52 +115,44 @@ end @test ctx[FactorID(sum, 2)] == ctx.factor_nodes[FactorID(sum, 2)] end -@testitem "getcontext(::Model)" begin +@testitem "getcontext(::Model)" setup = [TestUtils] begin import GraphPPL: Context, getcontext, create_model, add_variable_node!, NodeCreationOptions - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() @test getcontext(model) == model.graph[] add_variable_node!(model, getcontext(model), NodeCreationOptions(), :x, nothing) @test getcontext(model)[:x] == model.graph[][:x] end -@testitem "path_to_root(::Context)" begin +@testitem "path_to_root(::Context)" setup = [TestUtils] begin import GraphPPL: create_model, Context, path_to_root, getcontext - include("testutils.jl") - - using .TestUtils.ModelZoo - ctx = Context() @test path_to_root(ctx) == [ctx] - model = create_model(outer()) + model = create_model(TestUtils.outer()) ctx = getcontext(model) - inner_context = ctx[inner, 1] - inner_inner_context = inner_context[inner_inner, 1] + inner_context = ctx[TestUtils.inner, 1] + inner_inner_context = inner_context[TestUtils.inner_inner, 1] @test path_to_root(inner_inner_context) == [inner_inner_context, inner_context, ctx] end -@testitem "VarDict" begin +@testitem "VarDict" setup = [TestUtils] begin using Distributions import GraphPPL: Context, VarDict, create_model, getorcreate!, datalabel, NodeCreationOptions, getcontext, is_random, is_data, getproperties - include("testutils.jl") - ctx = Context() vardict = VarDict(ctx) @test isa(vardict, VarDict) - @model function submodel(y, x_prev, x_next) + TestUtils.@model function submodel(y, x_prev, x_next) γ ~ Gamma(1, 1) x_next ~ Normal(x_prev, γ) y ~ Normal(x_next, 1) end - @model function state_space_model_with_new(y) + TestUtils.@model function state_space_model_with_new(y) x[1] ~ Normal(0, 1) y[1] ~ Normal(x[1], 1) for i in 2:length(y) diff --git a/test/model/model_construction_tests.jl b/test/model/model_construction_tests.jl index a2cf4a16..ae06a7f5 100644 --- a/test/model/model_construction_tests.jl +++ b/test/model/model_construction_tests.jl @@ -1,32 +1,25 @@ - -@testitem "model constructor" begin +@testitem "model constructor" setup = [TestUtils] begin import GraphPPL: create_model, Model - include("testutils.jl") - - @test typeof(create_test_model()) <: Model + @test typeof(TestUtils.create_test_model()) <: Model @test_throws MethodError Model() end # TODO this is not a test for GraphPPL but for the tests. -@testitem "create_test_model()" begin +@testitem "create_test_model()" setup = [TestUtils] begin import GraphPPL: create_model, Model, nv, ne - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() @test typeof(model) <: Model && nv(model) == 0 && ne(model) == 0 - @test_throws MethodError create_test_model(:x, :y, :z) + @test_throws MethodError TestUtils.create_test_model(:x, :y, :z) end -@testitem "getcounter and setcounter!" begin +@testitem "getcounter and setcounter!" setup = [TestUtils] begin import GraphPPL: create_model, setcounter!, getcounter - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() @test setcounter!(model, 1) == 1 @test getcounter(model) == 1 diff --git a/test/model/model_operations_tests.jl b/test/model/model_operations_tests.jl index 44b0a279..87a41762 100644 --- a/test/model/model_operations_tests.jl +++ b/test/model/model_operations_tests.jl @@ -1,9 +1,7 @@ -@testitem "getindex(::Model, ::NodeLabel)" begin +@testitem "getindex(::Model, ::NodeLabel)" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, NodeLabel, NodeData, VariableNodeProperties, getproperties - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) label = NodeLabel(:x, 1) model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing)) @@ -13,14 +11,12 @@ @test_throws MethodError model[0] end -@testitem "copy_markov_blanket_to_child_context" begin +@testitem "copy_markov_blanket_to_child_context" setup = [TestUtils] begin import GraphPPL: create_model, copy_markov_blanket_to_child_context, Context, getorcreate!, proxylabel, unroll, getcontext, NodeCreationOptions - include("testutils.jl") - # Copy individual variables - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) function child end child_context = Context(ctx, child) @@ -29,7 +25,7 @@ end zref = getorcreate!(model, ctx, NodeCreationOptions(), :z, nothing) # Do not copy constant variables - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) child_context = Context(ctx, child) @@ -37,14 +33,14 @@ end @test !haskey(child_context, :in) # Do not copy vector valued constant variables - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) child_context = Context(ctx, child) copy_markov_blanket_to_child_context(child_context, (in = [1, 2, 3],)) @test !haskey(child_context, :in) # Copy ProxyLabel variables to child context - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) xref = proxylabel(:x, xref, nothing) @@ -53,7 +49,7 @@ end @test child_context[:in] == xref end -@testitem "getorcreate!" begin +@testitem "getorcreate!" setup = [TestUtils] begin using Graphs import GraphPPL: create_model, @@ -66,11 +62,9 @@ end getproperties, is_kind - include("testutils.jl") - let # let block to suppress the scoping warnings # Test 1: Creation of regular one-dimensional variable - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) x = getorcreate!(model, ctx, :x, nothing) @test nv(model) == 1 && ne(model) == 0 @@ -84,7 +78,7 @@ end @test x == x2 && nv(model) == 1 && ne(model) == 0 # Test 4: Test that creating a vector variable creates an array of the correct size - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) y = getorcreate!(model, ctx, :y, 1) @test nv(model) == 1 && ne(model) == 0 && y isa ResizableArray && y[1] isa NodeLabel @@ -104,7 +98,7 @@ end @test_throws ErrorException getorcreate!(model, ctx, :y, 1, 2) #Test 9: Test that creating a tensor variable creates a tensor of the correct size - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) z = getorcreate!(model, ctx, :z, 1, 1) @test nv(model) == 1 && ne(model) == 0 && z isa ResizableArray && z[1, 1] isa NodeLabel @@ -128,7 +122,7 @@ end # Test 15: Test that creating a variable that exists in the model scope but not in local scope still throws an error let # force local scope - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) getorcreate!(model, ctx, :a, nothing) @test_throws ErrorException a = getorcreate!(model, ctx, :a, 1) @@ -136,7 +130,7 @@ end end # Test 16. Test that the index is required to create a variable in the model - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) @test_throws ErrorException getorcreate!(model, ctx, :a) @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(), :a) @@ -144,13 +138,13 @@ end @test_throws ErrorException getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 2), :a) # Test 17. Range based getorcreate! - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) var = getorcreate!(model, ctx, :a, 1:2) @test nv(model) == 2 && var[1] isa NodeLabel && var[2] isa NodeLabel # Test 17.1 Range based getorcreate! should use the same options - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) var = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a, 1:2) @test nv(model) == 2 && var[1] isa NodeLabel && var[2] isa NodeLabel @@ -158,7 +152,7 @@ end @test is_kind(getproperties(model[var[1]]), :data) # Test 18. Range x2 based getorcreate! - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) var = getorcreate!(model, ctx, :a, 1:2, 1:3) @test nv(model) == 6 @@ -167,7 +161,7 @@ end end # Test 18. Range x2 based getorcreate! should use the same options - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) var = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :a, 1:2, 1:3) @test nv(model) == 6 @@ -178,7 +172,7 @@ end end end -@testitem "getifcreated" begin +@testitem "getifcreated" setup = [TestUtils] begin using Graphs import GraphPPL: create_model, @@ -193,9 +187,7 @@ end value, NodeCreationOptions - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) # Test case 1: check that getifcreated the variable created by getorcreate @@ -228,20 +220,20 @@ end @test output[1] == xref && value(getproperties(model[output[2]])) == 1 # Test case 10: check that getifcreated returns the variable node if we create a variable and call it by symbol in a vector - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) zref = getorcreate!(model, ctx, NodeCreationOptions(), :z, 1) z_fetched = getifcreated(model, ctx, zref[1]) @test z_fetched == zref[1] # Test case 11: Test that getifcreated returns a constant node when we call it with a symbol - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) zref = getifcreated(model, ctx, :Bernoulli) @test value(getproperties(model[zref])) == :Bernoulli # Test case 12: Test that getifcreated returns a vector of NodeLabels if called with a vector of NodeLabels - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) y = getorcreate!(model, ctx, NodeCreationOptions(), :y, nothing) @@ -249,7 +241,7 @@ end @test zref == [xref, y] # Test case 13: Test that getifcreated returns a ResizableArray tensor of NodeLabels if called with a ResizableArray tensor of NodeLabels - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, 1, 1) xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, 2, 1) @@ -257,13 +249,13 @@ end @test zref == xref # Test case 14: Test that getifcreated returns multiple variables if called with a tuple of constants - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) zref = getifcreated(model, ctx, ([1, 1], 2)) @test nv(model) == 2 && value(getproperties(model[zref[1]])) == [1, 1] && value(getproperties(model[zref[2]])) == 2 # Test case 15: Test that getifcreated returns a ProxyLabel if called with a ProxyLabel - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) xref = getorcreate!(model, ctx, NodeCreationOptions(), :x, nothing) xref = proxylabel(:x, xref, nothing) @@ -271,7 +263,7 @@ end @test zref === xref end -@testitem "make_node!(::Atomic)" begin +@testitem "make_node!(::Atomic)" setup = [TestUtils] begin using Graphs, BitSetTuples import GraphPPL: getcontext, @@ -290,10 +282,8 @@ end NodeCreationOptions, getproperties - include("testutils.jl") - # Test 1: Deterministic call returns result of deterministic function and does not create new node - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = AnonymousVariable(model, ctx) @@ -313,7 +303,7 @@ end @test getname.(edges(model, node_id)) == [:out, :μ, :σ] # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) @@ -321,7 +311,7 @@ end @test nv(model) == 4 && ne(model) == 3 # Test 4: Deterministic atomic call with nodelabels should create the actual node - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() in1 = getorcreate!(model, ctx, :in1, nothing) @@ -331,7 +321,7 @@ end @test nv(model) == 4 && ne(model) == 3 # Test 5: Deterministic atomic call with nodelabels should create the actual node - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() in1 = getorcreate!(model, ctx, :in1, nothing) @@ -341,7 +331,7 @@ end @test nv(model) == 4 # Test 6: Stochastic node with default arguments - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) @@ -351,7 +341,7 @@ end @test getname.(edges(model, node_id)) == [:out, :μ, :σ] # Test 7: Stochastic node with instantiated object - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() uprior = Normal(0, 1) @@ -360,7 +350,7 @@ end @test nv(model) == 2 # Test 8: Deterministic node with nodelabel objects where all interfaces are already defined (no missing interfaces) - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() in1 = getorcreate!(model, ctx, :in1, nothing) @@ -371,19 +361,19 @@ end ) # Test 8: Stochastic node with nodelabel objects where we have an array on the rhs (so should create 1 node for [0, 1]) - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() out = getorcreate!(model, ctx, :out, nothing) - nodeid, _ = make_node!(model, ctx, options, ArbitraryNode, out, (in = [0, 1],)) + nodeid, _ = make_node!(model, ctx, options, TestUtils.ArbitraryNode, out, (in = [0, 1],)) @test nv(model) == 3 && value(getproperties(model[ctx[:constvar_2]])) == [0, 1] # Test 9: Stochastic node with all interfaces defined as constants - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() out = getorcreate!(model, ctx, :out, nothing) - nodeid, _ = make_node!(model, ctx, options, ArbitraryNode, out, (1, 1)) + nodeid, _ = make_node!(model, ctx, options, TestUtils.ArbitraryNode, out, (1, 1)) @test nv(model) == 4 @test getname.(edges(model, nodeid)) == [:out, :in, :in] @test getname.(edges(model, nodeid)) == [:out, :in, :in] @@ -392,7 +382,7 @@ end function abc(; a = 1, b = 2) return a + b end - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() out = AnonymousVariable(model, ctx) @@ -402,14 +392,14 @@ end function abc(a; b = 2) return a + b end - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() out = AnonymousVariable(model, ctx) @test make_node!(model, ctx, options, abc, out, MixedArguments((2,), (b = 2,))) == (nothing, 4) # Test 12: Deterministic node with mixed arguments that has to be materialized should throw error - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() out = getorcreate!(model, ctx, :out, nothing) @@ -417,32 +407,32 @@ end @test_throws ErrorException make_node!(model, ctx, options, abc, out, MixedArguments((a,), (b = 2,))) # Test 13: Make stochastic node with aliases - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) node_id = make_node!(model, ctx, options, Normal, xref, (μ = 0, τ = 1)) - @test any((key) -> fform(key) == NormalMeanPrecision, keys(ctx.factor_nodes)) + @test any((key) -> fform(key) == TestUtils.NormalMeanPrecision, keys(ctx.factor_nodes)) @test nv(model) == 4 - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) node_id = make_node!(model, ctx, options, Normal, xref, (μ = 0, σ = 1)) - @test any((key) -> fform(key) == NormalMeanVariance, keys(ctx.factor_nodes)) + @test any((key) -> fform(key) == TestUtils.NormalMeanVariance, keys(ctx.factor_nodes)) @test nv(model) == 4 - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) node_id = make_node!(model, ctx, options, Normal, xref, (0, 1)) - @test any((key) -> fform(key) == NormalMeanVariance, keys(ctx.factor_nodes)) + @test any((key) -> fform(key) == TestUtils.NormalMeanVariance, keys(ctx.factor_nodes)) @test nv(model) == 4 # Test 14: Make deterministic node with ProxyLabels as arguments - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) @@ -455,7 +445,7 @@ end @test nv(model) == 4 # Test 15.1: Make stochastic node with aliased interfaces - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() μ = getorcreate!(model, ctx, :μ, nothing) @@ -463,12 +453,12 @@ end out = getorcreate!(model, ctx, :out, nothing) for keys in [(:mean, :variance), (:m, :variance), (:mean, :v)] local node_id, _ = make_node!(model, ctx, options, Normal, out, NamedTuple{keys}((μ, σ))) - @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === NormalMeanVariance + @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === TestUtils.NormalMeanVariance @test GraphPPL.neighbors(model, node_id) == [out, μ, σ] end # Test 15.2: Make stochastic node with aliased interfaces - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() μ = getorcreate!(model, ctx, :μ, nothing) @@ -476,12 +466,12 @@ end out = getorcreate!(model, ctx, :out, nothing) for keys in [(:mean, :precision), (:m, :precision), (:mean, :p)] local node_id, _ = make_node!(model, ctx, options, Normal, out, NamedTuple{keys}((μ, p))) - @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === NormalMeanPrecision + @test GraphPPL.fform(GraphPPL.getproperties(model[node_id])) === TestUtils.NormalMeanPrecision @test GraphPPL.neighbors(model, node_id) == [out, μ, p] end end -@testitem "materialize_factor_node!" begin +@testitem "materialize_factor_node!" setup = [TestUtils] begin using Distributions using Graphs import GraphPPL: @@ -497,9 +487,7 @@ end edges, NodeCreationOptions - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, options, :x, nothing) @@ -513,7 +501,7 @@ end @test getname.(edges(model, node_id)) == [:out, :μ, :σ] # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) @@ -523,7 +511,7 @@ end @test nv(model) == 4 && ne(model) == 3 # Test 4: Deterministic atomic call with nodelabels should create the actual node - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() in1 = getorcreate!(model, ctx, :in1, nothing) @@ -533,7 +521,7 @@ end @test nv(model) == 4 && ne(model) == 3 # Test 14: Make deterministic node with ProxyLabels as arguments - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) @@ -546,46 +534,36 @@ end @test nv(model) == 4 end -@testitem "make_node!(::Composite)" begin +@testitem "make_node!(::Composite)" setup = [TestUtils] begin using MetaGraphsNext, Graphs import GraphPPL: getcontext, make_node!, create_model, getorcreate!, proxylabel, NodeCreationOptions - include("testutils.jl") - - using .TestUtils.ModelZoo - #test make node for priors - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) - make_node!(model, ctx, options, prior, proxylabel(:x, xref, nothing), ()) + make_node!(model, ctx, options, TestUtils.ModelZoo.prior, proxylabel(:x, xref, nothing), ()) @test nv(model) == 4 - @test ctx[prior, 1][:a] == proxylabel(:x, xref, nothing) + @test ctx[TestUtils.ModelZoo.prior, 1][:a] == proxylabel(:x, xref, nothing) #test make node for other composite models - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) - @test_throws ErrorException make_node!(model, ctx, options, gcv, proxylabel(:x, xref, nothing), (0, 1)) + @test_throws ErrorException make_node!(model, ctx, options, TestUtils.ModelZoo.gcv, proxylabel(:x, xref, nothing), (0, 1)) # test make node of broadcastable composite model - model = create_test_model() - ctx = getcontext(model) - options = NodeCreationOptions() - out = getorcreate!(model, ctx, :out, nothing) - model = create_model(broadcaster()) + model = create_model(TestUtils.ModelZoo.broadcaster()) @test nv(model) == 103 end -@testitem "broadcast" begin +@testitem "broadcast" setup = [TestUtils] begin import GraphPPL: NodeLabel, ResizableArray, create_model, getcontext, getorcreate!, make_node!, NodeCreationOptions - include("testutils.jl") - # Test 1: Broadcast a vector node - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, 1) @@ -600,7 +578,7 @@ end @test size(zref) == (2,) # Test 2: Broadcast a matrix node - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, 1, 1) @@ -624,7 +602,7 @@ end @test size(zref) == (2, 2) # Test 3: Broadcast a vector node with a matrix node - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, 1) diff --git a/test/model_generator_tests.jl b/test/model_generator_tests.jl index 8e463487..801941a2 100644 --- a/test/model_generator_tests.jl +++ b/test/model_generator_tests.jl @@ -1,10 +1,8 @@ -@testitem "Basic creation" begin +@testitem "Basic creation" setup = [TestUtils] begin using Distributions import GraphPPL: ModelGenerator, create_model, Model, NodeCreationOptions, getorcreate! - include("testutils.jl") - - @model function basic_model(a, b) + TestUtils.@model function basic_model(a, b) x ~ Normal(a, b) z ~ Gamma(1, 1) y ~ Normal(x, z) @@ -33,13 +31,11 @@ @test_throws "b = ..." basic_model(a = 1, 2) end -@testitem "Data creation via callback" begin +@testitem "Data creation via callback" setup = [TestUtils] begin using Distributions import GraphPPL: ModelGenerator, create_model, Model, NodeCreationOptions, getorcreate!, NodeLabel - include("testutils.jl") - - @model function simple_model_for_model_generator(observation, a, b) + TestUtils.@model function simple_model_for_model_generator(observation, a, b) x ~ Beta(0, 1) y ~ Gamma(a, b) observation ~ Normal(x, y) @@ -72,13 +68,11 @@ end @test GraphPPL.getname(outedge[2]) === :out end -@testitem "Indexing in provided fixed kwargs" begin +@testitem "Indexing in provided fixed kwargs" setup = [TestUtils] begin using Distributions import GraphPPL: ModelGenerator, create_model, Model, as_node, neighbors, NodeLabel, getname, is_data, is_constant, getproperties, value - include("testutils.jl") - - @model function basic_model(inputs) + TestUtils.@model function basic_model(inputs) x ~ Beta(inputs[1], inputs[2]) z ~ Gamma(1, 1) y ~ Normal(x, z) @@ -112,13 +106,11 @@ end end end -@testitem "Error messages" begin +@testitem "Error messages" setup = [TestUtils] begin using Distributions import GraphPPL: create_model, Model, ModelGenerator - include("testutils.jl") - - @model function simple_model_for_model_generator(observation, a, b) + TestUtils.@model function simple_model_for_model_generator(observation, a, b) x ~ Beta(0, 1) y ~ Gamma(a, b) observation ~ Normal(x, y) @@ -166,8 +158,9 @@ end end end -@testitem "with_plugins" begin - import GraphPPL: ModelGenerator, PluginsCollection, AbstractPluginTraitType, getplugins, with_plugins, @model +@testitem "with_plugins" setup = [TestUtils] begin + using Distributions + import GraphPPL: ModelGenerator, PluginsCollection, AbstractPluginTraitType, getplugins, with_plugins struct ArbitraryPluginForModelGeneratorTestsType1 <: AbstractPluginTraitType end struct ArbitraryPluginForModelGeneratorTestsType2 <: AbstractPluginTraitType end @@ -178,7 +171,7 @@ end GraphPPL.plugin_type(::ArbitraryPluginForModelGeneratorTests1) = ArbitraryPluginForModelGeneratorTestsType1() GraphPPL.plugin_type(::ArbitraryPluginForModelGeneratorTests2) = ArbitraryPluginForModelGeneratorTestsType2() - @model function simple_model(a) + TestUtils.@model function simple_model(a) y ~ Normal(a, 1) end @@ -210,13 +203,12 @@ end end end -@testitem "with_backend" begin - import GraphPPL: ModelGenerator, DefaultBackend, with_backend, getbackend, create_model - - include("testutils.jl") +@testitem "with_backend" setup = [TestUtils] begin + using Distributions + import GraphPPL: ModelGenerator, DefaultBackend, with_backend, getbackend, create_model, @model - # `GraphPPL.@model` uses the `DefaultBackend`, while `@model` from `testutils.jl` uses the `TestBackend` - GraphPPL.@model function simple_model(a) + # `GraphPPL.TestUtils.@model` uses the `DefaultBackend`, while `TestUtils.@model` from `testutils.jl` uses the `TestBackend` + @model function simple_model(a) y ~ Normal(a, 1) end @@ -235,11 +227,10 @@ end @test getbackend(model_with_a_different_backend) === TestUtils.TestGraphPPLBackend() end -@testitem "source code retrieval from ModelGenerator #1" begin - import GraphPPL: @model +@testitem "source code retrieval from ModelGenerator #1" setup = [TestUtils] begin using Distributions - @model function beta_bernoulli(y) + TestUtils.@model function beta_bernoulli(y) θ ~ Beta(1, 1) for i in eachindex(y) y[i] ~ Bernoulli(θ) @@ -255,19 +246,18 @@ end end""" end -@testitem "source code retrieval from ModelGenerator #2" begin - import GraphPPL: @model +@testitem "source code retrieval from ModelGenerator #2" setup = [TestUtils] begin using Distributions # Define a model with some random variables and deterministic statements - @model function test_model(x, y) + TestUtils.@model function test_model(x, y) z ~ Normal(x, 1) w := z + y q ~ Normal(w, 1) end # Define a second model with different structure - @model function another_model(μ, σ) + TestUtils.@model function another_model(μ, σ) x ~ Normal(μ, σ) y := x^2 z ~ Gamma(y, 1) @@ -304,11 +294,10 @@ end @test source1 != source2 end -@testitem "source code retrieval from ModelGenerator #3 partial kwargs" begin - import GraphPPL: @model +@testitem "source code retrieval from ModelGenerator #3 partial kwargs" setup = [TestUtils] begin using Distributions - @model function beta_bernoulli(y) + TestUtils.@model function beta_bernoulli(y) θ ~ Beta(1, 1) for i in eachindex(y) y[i] ~ Bernoulli(θ) diff --git a/test/model_macro_tests.jl b/test/model_macro_tests.jl index f09ee090..2362f60a 100644 --- a/test/model_macro_tests.jl +++ b/test/model_macro_tests.jl @@ -1,4 +1,4 @@ -@testitem "__guard_f" begin +@testitem "__guard_f" setup = [TestUtils] begin import GraphPPL.__guard_f f(e::Expr) = 10 @@ -6,11 +6,10 @@ @test __guard_f(f, :(1 + 1)) == 10 end -@testitem "apply_pipeline" begin +@testitem "apply_pipeline" setup = [TestUtils] begin + using MacroTools import GraphPPL: apply_pipeline - include("testutils.jl") - @testset "Default `what_walk`" begin # `Pipeline` that finds all `Expr` nodes in the AST in the form of `:(x + 1)` # And replaces them with `:(x + 2)` @@ -24,19 +23,19 @@ end input = :(x + 1) output = :(x + 2) - @test_expression_generating apply_pipeline(input, pipeline1) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline1) output input = :(x + y) output = :(x + y) - @test_expression_generating apply_pipeline(input, pipeline1) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline1) output input = :(x * 1) output = :(x * 1) - @test_expression_generating apply_pipeline(input, pipeline1) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline1) output input = :(y ~ Normal(x + 1, z)) output = :(y ~ Normal(x + 2, z)) - @test_expression_generating apply_pipeline(input, pipeline1) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline1) output end @testset "Guarded `what_walk`" begin @@ -55,28 +54,27 @@ end input = :(x + 1) output = :(x + 2) - @test_expression_generating apply_pipeline(input, pipeline2) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline2) output input = :(x + y) output = :(x + y) - @test_expression_generating apply_pipeline(input, pipeline2) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline2) output input = :(x * 1) output = :(x * 1) - @test_expression_generating apply_pipeline(input, pipeline2) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline2) output # Should not modift this one, since its guarded walk input = :(y ~ Normal(x + 1, z)) output = :(y ~ Normal(x + 1, z)) - @test_expression_generating apply_pipeline(input, pipeline2) output + TestUtils.@test_expression_generating apply_pipeline(input, pipeline2) output end end -@testitem "check_reserved_variable_names_model" begin +@testitem "check_reserved_variable_names_model" setup = [TestUtils] begin + using Distributions import GraphPPL: apply_pipeline, check_reserved_variable_names_model - include("testutils.jl") - # Test 1: test that reserved variable name __parent_options__ throws an error input = quote __parent_options__ = 1 @@ -99,11 +97,10 @@ end @test apply_pipeline(input, check_reserved_variable_names_model) == input end -@testitem "check_incomplete_factorization_constraint" begin +@testitem "check_incomplete_factorization_constraint" setup = [TestUtils] begin + using Distributions import GraphPPL: apply_pipeline, check_incomplete_factorization_constraint - include("testutils.jl") - input = quote q(x)q(y) end @@ -130,12 +127,10 @@ end @test apply_pipeline(input, check_incomplete_factorization_constraint) == input end -@testitem "guarded_walk" begin - import MacroTools: @capture +@testitem "guarded_walk" setup = [TestUtils] begin + using MacroTools import GraphPPL: guarded_walk - include("testutils.jl") - #Test 1: walk with indexing operation as guard g_walk = guarded_walk((x) -> x isa Expr && x.head == :ref) @@ -155,7 +150,7 @@ end end end - @test_expression_generating result output + TestUtils.@test_expression_generating result output #Test 2: walk with complexer guard function custom_guard(x) = x isa Expr && (x.head == :ref || x.head == :call) @@ -178,7 +173,7 @@ end end end - @test_expression_generating result output + TestUtils.@test_expression_generating result output #Test 3: walk with guard function that always returns true g_walk = guarded_walk((x) -> true) @@ -195,7 +190,7 @@ end end end - @test_expression_generating result input + TestUtils.@test_expression_generating result input #Test 4: walk with guard function that should not go into body if created_by is key g_walk = guarded_walk((x) -> x isa Expr && :created_by ∈ x.args) @@ -214,18 +209,17 @@ end return x end end - @test_expression_generating result output + TestUtils.@test_expression_generating result output end -@testitem "save_expression_in_tilde" begin +@testitem "save_expression_in_tilde" setup = [TestUtils] begin + using MacroTools import GraphPPL: save_expression_in_tilde, apply_pipeline - include("testutils.jl") - # Test 1: save expression in tilde input = :(x ~ Normal(0, 1)) output = :(x ~ Normal(0, 1) where {created_by = () -> :(x ~ Normal(0, 1))}) - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 2: save expression in tilde with multiple expressions input = quote @@ -236,12 +230,12 @@ end x ~ Normal(0, 1) where {created_by = () -> :(x ~ Normal(0, 1))} y ~ Normal(0, 1) where {created_by = () -> :(y ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, save_expression_in_tilde) output + TestUtils.@test_expression_generating apply_pipeline(input, save_expression_in_tilde) output # Test 3: save expression in tilde with broadcasted operation input = :(x .~ Normal(0, 1)) output = :(x .~ Normal(0, 1) where {created_by = () -> :(x .~ Normal(0, 1))}) - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 4: save expression in tilde with multiple broadcast expressions input = quote @@ -254,12 +248,12 @@ end y .~ Normal(0, 1) where {created_by = () -> :(y .~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, save_expression_in_tilde) output + TestUtils.@test_expression_generating apply_pipeline(input, save_expression_in_tilde) output # Test 5: save expression in tilde with deterministic operation input = :(x := Normal(0, 1)) output = :(x := Normal(0, 1) where {created_by = () -> :(x := Normal(0, 1))}) - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 6: save expression in tilde with multiple deterministic expressions input = quote @@ -272,7 +266,7 @@ end y := Normal(0, 1) where {created_by = () -> :(y := Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, save_expression_in_tilde) output + TestUtils.@test_expression_generating apply_pipeline(input, save_expression_in_tilde) output # Test 7: save expression in tilde with additional options input = quote @@ -283,22 +277,22 @@ end x ~ Normal(0, 1) where {q = MeanField(), created_by = () -> :(x ~ Normal(0, 1) where {q = MeanField()})} y ~ Normal(0, 1) where {q = MeanField(), created_by = () -> :(y ~ Normal(0, 1) where {q = MeanField()})} end - @test_expression_generating apply_pipeline(input, save_expression_in_tilde) output + TestUtils.@test_expression_generating apply_pipeline(input, save_expression_in_tilde) output # Test 8: with different variable names input = :(y ~ Normal(0, 1)) output = :(y ~ Normal(0, 1) where {created_by = () -> :(y ~ Normal(0, 1))}) - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 9: with different parameter options input = :(x ~ Normal(0, 1) where {mu = 2.0, sigma = 0.5}) output = :(x ~ Normal(0, 1) where {mu = 2.0, sigma = 0.5, created_by = () -> :(x ~ Normal(0, 1) where {mu = 2.0, sigma = 0.5})}) - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 10: with different parameter options input = :(y ~ Normal(0, 1) where {mu = 1.0}) output = :(y ~ Normal(0, 1) where {mu = 1.0, created_by = () -> :(y ~ Normal(0, 1) where {mu = 1.0})}) - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 11: with no parameter options input = :(x ~ Normal(0, 1) where {}) @@ -312,7 +306,7 @@ end x = i end end - @test_expression_generating save_expression_in_tilde(input) input + TestUtils.@test_expression_generating save_expression_in_tilde(input) input # Test 13: check matching pattern in loop input = quote @@ -325,7 +319,7 @@ end x[i] ~ Normal(0, 1) where {created_by = () -> :(x[i] ~ Normal(0, 1))} end end - @test_expression_generating save_expression_in_tilde(input) input + TestUtils.@test_expression_generating save_expression_in_tilde(input) input # Test 14: check local statements input = quote @@ -338,7 +332,7 @@ end local y ~ (Normal(0, 1)) where {created_by = () -> :(local y ~ Normal(0, 1))} end - @test_expression_generating save_expression_in_tilde(input) input + TestUtils.@test_expression_generating save_expression_in_tilde(input) input # Test 15: check arithmetic operations input = quote @@ -347,7 +341,7 @@ end output = quote x := (a + b) where {created_by = () -> :(x := a + b)} end - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 16: test local for deterministic statement input = quote @@ -356,7 +350,7 @@ end output = quote local x := (a + b) where {created_by = () -> :(local x := a + b)} end - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 17: test local for deterministic statement input = quote @@ -365,7 +359,7 @@ end output = quote local x := (a + b) where {q = q(x)q(a)q(b), created_by = () -> :(local x := (a + b) where {q = q(x)q(a)q(b)})} end - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 18: test local for broadcasting statement input = quote @@ -374,7 +368,7 @@ end output = quote local x .~ Normal(μ, σ) where {created_by = () -> :(local x .~ Normal(μ, σ))} end - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output # Test 19: test local for broadcasting statement input = quote @@ -383,14 +377,13 @@ end output = quote local x .~ Normal(μ, σ) where {q = q(x)q(μ)q(σ), created_by = () -> :(local x .~ Normal(μ, σ) where {q = q(x)q(μ)q(σ)})} end - @test_expression_generating save_expression_in_tilde(input) output + TestUtils.@test_expression_generating save_expression_in_tilde(input) output end -@testitem "get_created_by" begin +@testitem "get_created_by" setup = [TestUtils] begin + using MacroTools import GraphPPL.get_created_by - include("testutils.jl") - # Test 1: only created by input = [:(created_by = (x ~ Normal(0, 1)))] @test get_created_by(input) == :(x ~ Normal(0, 1)) @@ -401,22 +394,21 @@ end # Test 3: created by and other parameters input = [:(created_by = (x ~ Normal(0, 1) where {q} = MeanField())), :(q = MeanField())] - @test_expression_generating get_created_by(input) :(x ~ Normal(0, 1) where {q} = MeanField()) + TestUtils.@test_expression_generating get_created_by(input) :(x ~ Normal(0, 1) where {q} = MeanField()) @test_throws ErrorException get_created_by([:(q = MeanField())]) end -@testitem "convert_deterministic_statement" begin +@testitem "convert_deterministic_statement" setup = [TestUtils] begin + using MacroTools import GraphPPL: convert_deterministic_statement, apply_pipeline - include("testutils.jl") - # Test 1: no deterministic statement input = quote x ~ Normal(0, 1) where {created_by = (x ~ Normal(0, 1))} y ~ Normal(0, 1) where {created_by = (y ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, convert_deterministic_statement) input + TestUtils.@test_expression_generating apply_pipeline(input, convert_deterministic_statement) input # Test 2: deterministic statement input = quote @@ -427,7 +419,7 @@ end x ~ Normal(0, 1) where {created_by = (x := Normal(0, 1)), is_deterministic = true} y ~ Normal(0, 1) where {created_by = (y := Normal(0, 1)), is_deterministic = true} end - @test_expression_generating apply_pipeline(input, convert_deterministic_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_deterministic_statement) output # Test case 3: Input expression with multiple matching patterns input = quote @@ -440,7 +432,7 @@ end y ~ Normal(0, 1) where {created_by = (y := Normal(0, 1)), is_deterministic = true} z ~ Bernoulli(0.5) where {created_by = (z := Bernoulli(0.5))} end - @test_expression_generating apply_pipeline(input, convert_deterministic_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_deterministic_statement) output # Test case 5: Input expression with multiple matching patterns input = quote @@ -449,14 +441,13 @@ end output = quote x ~ (a + b) where {q = q(x)q(a)q(b), created_by = (x := a + b where {q = q(x)q(a)q(b)}), is_deterministic = true} end - @test_expression_generating apply_pipeline(input, convert_deterministic_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_deterministic_statement) output end -@testitem "convert_local_statement" begin +@testitem "convert_local_statement" setup = [TestUtils] begin + using MacroTools import GraphPPL: convert_local_statement, apply_pipeline - include("testutils.jl") - # Test 1: one local statement input = quote local x ~ Normal(0, 1) where {created_by = (x ~ Normal(0, 1))} @@ -465,7 +456,7 @@ end x = GraphPPL.add_variable_node!(__model__, __context__, gensym(__model__, :x)) x ~ Normal(0, 1) where {created_by = (x ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, convert_local_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_local_statement) output # Test 2: two local statements input = quote @@ -478,7 +469,7 @@ end y = GraphPPL.add_variable_node!(__model__, __context__, gensym(__model__, :y)) y ~ Normal(0, 1) where {created_by = (y ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, convert_local_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_local_statement) output # Test 3: mixed local and non-local statements input = quote @@ -491,7 +482,7 @@ end y ~ Normal(0, 1) where {created_by = (y ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, convert_local_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_local_statement) output #Test 4: local statement with multiple matching patterns input = quote local x ~ Normal(a, b) where {q = q(x)q(a)q(b), created_by = (x ~ Normal(a, b) where {q = q(x)q(a)q(b)})} @@ -500,7 +491,7 @@ end x = GraphPPL.add_variable_node!(__model__, __context__, gensym(__model__, :x)) x ~ Normal(a, b) where {q = q(x)q(a)q(b), created_by = (x ~ Normal(a, b) where {q = q(x)q(a)q(b)})} end - @test_expression_generating apply_pipeline(input, convert_local_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_local_statement) output # Test 5: local statement with broadcasting statement input = quote @@ -509,15 +500,13 @@ end output = quote x .~ Normal(μ, σ) where {created_by = (x .~ Normal(μ, σ))} end - @test_expression_generating apply_pipeline(input, convert_local_statement) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_local_statement) output end -@testitem "is_kwargs_expression(::AbstractArray)" begin +@testitem "is_kwargs_expression(::AbstractArray)" setup = [TestUtils] begin import MacroTools: @capture import GraphPPL: is_kwargs_expression - include("testutils.jl") - func_def = :(foo(a, b)) @capture(func_def, (f_(args__))) @test !is_kwargs_expression(args) @@ -550,52 +539,52 @@ end @test !is_kwargs_expression(args) end -@testitem "convert_to_kwargs_expression" begin +@testitem "convert_to_kwargs_expression" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: convert_to_kwargs_expression, apply_pipeline - include("testutils.jl") - # Test 1: Input expression with ~ expression and args and kwargs expressions input = quote x ~ Normal(0, 1; a = 1, b = 2) where {created_by = (x ~ Normal(0, 1; a = 1, b = 2))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 2: Input expression with ~ expression and args and kwargs expressions with symbols input = quote x ~ Normal(μ, σ; a = τ, b = θ) where {created_by = (x ~ Normal(μ, σ; a = τ, b = θ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 3: Input expression with ~ expression and only kwargs expression input = quote x ~ Normal(; a = 1, b = 2) where {created_by = (x ~ Normal(; a = 1, b = 2))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 4: Input expression with ~ expression and only kwargs expression with symbols input = quote x ~ Normal(; a = τ, b = θ) where {created_by = (x ~ Normal(; a = τ, b = θ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 5: Input expression with ~ expression and only args expression input = quote x ~ Normal(0, 1) where {created_by = (x ~ Normal(0, 1))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 6: Input expression with ~ expression and only args expression with symbols input = quote x ~ Normal(μ, σ) where {created_by = (x ~ Normal(μ, σ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 7: Input expression with ~ expression and named args expression input = quote @@ -604,7 +593,7 @@ end output = quote x ~ Normal(; μ = 0, σ = 1) where {created_by = (x ~ Normal(μ = 0, σ = 1))} end - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 8: Input expression with ~ expression and named args expression with symbols input = quote @@ -613,49 +602,49 @@ end output = quote x ~ Normal(; μ = μ, σ = σ) where {created_by = (x ~ Normal(μ = μ, σ = σ))} end - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 9: Input expression with .~ expression and args and kwargs expressions input = quote x .~ Normal(0, 1; a = 1, b = 2) where {created_by = (x .~ Normal(0, 1; a = 1, b = 2))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 10: Input expression with .~ expression and args and kwargs expressions with symbols input = quote x .~ Normal(μ, σ; a = τ, b = θ) where {created_by = (x .~ Normal(μ, σ; a = τ, b = θ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 11: Input expression with .~ expression and only kwargs expression input = quote x .~ Normal(; a = 1, b = 2) where {created_by = (x .~ Normal(; a = 1, b = 2))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 12: Input expression with .~ expression and only kwargs expression with symbols input = quote x .~ Normal(; a = τ, b = θ) where {created_by = (x .~ Normal(; a = τ, b = θ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 13: Input expression with .~ expression and only args expression input = quote x .~ Normal(0, 1) where {created_by = (x .~ Normal(0, 1))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 14: Input expression with .~ expression and only args expression with symbols input = quote x .~ Normal(μ, σ) where {created_by = (x .~ Normal(μ, σ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 15: Input expression with .~ expression and named args expression input = quote @@ -664,7 +653,7 @@ end output = quote x .~ Normal(; μ = 0, σ = 1) where {created_by = (x .~ Normal(μ = 0, σ = 1))} end - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 16: Input expression with .~ expression and named args expression with symbols input = quote @@ -673,49 +662,49 @@ end output = quote x .~ Normal(; μ = μ, σ = σ) where {created_by = (x .~ Normal(μ = μ, σ = σ))} end - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 17: Input expression with := expression and args and kwargs expressions input = quote x := Normal(0, 1; a = 1, b = 2) where {created_by = (x := Normal(0, 1; a = 1, b = 2))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 18: Input expression with := expression and args and kwargs expressions with symbols input = quote x := Normal(μ, σ; a = τ, b = θ) where {created_by = (x := Normal(μ, σ; a = τ, b = θ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 19: Input expression with := expression and only kwargs expression input = quote x := Normal(; a = 1, b = 2) where {created_by = (x := Normal(; a = 1, b = 2))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 20: Input expression with := expression and only kwargs expression with symbols input = quote x := Normal(; a = τ, b = θ) where {created_by = (x := Normal(; a = τ, b = θ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 21: Input expression with := expression and only args expression input = quote x := Normal(0, 1) where {created_by = (x := Normal(0, 1))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 22: Input expression with := expression and only args expression with symbols input = quote x := Normal(μ, σ) where {created_by = (x := Normal(μ, σ))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 23: Input expression with := expression and named args as args expression input = quote @@ -724,7 +713,7 @@ end output = quote x := Normal(; μ = 0, σ = 1) where {created_by = (x := Normal(μ = 0, σ = 1))} end - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 24: Input expression with := expression and named args as args expression with symbols input = quote @@ -733,21 +722,21 @@ end output = quote x := Normal(; μ = μ, σ = σ) where {created_by = (x := Normal(μ = μ, σ = σ))} end - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 25: Input expression with ~ expression and additional arguments in where clause input = quote x ~ Normal(0, 1) where {q = MeanField(), created_by = (x ~ Normal(0, 1)) where {q} = MeanField()} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 26: Input expression with nested call in rhs input = quote x ~ Normal(Normal(0, 1)) where {created_by = (x ~ Normal(Normal(0, 1)))} end output = input - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output # Test 27: Input expression with additional where clause on rhs input = quote @@ -756,14 +745,14 @@ end output = quote x ~ Normal(; μ = μ, σ = σ) where {created_by = (x ~ Normal(μ = μ, σ = σ) where {q = MeanField()}), q = MeanField()} end - @test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_to_kwargs_expression) output end -@testitem "convert_to_anonymous" begin +@testitem "convert_to_anonymous" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: convert_to_anonymous, apply_pipeline - include("testutils.jl") - # Test 1: convert function to anonymous function input = quote Normal(0, 1) @@ -774,7 +763,7 @@ end var"#anon" ~ Normal(0, 1) where {anonymous = true, created_by = x ~ Normal(0, 1)} end end - @test_expression_generating convert_to_anonymous(input, created_by) output + TestUtils.@test_expression_generating convert_to_anonymous(input, created_by) output # Test 2: leave number expression input = quote @@ -782,7 +771,7 @@ end end created_by = :(x ~ Normal(0, 1)) output = input - @test_expression_generating convert_to_anonymous(input, created_by) output + TestUtils.@test_expression_generating convert_to_anonymous(input, created_by) output # Test 3: leave symbol expression input = quote @@ -790,7 +779,7 @@ end end created_by = :(x ~ Normal(0, 1)) output = input - @test_expression_generating convert_to_anonymous(input, created_by) output + TestUtils.@test_expression_generating convert_to_anonymous(input, created_by) output # Test 4: handle broadcasted expression input = quote @@ -802,7 +791,7 @@ end var"#anon" .~ Normal(fill(0, 10), fill(1, 10)) where {anonymous = true, created_by = x ~ Normal(0, 1)} end end - @test_expression_generating convert_to_anonymous(input, created_by) output + TestUtils.@test_expression_generating convert_to_anonymous(input, created_by) output # Test 5: handle broadcasted expression with special cases input = quote @@ -814,14 +803,13 @@ end var"#anon" .~ ((a + b) where {anonymous = true, created_by = x ~ Normal(0, 1)}) end end - @test_expression_generating convert_to_anonymous(input, created_by) output + TestUtils.@test_expression_generating convert_to_anonymous(input, created_by) output end -@testitem "not_enter_indexed_walk" begin +@testitem "not_enter_indexed_walk" setup = [TestUtils] begin + using MacroTools import GraphPPL: not_enter_indexed_walk - include("testutils.jl") - # Test 1: not enter indexed walk input = quote x[1] ~ y[10 + 1] @@ -841,11 +829,11 @@ end end end -@testitem "convert_anonymous_variables" begin +@testitem "convert_anonymous_variables" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: convert_anonymous_variables, apply_pipeline - include("testutils.jl") - #Test 1: Input expression with a function call in rhs arguments input = quote x ~ Normal(Normal(0, 1), 1) where {created_by = (x ~ Normal(Normal(0, 1), 1))} @@ -858,7 +846,7 @@ end 1 ) where {created_by = (x ~ Normal(Normal(0, 1), 1))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output #Test 2: Input expression without pattern matching input = quote @@ -867,7 +855,7 @@ end output = quote x ~ Normal(0, 1) where {created_by = (x ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output #Test 3: Input expression with a function call as kwargs input = quote @@ -880,7 +868,7 @@ end end, σ = 1 ) where {created_by = (x ~ Normal(; μ = Normal(0, 1), σ = 1))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output #Test 4: Input expression without pattern matching and kwargs input = quote @@ -889,7 +877,7 @@ end output = quote x ~ Normal(; μ = 0, σ = 1) where {created_by = (x ~ Normal(μ = 0, σ = 1))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output #Test 5: Input expression with multiple function calls in rhs arguments input = quote @@ -905,7 +893,7 @@ end end ) where {created_by = (x ~ Normal(Normal(0, 1), Normal(0, 1)))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output #Test 6: Input expression with multiple function calls in rhs arguments and kwargs input = quote @@ -921,7 +909,7 @@ end end ) where {created_by = (x ~ Normal(; μ = Normal(0, 1), σ = Normal(0, 1)))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output #Test 7: Input expression with nested function call in rhs arguments input = quote @@ -941,7 +929,7 @@ end ) where {created_by = (x ~ Normal(Normal(Normal(0, 1), 1), 1))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output #Test 8: Input expression with nested function call in rhs arguments and kwargs and additional where clause input = quote @@ -964,14 +952,14 @@ end 1 ) where {q = MeanField(), created_by = (x ~ Normal(Normal(Normal(0, 1), 1), 1) where {q = MeanField()})} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output # Test 9: Input expression with arithmetic indexed call on rhs input = quote x ~ Normal(x[i - 1], 1) where {created_by = (x ~ Normal(y[i - 1], 1))} end output = input - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output # Test 10: Input expression with broadcasted call input = quote @@ -994,7 +982,7 @@ end 1 ) where {q = MeanField(), created_by = (x ~ Normal(Normal(Normal(0, 1), 1), 1) where {q = MeanField()})} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output # Test 11: Input expression with broadcasted call as anonymous variable input = quote @@ -1008,7 +996,7 @@ end 1 ) where {created_by = (x ~ Normal(f.(y), 1))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output # Test 12: Input expression with nested broadcasted call as anonymous variable input = quote @@ -1026,7 +1014,7 @@ end 1 ) where {created_by = (x ~ Normal(f.(g.(y)), 1))} end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output input = quote y .~ Normal(a .* y .+ b, 1) where {created_by = y .~ Normal(a .* y .+ b, 1)} @@ -1047,7 +1035,7 @@ end ) where {(created_by = (y .~ Normal(a .* y .+ b, 1)))} ) end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output input = quote y .~ Normal( @@ -1178,20 +1166,20 @@ end ) ) end - @test_expression_generating apply_pipeline(input, convert_anonymous_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_anonymous_variables) output end -@testitem "proxy_args_rhs" begin +@testitem "proxy_args_rhs" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: proxy_args_rhs, apply_pipeline, recursive_rhs_indexing - include("testutils.jl") - # Test 1: Input expression with a function call in rhs arguments input = :x output = quote GraphPPL.proxylabel(:x, x, nothing, GraphPPL.False()) end - @test_expression_generating proxy_args_rhs(input) output + TestUtils.@test_expression_generating proxy_args_rhs(input) output input = quote x[1] @@ -1199,7 +1187,7 @@ end output = quote GraphPPL.proxylabel(:x, x, (1,), GraphPPL.False()) end - @test_expression_generating proxy_args_rhs(input) output + TestUtils.@test_expression_generating proxy_args_rhs(input) output input = quote x[1, 2] @@ -1207,7 +1195,7 @@ end output = quote GraphPPL.proxylabel(:x, x, (1, 2), GraphPPL.False()) end - @test_expression_generating proxy_args_rhs(input) output + TestUtils.@test_expression_generating proxy_args_rhs(input) output input = quote x[1][1] @@ -1215,14 +1203,14 @@ end output = quote GraphPPL.proxylabel(:x, GraphPPL.proxylabel(:x, x, (1,), GraphPPL.False()), (1,), GraphPPL.False()) end - @test_expression_generating proxy_args_rhs(input) output + TestUtils.@test_expression_generating proxy_args_rhs(input) output end -@testitem "add_get_or_create_expression" begin +@testitem "add_get_or_create_expression" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: add_get_or_create_expression, apply_pipeline - include("testutils.jl") - #Test 1: test scalar variable input = quote x ~ Normal(0, 1) where {created_by = (x ~ Normal(0, 1))} @@ -1235,7 +1223,7 @@ end end x ~ (Normal(0, 1) where {(created_by = (x ~ Normal(0, 1)))}) end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output #Test 1.1: test scalar variable input = quote @@ -1249,7 +1237,7 @@ end end x ~ (Gamma(0, 1) where {(created_by = (x ~ Gamma(0, 1)))}) end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output #Test 2: test vector variable input = quote @@ -1263,7 +1251,7 @@ end end x[1] ~ Normal(0, 1) where {created_by = (x[1] ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output #Test 3: test matrix variable input = quote @@ -1277,7 +1265,7 @@ end end x[1, 2] ~ Normal(0, 1) where {created_by = (x[1, 2] ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output #Test 4: test vector variable with variable as index input = quote @@ -1291,7 +1279,7 @@ end end x[i] ~ Normal(0, 1) where {created_by = (x[i] ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output #Test 5: test matrix variable with symbol as index input = quote @@ -1305,7 +1293,7 @@ end end x[i, j] ~ Normal(0, 1) where {created_by = (x[i, j] ~ Normal(0, 1))} end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output #Test 4: test function call in parameters on rhs sym = gensym(:anon) @@ -1335,7 +1323,7 @@ end 1 ) where {created_by = (x ~ Normal(Normal(0, 1), 1))} end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output # Test 5: Input expression with NodeLabel on rhs input = quote @@ -1349,7 +1337,7 @@ end end y ~ x where {created_by = (y := x), is_deterministic = true} end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output # Test 6: Input expression with additional options on rhs input = quote @@ -1363,14 +1351,14 @@ end end x ~ Normal(0, 1) where {created_by = (x ~ Normal(0, 1) where {q = q(x)q(y)}), q = q(x)q(y)} end - @test_expression_generating apply_pipeline(input, add_get_or_create_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, add_get_or_create_expression) output end -@testitem "generate_get_or_create" begin +@testitem "generate_get_or_create" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: generate_get_or_create, apply_pipeline - include("testutils.jl") - # Test 1: test scalar variable output = generate_get_or_create(:x, nothing, :(Normal(0, 1))) desired_result = quote @@ -1380,7 +1368,7 @@ end x end end - @test_expression_generating output desired_result + TestUtils.@test_expression_generating output desired_result # Test 2: test vector variable output = generate_get_or_create(:x, [1], :(Gamma(0, 1))) @@ -1391,7 +1379,7 @@ end x end end - @test_expression_generating output desired_result + TestUtils.@test_expression_generating output desired_result # Test 3: test matrix variable output = generate_get_or_create(:x, [1, 2], :(unknownsymbol)) @@ -1402,7 +1390,7 @@ end x end end - @test_expression_generating output desired_result + TestUtils.@test_expression_generating output desired_result # Test 5: test symbol-indexed variable output = generate_get_or_create(:x, [:i, :j], :(unknownsymbol)) @@ -1413,7 +1401,7 @@ end x end end - @test_expression_generating output desired_result + TestUtils.@test_expression_generating output desired_result # Test 6: test vector of single symbol output = generate_get_or_create(:x, [:i], :(f(argument))) @@ -1424,7 +1412,7 @@ end x end end - @test_expression_generating output desired_result + TestUtils.@test_expression_generating output desired_result # Test 7: test vector of symbols output = generate_get_or_create(:x, [:i, :j], :(f(keyword = 1))) @@ -1435,7 +1423,7 @@ end x end end - @test_expression_generating output desired_result + TestUtils.@test_expression_generating output desired_result # Test 8: test error if un-unrollable index @test_throws MethodError generate_get_or_create(:x, 2, :(Normal())) @@ -1444,12 +1432,11 @@ end @test_throws MethodError generate_get_or_create(:x, prod(0, 1), :(Normal())) end -@testitem "keyword_expressions_to_named_tuple" begin +@testitem "keyword_expressions_to_named_tuple" setup = [TestUtils] begin import MacroTools: @capture + using Distributions import GraphPPL: keyword_expressions_to_named_tuple, apply_pipeline, convert_to_kwargs_expression - include("testutils.jl") - expr = [:($(Expr(:kw, :in1, :y))), :($(Expr(:kw, :in2, :z)))] @test keyword_expressions_to_named_tuple(expr) == :((in1 = y, in2 = z)) @@ -1472,27 +1459,27 @@ end @test keyword_expressions_to_named_tuple(kwargs) == :((a = 1, b = 2)) end -@testitem "combine_broadcast_args" begin +@testitem "combine_broadcast_args" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: combine_broadcast_args - include("testutils.jl") - - @test_expression_generating combine_broadcast_args([:μ, :σ], nothing) quote + TestUtils.@test_expression_generating combine_broadcast_args([:μ, :σ], nothing) quote args end - @test_expression_generating combine_broadcast_args([], [Expr(:kw, :μ, :μ), Expr(:kw, :σ, :σ)]) quote + TestUtils.@test_expression_generating combine_broadcast_args([], [Expr(:kw, :μ, :μ), Expr(:kw, :σ, :σ)]) quote NamedTuple{$(:μ, :σ)}(args) end - @test_expression_generating combine_broadcast_args([:μ, :σ], [Expr(:kw, :μ, :μ), Expr(:kw, :σ, :σ)]) quote + TestUtils.@test_expression_generating combine_broadcast_args([:μ, :σ], [Expr(:kw, :μ, :μ), Expr(:kw, :σ, :σ)]) quote GraphPPL.MixedArguments((μ, σ), NamedTuple{$(:μ, :σ)}(args)) end end -@testitem "convert_tilde_expression" begin +@testitem "convert_tilde_expression" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: convert_tilde_expression, apply_pipeline - include("testutils.jl") - # Test 1: Test regular node creation input input = quote x ~ sum(0, 1) where {created_by = :(x ~ Normal(0, 1))} @@ -1513,7 +1500,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 2: Test regular node creation input with kwargs input = quote @@ -1535,7 +1522,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 3: Test regular node creation with indexed input input = quote @@ -1554,7 +1541,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 4: Test node creation with anonymous variable input = quote @@ -1603,7 +1590,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 5: Test node creation with non-function on rhs @@ -1623,7 +1610,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 6: Test node creation with non-function on rhs with indexed statement @@ -1643,7 +1630,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 7: Test node creation with non-function on rhs with multidimensional array @@ -1663,7 +1650,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 8: Test node creation with mixed args and kwargs on rhs input = quote @@ -1691,7 +1678,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 9: Test node creation with additional options input = quote @@ -1710,7 +1697,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 10: Test node creation with kwargs and symbols_to_expression input = quote @@ -1729,7 +1716,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output input = quote y ~ prior() where {created_by = :(y ~ prior())} @@ -1747,7 +1734,7 @@ end var"#var" end end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 11: Test node creation with broadcasting call input = quote @@ -1772,7 +1759,7 @@ end end GraphPPL.__check_vectorized_input(var"#rvar#") end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 12: Test node creation with broadcasting call with kwargs input = quote @@ -1798,7 +1785,7 @@ end GraphPPL.__check_vectorized_input(var"#rvar") end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 13: Test node creation with broadcasting call with mixed args and kwargs input = quote @@ -1828,7 +1815,7 @@ end end GraphPPL.__check_vectorized_input(var"#rvar") end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output # Test 14: Test node creation with splatting inside input = quote @@ -1845,14 +1832,14 @@ end ) var"#var" end - @test_expression_generating apply_pipeline(input, convert_tilde_expression) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_tilde_expression) output end -@testitem "options_vector_to_factoroptions" begin +@testitem "options_vector_to_factoroptions" setup = [TestUtils] begin + using Distributions + using MacroTools import GraphPPL: options_vector_to_named_tuple - include("testutils.jl") - # Test 1: Test with empty input input = [] output = :((;)) @@ -1878,8 +1865,11 @@ end @test_throws ErrorException options_vector_to_named_tuple(input) end -@testitem "model_macro_interior" begin - using LinearAlgebra, MetaGraphsNext, Graphs +@testitem "model_macro_interior" setup = [TestUtils] begin + using Distributions + using LinearAlgebra + using MacroTools + using Static import GraphPPL: model_macro_interior, create_model, @@ -1890,17 +1880,16 @@ end add_terminated_submodel!, NodeCreationOptions, getproperties, - Context - - include("testutils.jl") - - using .TestUtils.ModelZoo + Context, + @model, + nv, + ne # Test 1: Test regular node creation input @model function test_model(μ, σ) x ~ sum(μ, σ) end - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() μ = getorcreate!(model, ctx, :μ, nothing) @@ -1917,7 +1906,7 @@ end y ~ x[1] + x[10] end - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() μ = getorcreate!(model, ctx, :μ, nothing) @@ -1939,7 +1928,7 @@ end end y ~ x[1] + x[10] + x[11] end - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() μ = getorcreate!(model, ctx, :μ, nothing) @@ -1956,7 +1945,7 @@ end x ~ y + z end end - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() x = getorcreate!(model, ctx, :x, nothing) @@ -1969,7 +1958,7 @@ end z ~ Normal(x, Matrix{Float64}(Diagonal(ones(4)))) y ~ Normal(z, 1) end - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() x = getorcreate!(model, ctx, :x, nothing) @@ -2013,18 +2002,18 @@ end @test GraphPPL.nv(model) == 7 && GraphPPL.ne(model) == 6 # Test add_terminated_submodel! - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() for i in 1:10 getorcreate!(model, ctx, :y, i) end y = getorcreate!(model, ctx, :y, 1) - GraphPPL.add_terminated_submodel!(model, ctx, options, hgf, (y = y,), static(1)) + GraphPPL.add_terminated_submodel!(model, ctx, options, TestUtils.hgf, (y = y,), static(1)) @test haskey(ctx, :ω_2) && haskey(ctx, :x_1) && haskey(ctx, :x_2) && haskey(ctx, :x_3) # Test anonymous variable creation - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() for i in 1:10 @@ -2032,16 +2021,15 @@ end end x_arr = getorcreate!(model, ctx, :x, 1) y = getorcreate!(model, ctx, :y, nothing) - make_node!(model, ctx, options, anonymous_in_loop, proxylabel(:y, y, nothing), (x = x_arr,)) + make_node!(model, ctx, options, TestUtils.anonymous_in_loop, proxylabel(:y, y, nothing), (x = x_arr,)) @test nv(model) == 67 end -@testitem "ModelGenerator based constructor is being created" begin +@testitem "ModelGenerator based constructor is being created" setup = [TestUtils] begin + using Distributions import GraphPPL: ModelGenerator - include("testutils.jl") - - @model function foo(x, y) + TestUtils.@model function foo(x, y) x ~ y + 1 end @@ -2050,11 +2038,10 @@ end @test foo(x = 1, y = 1) isa ModelGenerator end -@testitem "`default_backend` should be set from the `model_macro_interior`" begin +@testitem "`default_backend` should be set from the `model_macro_interior`" setup = [TestUtils] begin + using Distributions import GraphPPL: default_backend, model_macro_interior - include("testutils.jl") - model_spec = quote function hello(a, b, c) a ~ Normal(b, c) @@ -2066,10 +2053,8 @@ end @test default_backend(hello) === TestUtils.TestGraphPPLBackend() end -@testitem "error message for other number of interfaces" begin - using GraphPPL +@testitem "error message for other number of interfaces" setup = [TestUtils] begin using Distributions - import GraphPPL: @model @model function somemodel(a, b, c) @@ -2085,8 +2070,8 @@ end @test !isnothing(GraphPPL.create_model(somemodel(a = 1, b = 2))) end -@testitem "model should warn users against incorrect usages of `=` operator with random variables" begin - using GraphPPL, Distributions +@testitem "model should warn users against incorrect usages of `=` operator with random variables" setup = [TestUtils] begin + using Distributions import GraphPPL: @model @model function somemodel() diff --git a/test/nodes/node_data_tests.jl b/test/nodes/node_data_tests.jl index fba52203..5f529d51 100644 --- a/test/nodes/node_data_tests.jl +++ b/test/nodes/node_data_tests.jl @@ -1,9 +1,7 @@ -@testitem "NodeData constructor" begin +@testitem "NodeData constructor" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, NodeData, FactorNodeProperties, VariableNodeProperties, getproperties - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() context = getcontext(model) @testset "FactorNodeProperties" begin @@ -42,7 +40,7 @@ end end -@testitem "NodeDataExtraKey" begin +@testitem "NodeDataExtraKey" setup = [TestUtils] begin import GraphPPL: NodeDataExtraKey, getkey @test NodeDataExtraKey{:a, Int}() isa NodeDataExtraKey @@ -54,7 +52,7 @@ end @test getkey(NodeDataExtraKey{:b, Float64}()) === :b end -@testitem "NodeData extra properties" begin +@testitem "NodeData extra properties" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, @@ -67,9 +65,7 @@ end hasextra, NodeDataExtraKey - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() context = getcontext(model) @testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1)) @@ -115,11 +111,9 @@ end end end -@testitem "NodeCreationOptions" begin +@testitem "NodeCreationOptions" setup = [TestUtils] begin import GraphPPL: NodeCreationOptions, withopts, withoutopts - include("testutils.jl") - @test NodeCreationOptions() == NodeCreationOptions() @test keys(NodeCreationOptions()) === () @test NodeCreationOptions(arbitrary_option = 1) == NodeCreationOptions((; arbitrary_option = 1)) @@ -158,14 +152,10 @@ end @test @inferred(withoutopts(NodeCreationOptions(a = 1, b = 2), Val((:c,)))) == NodeCreationOptions(a = 1, b = 2) end -@testitem "is_constant" begin +@testitem "is_constant" setup = [TestUtils] begin import GraphPPL: create_model, is_constant, variable_nodes, getname, getproperties - include("testutils.jl") - - using .TestUtils.ModelZoo - - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(model_fn()) for label in variable_nodes(model) node = model[label] @@ -179,20 +169,16 @@ end end end -@testitem "is_data" begin +@testitem "is_data" setup = [TestUtils] begin import GraphPPL: is_data, create_model, getcontext, getorcreate!, variable_nodes, NodeCreationOptions, getproperties - include("testutils.jl") - - m = create_test_model() + m = TestUtils.create_test_model() ctx = getcontext(m) xref = getorcreate!(m, ctx, NodeCreationOptions(kind = :data), :x, nothing) @test is_data(getproperties(m[xref])) - using .TestUtils.ModelZoo - # Since the models here are without top arguments they cannot create `data` labels - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(model_fn()) for label in variable_nodes(model) @test !is_data(getproperties(model[label])) @@ -200,13 +186,11 @@ end end end -@testitem "Predefined kinds of variable nodes" begin +@testitem "Predefined kinds of variable nodes" setup = [TestUtils] begin import GraphPPL: VariableKindRandom, VariableKindData, VariableKindConstant import GraphPPL: getcontext, getorcreate!, NodeCreationOptions, getproperties - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() context = getcontext(model) xref = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindRandom), :x, nothing) y = getorcreate!(model, context, NodeCreationOptions(kind = VariableKindData), :y, nothing) diff --git a/test/nodes/node_label_tests.jl b/test/nodes/node_label_tests.jl index fe37a839..a31f9df3 100644 --- a/test/nodes/node_label_tests.jl +++ b/test/nodes/node_label_tests.jl @@ -1,4 +1,4 @@ -@testitem "NodeLabel properties" begin +@testitem "NodeLabel properties" setup = [TestUtils] begin import GraphPPL: NodeLabel xref = NodeLabel(:x, 1) @@ -10,7 +10,7 @@ @test xref < y end -@testitem "getname(::NodeLabel)" begin +@testitem "getname(::NodeLabel)" setup = [TestUtils] begin import GraphPPL: ResizableArray, NodeLabel, getname xref = NodeLabel(:x, 1) @@ -25,12 +25,10 @@ end @test getname(xref) == :x end -@testitem "generate_nodelabel(::Model, ::Symbol)" begin +@testitem "generate_nodelabel(::Model, ::Symbol)" setup = [TestUtils] begin import GraphPPL: create_model, gensym, NodeLabel, generate_nodelabel - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() first_sym = generate_nodelabel(model, :x) @test typeof(first_sym) == NodeLabel @@ -41,7 +39,7 @@ end @test id.name == :c && id.global_counter == 3 end -@testitem "proxy labels" begin +@testitem "proxy labels" setup = [TestUtils] begin import GraphPPL: NodeLabel, ProxyLabel, proxylabel, getname, unroll, ResizableArray, FunctionalIndex y = NodeLabel(:y, 1) @@ -106,12 +104,10 @@ end end end -@testitem "datalabel" begin +@testitem "datalabel" setup = [TestUtils] begin import GraphPPL: getcontext, datalabel, NodeCreationOptions, VariableKindData, VariableKindRandom, unroll, proxylabel - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) ylabel = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :y) yvar = unroll(ylabel) @@ -140,12 +136,10 @@ end ) end -@testitem "contains_nodelabel" begin +@testitem "contains_nodelabel" setup = [TestUtils] begin import GraphPPL: create_model, getcontext, getorcreate!, contains_nodelabel, NodeCreationOptions, True, False, MixedArguments - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() ctx = getcontext(model) a = getorcreate!(model, ctx, :x, nothing) b = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :x, nothing) diff --git a/test/nodes/node_semantics_tests.jl b/test/nodes/node_semantics_tests.jl index a5f1b54d..7ba2c5e5 100644 --- a/test/nodes/node_semantics_tests.jl +++ b/test/nodes/node_semantics_tests.jl @@ -1,54 +1,46 @@ -@testitem "NodeType" begin +@testitem "NodeType" setup = [TestUtils] begin + using Distributions import GraphPPL: NodeType, Composite, Atomic - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() + model = TestUtils.create_test_model() @test NodeType(model, Composite) == Atomic() @test NodeType(model, Atomic) == Atomic() @test NodeType(model, abs) == Atomic() @test NodeType(model, Normal) == Atomic() - @test NodeType(model, NormalMeanVariance) == Atomic() - @test NodeType(model, NormalMeanPrecision) == Atomic() + @test NodeType(model, TestUtils.NormalMeanVariance) == Atomic() + @test NodeType(model, TestUtils.NormalMeanPrecision) == Atomic() # Could test all here - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments @test NodeType(model, model_fn) == Composite() end end -@testitem "NodeBehaviour" begin +@testitem "NodeBehaviour" setup = [TestUtils] begin + using Distributions import GraphPPL: NodeBehaviour, Deterministic, Stochastic - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() + model = TestUtils.create_test_model() @test NodeBehaviour(model, () -> 1) == Deterministic() @test NodeBehaviour(model, Matrix) == Deterministic() @test NodeBehaviour(model, abs) == Deterministic() @test NodeBehaviour(model, Normal) == Stochastic() - @test NodeBehaviour(model, NormalMeanVariance) == Stochastic() - @test NodeBehaviour(model, NormalMeanPrecision) == Stochastic() + @test NodeBehaviour(model, TestUtils.NormalMeanVariance) == Stochastic() + @test NodeBehaviour(model, TestUtils.NormalMeanPrecision) == Stochastic() # Could test all here - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments @test NodeBehaviour(model, model_fn) == Stochastic() end end -@testitem "interface_alias" begin - using GraphPPL +@testitem "interface_alias" setup = [TestUtils] begin + using Distributions import GraphPPL: interface_aliases, StaticInterfaces - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :τ)))) === StaticInterfaces((:out, :μ, :τ)) @test @inferred(interface_aliases(model, Normal, StaticInterfaces((:out, :mean, :precision)))) === StaticInterfaces((:out, :μ, :τ)) @@ -87,11 +79,9 @@ end @test @allocated(interface_aliases(model, Normal, StaticInterfaces((:out, :μ, :variance)))) === 0 end -@testitem "factor_alias" begin +@testitem "factor_alias" setup = [TestUtils] begin import GraphPPL: factor_alias, StaticInterfaces - include("testutils.jl") - function abc end function xyz end @@ -101,7 +91,7 @@ end GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInterfaces{(:a, :b)}) = abc GraphPPL.factor_alias(::TestUtils.TestGraphPPLBackend, ::typeof(xyz), ::StaticInterfaces{(:x, :y)}) = xyz - model = create_test_model() + model = TestUtils.create_test_model() @test factor_alias(model, abc, StaticInterfaces((:a, :b))) === abc @test factor_alias(model, abc, StaticInterfaces((:x, :y))) === xyz @@ -110,14 +100,11 @@ end @test factor_alias(model, xyz, StaticInterfaces((:x, :y))) === xyz end -@testitem "default_parametrization" begin +@testitem "default_parametrization" setup = [TestUtils] begin + using Distributions import GraphPPL: default_parametrization, Composite, Atomic - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() + model = TestUtils.create_test_model() # Test 1: Add default arguments to Normal call @test default_parametrization(model, Atomic(), Normal, (0, 1)) == (μ = 0, σ = 1) @@ -128,10 +115,10 @@ end # Test 3: Add :in to function call that has default behaviour with nested interfaces @test default_parametrization(model, Atomic(), +, ([1, 1], 2)) == (in = ([1, 1], 2),) - @test_throws ErrorException default_parametrization(model, Composite(), gcv, (1, 2)) + @test_throws ErrorException default_parametrization(model, Composite(), TestUtils.gcv, (1, 2)) end -@testitem "getindex for StaticInterfaces" begin +@testitem "getindex for StaticInterfaces" setup = [TestUtils] begin import GraphPPL: StaticInterfaces interfaces = (:a, :b, :c) @@ -142,12 +129,11 @@ end end end -@testitem "missing_interfaces" begin +@testitem "missing_interfaces" setup = [TestUtils] begin + using Static import GraphPPL: missing_interfaces, interfaces - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() function abc end @@ -172,38 +158,32 @@ end @test missing_interfaces(model, bar, static(2), (in1 = 1, in2 = 2, out = 3, test = 4)) == GraphPPL.StaticInterfaces(()) end -@testitem "sort_interfaces" begin +@testitem "sort_interfaces" setup = [TestUtils] begin import GraphPPL: sort_interfaces - include("testutils.jl") - - model = create_test_model() + model = TestUtils.create_test_model() # Test 1: Test that sort_interfaces sorts the interfaces in the correct order - @test sort_interfaces(model, NormalMeanVariance, (μ = 1, σ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanVariance, (out = 1, μ = 1, σ = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanVariance, (σ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanVariance, (σ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (μ = 1, τ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (out = 1, μ = 1, τ = 1)) == (out = 1, μ = 1, τ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (τ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, τ = 1) - @test sort_interfaces(model, NormalMeanPrecision, (τ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) - - @test_throws ErrorException sort_interfaces(model, NormalMeanVariance, (σ = 1, μ = 1, τ = 1)) + @test sort_interfaces(model, TestUtils.NormalMeanVariance, (μ = 1, σ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, TestUtils.NormalMeanVariance, (out = 1, μ = 1, σ = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, TestUtils.NormalMeanVariance, (σ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, TestUtils.NormalMeanVariance, (σ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(model, TestUtils.NormalMeanPrecision, (μ = 1, τ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(model, TestUtils.NormalMeanPrecision, (out = 1, μ = 1, τ = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(model, TestUtils.NormalMeanPrecision, (τ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(model, TestUtils.NormalMeanPrecision, (τ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) + + @test_throws ErrorException sort_interfaces(model, TestUtils.NormalMeanVariance, (σ = 1, μ = 1, τ = 1)) end -@testitem "prepare_interfaces" begin +@testitem "prepare_interfaces" setup = [TestUtils] begin import GraphPPL: prepare_interfaces - include("testutils.jl") - - using .TestUtils.ModelZoo - - model = create_test_model() + model = TestUtils.create_test_model() - @test prepare_interfaces(model, anonymous_in_loop, 1, (y = 1,)) == (x = 1, y = 1) - @test prepare_interfaces(model, anonymous_in_loop, 1, (x = 1,)) == (y = 1, x = 1) + @test prepare_interfaces(model, TestUtils.anonymous_in_loop, 1, (y = 1,)) == (x = 1, y = 1) + @test prepare_interfaces(model, TestUtils.anonymous_in_loop, 1, (x = 1,)) == (y = 1, x = 1) - @test prepare_interfaces(model, type_arguments, 1, (x = 1,)) == (n = 1, x = 1) - @test prepare_interfaces(model, type_arguments, 1, (n = 1,)) == (x = 1, n = 1) + @test prepare_interfaces(model, TestUtils.type_arguments, 1, (x = 1,)) == (n = 1, x = 1) + @test prepare_interfaces(model, TestUtils.type_arguments, 1, (n = 1,)) == (x = 1, n = 1) end \ No newline at end of file diff --git a/test/plugins/meta/meta_engine_tests.jl b/test/plugins/meta/meta_engine_tests.jl index 64dc6450..b8b638f6 100644 --- a/test/plugins/meta/meta_engine_tests.jl +++ b/test/plugins/meta/meta_engine_tests.jl @@ -1,8 +1,7 @@ @testitem "FactorMetaDescriptor" begin + using Distributions import GraphPPL: FactorMetaDescriptor, IndexedVariable - include("../../testutils.jl") - @test FactorMetaDescriptor(Normal, (:x, :k, :w)) isa FactorMetaDescriptor{<:Tuple} @test FactorMetaDescriptor(Gamma, nothing) isa FactorMetaDescriptor{Nothing} end @@ -15,10 +14,9 @@ end end @testitem "MetaObject" begin + using Distributions import GraphPPL: MetaObject, FactorMetaDescriptor, IndexedVariable, VariableMetaDescriptor - include("../../testutils.jl") - struct SomeMeta end @test MetaObject(FactorMetaDescriptor(Normal, (IndexedVariable(:x, nothing), :k, :w)), SomeMeta()) isa @@ -38,13 +36,10 @@ end @test MetaSpecification() isa MetaSpecification end -@testitem "SpecificSubModelMeta" begin +@testitem "SpecificSubModelMeta" setup = [TestUtils] begin + using Distributions import GraphPPL: SpecificSubModelMeta, GeneralSubModelMeta, MetaSpecification, IndexedVariable, FactorMetaDescriptor, MetaObject - include("../../testutils.jl") - - using .TestUtils.ModelZoo - struct SomeMeta end @test SpecificSubModelMeta(GraphPPL.FactorID(sum, 1), MetaSpecification()) isa SpecificSubModelMeta @@ -56,10 +51,11 @@ end SpecificSubModelMeta(GraphPPL.FactorID(sum, 1), MetaSpecification()), SpecificSubModelMeta(GraphPPL.FactorID(sum, 1), MetaSpecification()) ) - push!(SpecificSubModelMeta(GraphPPL.FactorID(sum, 1), MetaSpecification()), GeneralSubModelMeta(gcv, MetaSpecification())) + push!(SpecificSubModelMeta(GraphPPL.FactorID(sum, 1), MetaSpecification()), GeneralSubModelMeta(TestUtils.gcv, MetaSpecification())) end -@testitem "GeneralSubModelMeta" begin +@testitem "GeneralSubModelMeta" setup = [TestUtils] begin + using Distributions import GraphPPL: SpecificSubModelMeta, GeneralSubModelMeta, @@ -69,20 +65,16 @@ end MetaObject, getgeneralssubmodelmeta - include("../../testutils.jl") - - using .TestUtils.ModelZoo - struct SomeMeta end - @test GeneralSubModelMeta(gcv, MetaSpecification()) isa GeneralSubModelMeta + @test GeneralSubModelMeta(TestUtils.gcv, MetaSpecification()) isa GeneralSubModelMeta push!( - GeneralSubModelMeta(gcv, MetaSpecification()), + GeneralSubModelMeta(TestUtils.gcv, MetaSpecification()), MetaObject(FactorMetaDescriptor(Normal, (IndexedVariable(:x, nothing), :k, :w)), SomeMeta()) ) - push!(GeneralSubModelMeta(gcv, MetaSpecification()), SpecificSubModelMeta(GraphPPL.FactorID(sum, 1), MetaSpecification())) + push!(GeneralSubModelMeta(TestUtils.gcv, MetaSpecification()), SpecificSubModelMeta(GraphPPL.FactorID(sum, 1), MetaSpecification())) meta = MetaSpecification() - push!(meta, GeneralSubModelMeta(gcv, MetaSpecification())) + push!(meta, GeneralSubModelMeta(TestUtils.gcv, MetaSpecification())) end @testitem "filter general and specific submodel meta" begin @@ -105,7 +97,7 @@ end @test getgeneralsubmodelmeta(meta, sin).fform == sin end -@testitem "apply!(::Model, ::Context, ::MetaObject)" begin +@testitem "apply!(::Model, ::Context, ::MetaObject)" setup = [TestUtils] begin import GraphPPL: create_model, apply_meta!, @@ -118,54 +110,50 @@ end VariableMetaDescriptor, as_node - include("../../testutils.jl") - - using .TestUtils.ModelZoo - struct SomeMeta end # Test apply for a FactorMeta over a single factor - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = getcontext(model) metadata = MetaObject( - FactorMetaDescriptor(NormalMeanVariance, (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))), SomeMeta() + FactorMetaDescriptor(TestUtils.NormalMeanVariance, (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))), SomeMeta() ) apply_meta!(model, context, metadata) - node = last(filter(as_node(NormalMeanVariance), model)) + node = last(filter(as_node(TestUtils.NormalMeanVariance), model)) @test getextra(model[node], :meta) == SomeMeta() - node = first(filter(as_node(NormalMeanVariance), model)) + node = first(filter(as_node(TestUtils.NormalMeanVariance), model)) @test !hasextra(model[node], :meta) # Test apply for a FactorMeta over a single factor where variables are not specified - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) - metadata = MetaObject(FactorMetaDescriptor(NormalMeanVariance, nothing), SomeMeta()) + metadata = MetaObject(FactorMetaDescriptor(TestUtils.NormalMeanVariance, nothing), SomeMeta()) apply_meta!(model, context, metadata) @test getextra(model[node], :meta) == SomeMeta() # Test apply for a FactorMeta over a vector of factors - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) - metadata = MetaObject(FactorMetaDescriptor(NormalMeanVariance, (:x, :y)), SomeMeta()) + metadata = MetaObject(FactorMetaDescriptor(TestUtils.NormalMeanVariance, (:x, :y)), SomeMeta()) apply_meta!(model, context, metadata) for node in intersect(GraphPPL.neighbors(model, context[:x]), GraphPPL.neighbors(model, context[:y])) @test getextra(model[node], :meta) == SomeMeta() end # Test apply for a FactorMeta over a vector of factors without specifying variables - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) - metadata = MetaObject(FactorMetaDescriptor(NormalMeanVariance, nothing), SomeMeta()) + metadata = MetaObject(FactorMetaDescriptor(TestUtils.NormalMeanVariance, nothing), SomeMeta()) apply_meta!(model, context, metadata) for node in intersect(GraphPPL.neighbors(model, context[:x]), GraphPPL.neighbors(model, context[:y])) @test getextra(model[node], :meta) == SomeMeta() end # Test apply for a FactorMeta over a single factor with NamedTuple as meta - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) metadata = MetaObject( - FactorMetaDescriptor(NormalMeanVariance, (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))), + FactorMetaDescriptor(TestUtils.NormalMeanVariance, (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))), (meta = SomeMeta(), other = 1) ) apply_meta!(model, context, metadata) @@ -174,19 +162,19 @@ end @test getextra(model[node], :other) == 1 # Test apply for a FactorMeta over a single factor with NamedTuple as meta - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) - metadata = MetaObject(FactorMetaDescriptor(NormalMeanVariance, nothing), (meta = SomeMeta(), other = 1)) + metadata = MetaObject(FactorMetaDescriptor(TestUtils.NormalMeanVariance, nothing), (meta = SomeMeta(), other = 1)) apply_meta!(model, context, metadata) node = first(intersect(GraphPPL.neighbors(model, context[:x]), GraphPPL.neighbors(model, context[:y]))) @test getextra(model[node], :meta) == SomeMeta() @test getextra(model[node], :other) == 1 # Test apply for a FactorMeta over a vector of factors with NamedTuple as meta - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) metadata = MetaObject( - FactorMetaDescriptor(NormalMeanVariance, (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))), + FactorMetaDescriptor(TestUtils.NormalMeanVariance, (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))), (meta = SomeMeta(), other = 1) ) apply_meta!(model, context, metadata) @@ -196,10 +184,11 @@ end end # Test apply for a FactorMeta over a factor that is specified by an Index - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) metadata = MetaObject( - FactorMetaDescriptor(NormalMeanVariance, (IndexedVariable(:x, 1), IndexedVariable(:y, nothing))), (meta = SomeMeta(), other = 1) + FactorMetaDescriptor(TestUtils.NormalMeanVariance, (IndexedVariable(:x, 1), IndexedVariable(:y, nothing))), + (meta = SomeMeta(), other = 1) ) apply_meta!(model, context, metadata) node = first(intersect(GraphPPL.neighbors(model, context[:x][1]), GraphPPL.neighbors(model, context[:y]))) @@ -209,9 +198,9 @@ end @test !hasextra(model[other_node], :meta) # Test apply for a FactorMeta over a vector of factors with NamedTuple as meta - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) - metaobject = MetaObject(FactorMetaDescriptor(NormalMeanVariance, nothing), (meta = SomeMeta(), other = 1)) + metaobject = MetaObject(FactorMetaDescriptor(TestUtils.NormalMeanVariance, nothing), (meta = SomeMeta(), other = 1)) apply_meta!(model, context, metaobject) for node in intersect(GraphPPL.neighbors(model, context[:x]), GraphPPL.neighbors(model, context[:y])) @test getextra(model[node], :meta) == SomeMeta() @@ -219,21 +208,21 @@ end end # Test apply for a VariableMeta - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) metaobject = MetaObject(VariableMetaDescriptor(IndexedVariable(:x, nothing)), SomeMeta()) apply_meta!(model, context, metaobject) @test getextra(model[context[:x]], :meta) == SomeMeta() # Test apply for a VariableMeta with NamedTuple as meta - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) metaobject = MetaObject(VariableMetaDescriptor(IndexedVariable(:x, nothing)), (meta = SomeMeta(), other = 1)) apply_meta!(model, context, metaobject) @test getextra(model[context[:x]], :meta) == SomeMeta() # Test apply for a VariableMeta with NamedTuple as meta - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) metaobject = MetaObject(VariableMetaDescriptor(IndexedVariable(:x, nothing)), (meta = SomeMeta(), other = 1)) apply_meta!(model, context, metaobject) @@ -242,7 +231,7 @@ end @test getextra(model[context[:x][3]], :meta) == SomeMeta() # Test apply for a VariableMeta with NamedTuple as meta - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) metaobject = MetaObject(VariableMetaDescriptor(IndexedVariable(:x, nothing)), (meta = SomeMeta(), other = 1)) apply_meta!(model, context, metaobject) @@ -254,7 +243,7 @@ end @test getextra(model[context[:x][3]], :other) == 1 end -@testitem "save_meta!(::Model, ::NodeLabel, ::MetaObject)" begin +@testitem "save_meta!(::Model, ::NodeLabel, ::MetaObject)" setup = [TestUtils] begin import GraphPPL: create_model, save_meta!, @@ -267,25 +256,21 @@ end VariableMetaDescriptor, neighbors - include("../../testutils.jl") - - using .TestUtils.ModelZoo - struct SomeMeta end # Test save_meta! for a FactorMeta over a single factor - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = getcontext(model) node = first(intersect(neighbors(model, context[:x]), neighbors(model, context[:y]))) - metaobj = MetaObject(FactorMetaDescriptor(NormalMeanVariance, (:x, :y)), SomeMeta()) + metaobj = MetaObject(FactorMetaDescriptor(TestUtils.NormalMeanVariance, (:x, :y)), SomeMeta()) save_meta!(model, node, metaobj) @test getextra(model[node], :meta) == SomeMeta() # Test save_meta! for a FactorMeta with a NamedTuple as meta - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) node = first(intersect(neighbors(model, context[:x]), neighbors(model, context[:y]))) - metaobj = MetaObject(FactorMetaDescriptor(NormalMeanVariance, (:x, :y)), (meta = SomeMeta(), other = 1)) + metaobj = MetaObject(FactorMetaDescriptor(TestUtils.NormalMeanVariance, (:x, :y)), (meta = SomeMeta(), other = 1)) save_meta!(model, node, metaobj) @test getextra(model[node], :meta) == SomeMeta() @test getextra(model[node], :other) == 1 diff --git a/test/plugins/meta/meta_macro_tests.jl b/test/plugins/meta/meta_macro_tests.jl index be428616..4f1eac02 100644 --- a/test/plugins/meta/meta_macro_tests.jl +++ b/test/plugins/meta/meta_macro_tests.jl @@ -1,13 +1,12 @@ -@testitem "check_for_returns" begin +@testitem "check_for_returns" setup = [TestUtils] begin + using MacroTools import GraphPPL: check_for_returns_meta, apply_pipeline - include("../../testutils.jl") - # Test 1: check_for_returns_meta with one statement input = quote Normal(x, y) -> some_meta() end - @test_expression_generating apply_pipeline(input, check_for_returns_meta) input + TestUtils.@test_expression_generating apply_pipeline(input, check_for_returns_meta) input # Test 2: check_for_returns_meta with a return statement input = quote @@ -16,11 +15,10 @@ end @test_throws ErrorException("The meta macro does not support return statements.") apply_pipeline(input, check_for_returns_meta) end -@testitem "add_meta_constructor" begin +@testitem "add_meta_constructor" setup = [TestUtils] begin + using MacroTools import GraphPPL: add_meta_construction - include("../../testutils.jl") - # Test 1: add_constraints_construction to regular constraint specification input = quote GCV(x, k, w) -> GCVMetadata(GaussHermiteCubature(20)) @@ -36,7 +34,7 @@ end __meta__ end end - @test_expression_generating add_meta_construction(input) output + TestUtils.@test_expression_generating add_meta_construction(input) output # Test 2: add_constraints_construction to constraint specification with nested model specification input = quote @@ -58,7 +56,7 @@ end __meta__ end end - @test_expression_generating add_meta_construction(input) output + TestUtils.@test_expression_generating add_meta_construction(input) output # Test 3: add_constraints_construction to constraint specification with function specification input = quote @@ -80,7 +78,7 @@ end return __meta__ end end - @test_expression_generating add_meta_construction(input) output + TestUtils.@test_expression_generating add_meta_construction(input) output # Test 4: add_constraints_construction to constraint specification with function specification and arguments input = quote @@ -102,7 +100,7 @@ end return __meta__ end end - @test_expression_generating add_meta_construction(input) output + TestUtils.@test_expression_generating add_meta_construction(input) output # Test 5: add_constraints_construction to constraint specification with function specification and arguments and keyword arguments input = quote @@ -124,7 +122,7 @@ end return __meta__ end end - @test_expression_generating add_meta_construction(input) output + TestUtils.@test_expression_generating add_meta_construction(input) output # Test 6: add_constraints_construction to constraint specification with function specification and only keyword arguments input = quote @@ -146,14 +144,13 @@ end return __meta__ end end - @test_expression_generating add_meta_construction(input) output + TestUtils.@test_expression_generating add_meta_construction(input) output end -@testitem "create_submodel_meta" begin +@testitem "create_submodel_meta" setup = [TestUtils] begin + using MacroTools import GraphPPL: create_submodel_meta, apply_pipeline - include("../../testutils.jl") - # Test 1: create_submodel_meta with one nested layer input = quote __meta__ = GraphPPL.MetaSpecification() @@ -182,7 +179,7 @@ end x -> MySecondCustomMetaObject(arg3) __meta__ end - @test_expression_generating apply_pipeline(input, create_submodel_meta) output + TestUtils.@test_expression_generating apply_pipeline(input, create_submodel_meta) output # Test 2: create_submodel_meta with two nested layers input = quote @@ -223,14 +220,13 @@ end NormalMeanVariance() -> MyCustomMetaObject(arg1, arg2) __meta__ end - @test_expression_generating apply_pipeline(input, create_submodel_meta) output + TestUtils.@test_expression_generating apply_pipeline(input, create_submodel_meta) output end -@testitem "convert_meta_variables" begin +@testitem "convert_meta_variables" setup = [TestUtils] begin + using MacroTools import GraphPPL: convert_meta_variables, apply_pipeline - include("../../testutils.jl") - # Test 1: convert_meta_variables with non-indexed variables in Factor meta call input = quote some_function(x, y) -> some_meta() @@ -238,7 +234,7 @@ end output = quote some_function(GraphPPL.IndexedVariable(:x, nothing), GraphPPL.IndexedVariable(:y, nothing)) -> some_meta() end - @test_expression_generating apply_pipeline(input, convert_meta_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_meta_variables) output # Test 2: convert_meta_variables with indexed variables in Factor meta call input = quote @@ -247,7 +243,7 @@ end output = quote some_function(GraphPPL.IndexedVariable(:x, i), GraphPPL.IndexedVariable(:y, j)) -> some_meta() end - @test_expression_generating apply_pipeline(input, convert_meta_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_meta_variables) output # Test 3: convert_meta_variables with non-indexed variables in Variable meta call input = quote @@ -256,7 +252,7 @@ end output = quote GraphPPL.IndexedVariable(:x, nothing) -> some_meta() end - @test_expression_generating apply_pipeline(input, convert_meta_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_meta_variables) output # Test 4: convert_meta_variables with indexed variables in Variable meta call input = quote @@ -265,14 +261,13 @@ end output = quote GraphPPL.IndexedVariable(:x, i) -> some_meta() end - @test_expression_generating apply_pipeline(input, convert_meta_variables) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_meta_variables) output end -@testitem "convert_meta_object" begin +@testitem "convert_meta_object" setup = [TestUtils] begin + using MacroTools import GraphPPL: convert_meta_object, apply_pipeline - include("../../testutils.jl") - # Test 1: convert_meta_object with Factor meta call input = quote @@ -289,7 +284,7 @@ end ) ) end - @test_expression_generating apply_pipeline(input, convert_meta_object) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_meta_object) output # Test 2: convert_meta_object with Variable meta call input = quote @@ -298,14 +293,13 @@ end output = quote push!(__meta__, GraphPPL.MetaObject(GraphPPL.VariableMetaDescriptor(GraphPPL.IndexedVariable(:x, nothing)), some_meta())) end - @test_expression_generating apply_pipeline(input, convert_meta_object) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_meta_object) output end -@testitem "meta_macro_interior" begin +@testitem "meta_macro_interior" setup = [TestUtils] begin + using MacroTools import GraphPPL: meta_macro_interior - include("../../testutils.jl") - # Test 1: meta_macro_interor with one statement input = quote x -> some_meta() @@ -319,7 +313,7 @@ end __meta__ end end - @test_expression_generating meta_macro_interior(input) output + TestUtils.@test_expression_generating meta_macro_interior(input) output # Test 2: meta_macro_interor with multiple statements input = quote @@ -342,7 +336,7 @@ end __meta__ end end - @test_expression_generating meta_macro_interior(input) output + TestUtils.@test_expression_generating meta_macro_interior(input) output # Test 3: meta_macro_interor with multiple statements and a submodel definition input = quote @@ -384,5 +378,5 @@ end __meta__ end end - @test_expression_generating meta_macro_interior(input) output + TestUtils.@test_expression_generating meta_macro_interior(input) output end diff --git a/test/plugins/meta/meta_tests.jl b/test/plugins/meta/meta_tests.jl index 6ef844b2..91a050b0 100644 --- a/test/plugins/meta/meta_tests.jl +++ b/test/plugins/meta/meta_tests.jl @@ -5,15 +5,12 @@ @test MetaPlugin(nothing) == MetaPlugin(EmptyMeta) end -@testitem "@meta macro pipeline" begin +@testitem "@meta macro pipeline" setup = [TestUtils] begin + using Distributions using GraphPPL import GraphPPL: create_model, with_plugins, getextra, hasextra, PluginsCollection, MetaPlugin, apply_meta! - include("../../testutils.jl") - - using .TestUtils.ModelZoo - struct SomeMeta end # Test constraints macro with single variables and no nesting @@ -22,11 +19,11 @@ end x -> SomeMeta() y -> (meta = SomeMeta(), other = 1) end - model = create_model(with_plugins(simple_model(), PluginsCollection(MetaPlugin(metaspec)))) + model = create_model(with_plugins(TestUtils.simple_model(), PluginsCollection(MetaPlugin(metaspec)))) ctx = GraphPPL.getcontext(model) - @test !hasextra(model[ctx[NormalMeanVariance, 1]], :meta) - @test getextra(model[ctx[NormalMeanVariance, 2]], :meta) == SomeMeta() + @test !hasextra(model[ctx[TestUtils.NormalMeanVariance, 1]], :meta) + @test getextra(model[ctx[TestUtils.NormalMeanVariance, 2]], :meta) == SomeMeta() @test getextra(model[ctx[:x]], :meta) == SomeMeta() @test getextra(model[ctx[:y]], :meta) == SomeMeta() @@ -36,56 +33,56 @@ end metaobj = @meta begin Gamma(w) -> SomeMeta() end - model = create_model(with_plugins(outer(), PluginsCollection(MetaPlugin(metaobj)))) + model = create_model(with_plugins(TestUtils.outer(), PluginsCollection(MetaPlugin(metaobj)))) ctx = GraphPPL.getcontext(model) - for node in filter(GraphPPL.as_node(Gamma) & GraphPPL.as_context(outer), model) + for node in filter(GraphPPL.as_node(Gamma) & GraphPPL.as_context(TestUtils.outer), model) @test getextra(model[node], :meta) == SomeMeta() end # Test meta macro with nested model metaobj = @meta begin - for meta in inner + for meta in TestUtils.inner α -> SomeMeta() end end - model = create_model(with_plugins(outer(), PluginsCollection(MetaPlugin(metaobj)))) + model = create_model(with_plugins(TestUtils.outer(), PluginsCollection(MetaPlugin(metaobj)))) ctx = GraphPPL.getcontext(model) @test getextra(model[ctx[:y]], :meta) == SomeMeta() # Test with specifying specific submodel metaobj = @meta begin - for meta in (child_model, 1) + for meta in (TestUtils.child_model, 1) Normal(in, out) -> SomeMeta() end end - model = create_model(with_plugins(parent_model(), PluginsCollection(MetaPlugin(metaobj)))) + model = create_model(with_plugins(TestUtils.parent_model(), PluginsCollection(MetaPlugin(metaobj)))) ctx = GraphPPL.getcontext(model) - @test getextra(model[ctx[child_model, 1][NormalMeanVariance, 1]], :meta) == SomeMeta() + @test getextra(model[ctx[TestUtils.child_model, 1][TestUtils.NormalMeanVariance, 1]], :meta) == SomeMeta() for i in 2:99 - @test !hasextra(model[ctx[child_model, i][NormalMeanVariance, 1]], :meta) + @test !hasextra(model[ctx[TestUtils.child_model, i][TestUtils.NormalMeanVariance, 1]], :meta) end # Test with specifying general submodel metaobj = @meta begin - for meta in child_model + for meta in TestUtils.child_model Normal(in, out) -> SomeMeta() end end - model = create_model(with_plugins(parent_model(), PluginsCollection(MetaPlugin(metaobj)))) + model = create_model(with_plugins(TestUtils.parent_model(), PluginsCollection(MetaPlugin(metaobj)))) ctx = GraphPPL.getcontext(model) - for node in filter(GraphPPL.as_node(NormalMeanVariance) & GraphPPL.as_context(child_model), model) + for node in filter(GraphPPL.as_node(TestUtils.NormalMeanVariance) & GraphPPL.as_context(TestUtils.child_model), model) @test getextra(model[node], :meta) == SomeMeta() end end -@testitem "Meta setting via the `where` block" begin - include("../../testutils.jl") +@testitem "Meta setting via the `where` block" setup = [TestUtils] begin + using Distributions - @model function some_model() + TestUtils.@model function some_model() x ~ Beta(1.0, 2.0) where {meta = "Hello, world!"} end @@ -97,7 +94,7 @@ end end @testitem "Meta should save source code " begin - include("../../testutils.jl") + using Distributions meta = @meta begin Normal(in, out) -> 1 diff --git a/test/plugins/node_created_by_tests.jl b/test/plugins/node_created_by_tests.jl index 1f0e10c4..b01505f6 100644 --- a/test/plugins/node_created_by_tests.jl +++ b/test/plugins/node_created_by_tests.jl @@ -1,4 +1,4 @@ -@testitem "NodeCreatedByPlugin: model with the plugin" begin +@testitem "NodeCreatedByPlugin: model with the plugin" setup = [TestUtils] begin using Distributions import GraphPPL: @@ -13,9 +13,7 @@ hasextra, getextra - include("../testutils.jl") - - model = create_test_model(plugins = PluginsCollection(NodeCreatedByPlugin())) + model = TestUtils.create_test_model(plugins = PluginsCollection(NodeCreatedByPlugin())) ctx = getcontext(model) @testset begin @@ -52,7 +50,7 @@ end end -@testitem "NodeCreatedByPlugin: model without the plugin" begin +@testitem "NodeCreatedByPlugin: model without the plugin" setup = [TestUtils] begin using Distributions import GraphPPL: @@ -67,9 +65,7 @@ end hasextra, getextra - include("../testutils.jl") - - model = create_test_model(plugins = PluginsCollection()) + model = TestUtils.create_test_model(plugins = PluginsCollection()) ctx = getcontext(model) @testset begin @@ -90,7 +86,7 @@ end end end -@testitem "Usage with the actual model" begin +@testitem "Usage with the actual model" setup = [TestUtils] begin using Distributions import GraphPPL: @@ -105,9 +101,7 @@ end NodeCreatedByPlugin, getextra - include("../testutils.jl") - - @model function simple_model() + TestUtils.@model function simple_model() x ~ Normal(0, 1) y ~ Gamma(1, 1) z ~ Beta(x, y) diff --git a/test/plugins/node_id_tests.jl b/test/plugins/node_id_tests.jl index cc0097a7..282c2d16 100644 --- a/test/plugins/node_id_tests.jl +++ b/test/plugins/node_id_tests.jl @@ -1,4 +1,4 @@ -@testitem "NodeIdPlugin: model with the plugin" begin +@testitem "NodeIdPlugin: model with the plugin" setup = [TestUtils] begin using Distributions import GraphPPL: @@ -13,8 +13,7 @@ create_model, with_plugins - include("../testutils.jl") - @model function node_with_two_anonymous() + TestUtils.@model function node_with_two_anonymous() x[1] ~ Normal(0, 1) y[1] ~ Normal(0, 1) for i in 2:10 diff --git a/test/plugins/node_tag_tests.jl b/test/plugins/node_tag_tests.jl index e36eed6c..21e7eced 100644 --- a/test/plugins/node_tag_tests.jl +++ b/test/plugins/node_tag_tests.jl @@ -1,4 +1,4 @@ -@testitem "NodeTagPlugin: model with the plugin" begin +@testitem "NodeTagPlugin: model with the plugin" setup = [TestUtils] begin using Distributions import GraphPPL: @@ -13,9 +13,7 @@ getextra, by_nodetag - include("../testutils.jl") - - model = create_test_model(plugins = PluginsCollection(NodeTagPlugin())) + model = TestUtils.create_test_model(plugins = PluginsCollection(NodeTagPlugin())) ctx = getcontext(model) @testset begin diff --git a/test/plugins/plugin_lifecycle_tests.jl b/test/plugins/plugin_lifecycle_tests.jl index a60decbf..1e546dc3 100644 --- a/test/plugins/plugin_lifecycle_tests.jl +++ b/test/plugins/plugin_lifecycle_tests.jl @@ -1,11 +1,7 @@ -@testitem "Check that factor node plugins are uniquely recreated" begin +@testitem "Check that factor node plugins are uniquely recreated" setup = [TestUtils] begin import GraphPPL: create_model, with_plugins, getplugins, factor_nodes, PluginsCollection, setextra!, getextra - include("testutils.jl") - - using .TestUtils.ModelZoo - struct AnArbitraryPluginForTestUniqeness end GraphPPL.plugin_type(::AnArbitraryPluginForTestUniqeness) = GraphPPL.FactorNodePlugin() @@ -18,7 +14,7 @@ return label, nodedata end - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(with_plugins(model_fn(), PluginsCollection(AnArbitraryPluginForTestUniqeness()))) for f1 in factor_nodes(model), f2 in factor_nodes(model) if f1 !== f2 @@ -30,7 +26,7 @@ end end -@testitem "Check that plugins may change the options" begin +@testitem "Check that plugins may change the options" setup = [TestUtils] begin import GraphPPL: NodeData, variable_nodes, @@ -45,10 +41,6 @@ end create_model, with_plugins - include("testutils.jl") - - using .TestUtils.ModelZoo - struct AnArbitraryPluginForChangingOptions end GraphPPL.plugin_type(::AnArbitraryPluginForChangingOptions) = GraphPPL.VariableNodePlugin() @@ -58,7 +50,7 @@ end return label, NodeData(context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0))) end - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(with_plugins(model_fn(), PluginsCollection(AnArbitraryPluginForChangingOptions()))) for v in variable_nodes(model) @test getname(getproperties(model[v])) === :x diff --git a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl b/test/plugins/variational_constraints/variational_constraints_engine_tests.jl index b69db741..05557f02 100644 --- a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl +++ b/test/plugins/variational_constraints/variational_constraints_engine_tests.jl @@ -1,4 +1,4 @@ -@testitem "FactorizationConstraintEntry" begin +@testitem "FactorizationConstraintEntry" setup = [TestUtils] begin import GraphPPL: FactorizationConstraintEntry, IndexedVariable # Test 1: Test FactorisationConstraintEntry @@ -16,7 +16,7 @@ a = FactorizationConstraintEntry((IndexedVariable(:x, 1), IndexedVariable(:y, nothing))) end -@testitem "CombinedRange" begin +@testitem "CombinedRange" setup = [TestUtils] begin import GraphPPL: CombinedRange, is_splitted, FunctionalIndex, IndexedVariable for left in 1:3, right in 5:8 cr = CombinedRange(left, right) @@ -47,7 +47,7 @@ end @test lhs !== IndexedVariable(:y, CombinedRange(1, 2)) end -@testitem "SplittedRange" begin +@testitem "SplittedRange" setup = [TestUtils] begin import GraphPPL: SplittedRange, is_splitted, FunctionalIndex, IndexedVariable for left in 1:3, right in 5:8 cr = SplittedRange(left, right) @@ -78,7 +78,7 @@ end @test lhs !== IndexedVariable(:y, SplittedRange(1, 2)) end -@testitem "__factorization_specification_resolve_index" begin +@testitem "__factorization_specification_resolve_index" setup = [TestUtils] begin using GraphPPL import GraphPPL: __factorization_specification_resolve_index, FunctionalIndex, CombinedRange, SplittedRange, NodeLabel, ResizableArray @@ -130,7 +130,7 @@ end end end -@testitem "factorization_split" begin +@testitem "factorization_split" setup = [TestUtils] begin import GraphPPL: factorization_split, FactorizationConstraintEntry, IndexedVariable, FunctionalIndex, CombinedRange, SplittedRange # Test 1: Test factorization_split with single split @@ -219,7 +219,7 @@ end ) end -@testitem "FactorizationConstraint" begin +@testitem "FactorizationConstraint" setup = [TestUtils] begin import GraphPPL: FactorizationConstraint, FactorizationConstraintEntry, IndexedVariable, FunctionalIndex, CombinedRange, SplittedRange # Test 1: Test FactorizationConstraint with single variables @@ -381,7 +381,8 @@ end @test_throws ErrorException push!(constraints, constraint) end -@testitem "push!(::SubModelConstraints, c::Constraint)" begin +@testitem "push!(::SubModelConstraints, c::Constraint)" setup = [TestUtils] begin + using Distributions import GraphPPL: Constraint, GeneralSubModelConstraints, @@ -394,12 +395,8 @@ end Constraints, IndexedVariable - include("../../testutils.jl") - - using .TestUtils.ModelZoo - # Test 1: Test push! with FactorizationConstraint - constraints = GeneralSubModelConstraints(gcv) + constraints = GeneralSubModelConstraints(TestUtils.gcv) constraint = FactorizationConstraint( (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) @@ -414,21 +411,21 @@ end @test_throws MethodError push!(constraints, "string") # Test 2: Test push! with MarginalFormConstraint - constraints = GeneralSubModelConstraints(gcv) + constraints = GeneralSubModelConstraints(TestUtils.gcv) constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) push!(constraints, constraint) @test getconstraint(constraints) == Constraints([MarginalFormConstraint(IndexedVariable(:x, nothing), Normal)],) @test_throws MethodError push!(constraints, "string") # Test 3: Test push! with MessageFormConstraint - constraints = GeneralSubModelConstraints(gcv) + constraints = GeneralSubModelConstraints(TestUtils.gcv) constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) push!(constraints, constraint) @test getconstraint(constraints) == Constraints([MessageFormConstraint(IndexedVariable(:x, nothing), Normal)],) @test_throws MethodError push!(constraints, "string") # Test 4: Test push! with SpecificSubModelConstraints - constraints = SpecificSubModelConstraints(GraphPPL.FactorID(gcv, 3)) + constraints = SpecificSubModelConstraints(GraphPPL.FactorID(TestUtils.gcv, 3)) constraint = FactorizationConstraint( (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) @@ -443,26 +440,24 @@ end @test_throws MethodError push!(constraints, "string") # Test 5: Test push! with MarginalFormConstraint - constraints = GeneralSubModelConstraints(gcv) + constraints = GeneralSubModelConstraints(TestUtils.gcv) constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) push!(constraints, constraint) @test getconstraint(constraints) == Constraints([MarginalFormConstraint(IndexedVariable(:x, nothing), Normal)],) @test_throws MethodError push!(constraints, "string") # Test 6: Test push! with MessageFormConstraint - constraints = GeneralSubModelConstraints(gcv) + constraints = GeneralSubModelConstraints(TestUtils.gcv) constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) push!(constraints, constraint) @test getconstraint(constraints) == Constraints([MessageFormConstraint(IndexedVariable(:x, nothing), Normal)],) @test_throws MethodError push!(constraints, "string") end -@testitem "is_factorized" begin +@testitem "is_factorized" setup = [TestUtils] begin import GraphPPL: is_factorized, create_model, getcontext, getproperties, getorcreate!, variable_nodes, NodeCreationOptions - include("../../testutils.jl") - - m = create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + m = TestUtils.create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) ctx = getcontext(m) x_1 = getorcreate!(m, ctx, NodeCreationOptions(factorized = true), :x_1, nothing) @@ -484,20 +479,16 @@ end @test is_factorized(m[x_6[1, 2, 3]]) end -@testitem "is_factorized || is_constant" begin +@testitem "is_factorized || is_constant" setup = [TestUtils] begin import GraphPPL: is_constant, is_factorized, create_model, with_plugins, getcontext, getproperties, getorcreate!, variable_nodes, NodeCreationOptions - include("../../testutils.jl") - - using .TestUtils.ModelZoo - - m = create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + m = TestUtils.create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) ctx = getcontext(m) x = getorcreate!(m, ctx, NodeCreationOptions(kind = :data, factorized = true), :x, nothing) @test is_factorized(m[x]) - for model_fn in ModelsInTheZooWithoutArguments + for model_fn in TestUtils.ModelsInTheZooWithoutArguments model = create_model(with_plugins(model_fn(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) for label in variable_nodes(model) nodedata = model[label] @@ -510,7 +501,7 @@ end end end -@testitem "Application of MarginalFormConstraint" begin +@testitem "Application of MarginalFormConstraint" setup = [TestUtils] begin import GraphPPL: create_model, MarginalFormConstraint, @@ -520,14 +511,10 @@ end hasextra, VariationalConstraintsMarginalFormConstraintKey - include("../../testutils.jl") - - using .TestUtils.ModelZoo - struct ArbitraryFunctionalFormConstraint end # Test saving of MarginalFormConstraint in single variable - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) apply_constraints!(model, context, constraint) @@ -536,7 +523,7 @@ end end # Test saving of MarginalFormConstraint in multiple variables - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) apply_constraints!(model, context, constraint) @@ -548,7 +535,7 @@ end end # Test saving of MarginalFormConstraint in single variable in array - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) constraint = MarginalFormConstraint(IndexedVariable(:x, 1), ArbitraryFunctionalFormConstraint()) apply_constraints!(model, context, constraint) @@ -562,7 +549,7 @@ end end end -@testitem "Application of MessageFormConstraint" begin +@testitem "Application of MessageFormConstraint" setup = [TestUtils] begin import GraphPPL: create_model, MessageFormConstraint, @@ -572,14 +559,10 @@ end getextra, VariationalConstraintsMessagesFormConstraintKey - include("../../testutils.jl") - - using .TestUtils.ModelZoo - struct ArbitraryMessageFormConstraint end # Test saving of MessageFormConstraint in single variable - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) context = GraphPPL.getcontext(model) constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) node = first(filter(GraphPPL.as_variable(:x), model)) @@ -587,7 +570,7 @@ end @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() # Test saving of MessageFormConstraint in multiple variables - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) apply_constraints!(model, context, constraint) @@ -599,7 +582,7 @@ end end # Test saving of MessageFormConstraint in single variable in array - model = create_model(vector_model()) + model = create_model(TestUtils.vector_model()) context = GraphPPL.getcontext(model) constraint = MessageFormConstraint(IndexedVariable(:x, 1), ArbitraryMessageFormConstraint()) apply_constraints!(model, context, constraint) @@ -613,7 +596,7 @@ end end end -@testitem "save constraints with constants via `mean_field_constraint!`" begin +@testitem "save constraints with constants via `mean_field_constraint!`" setup = [TestUtils] begin using BitSetTuples import GraphPPL: create_model, @@ -625,79 +608,82 @@ end PluginsCollection, VariationalConstraintsFactorizationBitSetKey - include("../../testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(with_plugins(simple_model(), GraphPPL.PluginsCollection(VariationalConstraintsPlugin()))) + model = create_model(with_plugins(TestUtils.simple_model(), GraphPPL.PluginsCollection(VariationalConstraintsPlugin()))) ctx = GraphPPL.getcontext(model) @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 1)) == ((1,), (2, 3), (2, 3)) @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 2)) == ((1, 3), (2,), (1, 3)) @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 3)) == ((1, 2), (1, 2), (3,)) - node = ctx[NormalMeanVariance, 2] + node = ctx[TestUtils.NormalMeanVariance, 2] constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2, 3), (2, 3)) @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 2))) == ((1,), (2,), (3,)) - node = ctx[NormalMeanVariance, 1] + node = ctx[TestUtils.NormalMeanVariance, 1] constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) # Here it is the mean field because the original model has `x ~ Normal(0, 1)` and `0` and `1` are constants @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2,), (3,)) end -@testitem "materialize_constraints!(:Model, ::NodeLabel, ::FactorNodeData)" begin +@testitem "materialize_constraints!(:Model, ::NodeLabel, ::FactorNodeData)" setup = [TestUtils] begin using BitSetTuples import GraphPPL: - create_model, with_plugins, materialize_constraints!, EdgeLabel, get_constraint_names, getproperties, getextra, setextra! - - include("../../testutils.jl") + create_model, + with_plugins, + materialize_constraints!, + EdgeLabel, + get_constraint_names, + getproperties, + getextra, + setextra!, + VariationalConstraintsPlugin - using .TestUtils.ModelZoo + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] # Test 1: Test materialize with a Full Factorization constraint - model = create_model(simple_model()) - ctx = GraphPPL.getcontext(model) - node = ctx[NormalMeanVariance, 2] + node = ctx[TestUtils.NormalMeanVariance, 2] # Force overwrite the bitset and the constraints setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(3)) materialize_constraints!(model, node) @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1, 2, 3),) - node = ctx[NormalMeanVariance, 1] + node = ctx[TestUtils.NormalMeanVariance, 1] setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2,), (3,)))) materialize_constraints!(model, node) @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2,), (3,)) # Test 2: Test materialize with an applied constraint - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) ctx = GraphPPL.getcontext(model) - node = ctx[NormalMeanVariance, 2] + node = ctx[TestUtils.NormalMeanVariance, 2] setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2, 3), (2, 3)))) materialize_constraints!(model, node) @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2, 3)) # # Test 3: Check that materialize_constraints! throws if the constraint is not a valid partition - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) ctx = GraphPPL.getcontext(model) - node = ctx[NormalMeanVariance, 2] + node = ctx[TestUtils.NormalMeanVariance, 2] setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (3,), (1, 3)))) @test_throws ErrorException materialize_constraints!(model, node) # Test 4: Check that materialize_constraints! throws if the constraint is not a valid partition - model = create_model(simple_model()) + model = create_model(TestUtils.simple_model()) ctx = GraphPPL.getcontext(model) - node = ctx[NormalMeanVariance, 2] + node = ctx[TestUtils.NormalMeanVariance, 2] setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (1,), (3,)))) @test_throws ErrorException materialize_constraints!(model, node) end -@testitem "Resolve Factorization Constraints" begin +@testitem "Resolve Factorization Constraints" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, FactorizationConstraint, @@ -709,15 +695,12 @@ end ResolvedFactorizationConstraintEntry, ResolvedIndexedVariable, CombinedRange, - SplittedRange - - include("../../testutils.jl") - - using .TestUtils.ModelZoo + SplittedRange, + @model - model = create_model(outer()) + model = create_model(TestUtils.outer()) ctx = GraphPPL.getcontext(model) - inner_context = ctx[inner, 1] + inner_context = ctx[TestUtils.inner, 1] # Test resolve constraint in child model @@ -768,7 +751,7 @@ end @test resolve(model, ctx, constraint) == result end - model = create_model(filled_matrix_model()) + model = create_model(TestUtils.filled_matrix_model()) ctx = GraphPPL.getcontext(model) let constraint = FactorizationConstraint( @@ -786,7 +769,7 @@ end ) @test resolve(model, ctx, constraint) == result end - model = create_model(filled_matrix_model()) + model = create_model(TestUtils.filled_matrix_model()) ctx = GraphPPL.getcontext(model) let constraint = FactorizationConstraint( @@ -838,8 +821,6 @@ end ) @test resolve(model, ctx, constraint) == result end - - model = create_model(uneven_matrix()) end @testitem "Resolved Constraints in" begin @@ -894,7 +875,7 @@ end @test node_data ∈ variable end -@testitem "convert_to_bitsets" begin +@testitem "convert_to_bitsets" setup = [TestUtils] begin using BitSetTuples import GraphPPL: create_model, @@ -908,16 +889,12 @@ end apply_constraints!, getproperties - include("../../testutils.jl") - - using .TestUtils.ModelZoo - - model = create_model(with_plugins(outer(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model(with_plugins(TestUtils.outer(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) context = GraphPPL.getcontext(model) - inner_context = context[inner, 1] - inner_inner_context = inner_context[inner_inner, 1] + inner_context = context[TestUtils.inner, 1] + inner_inner_context = inner_context[TestUtils.inner_inner, 1] - normal_node = inner_inner_context[NormalMeanVariance, 1] + normal_node = inner_inner_context[TestUtils.NormalMeanVariance, 1] neighbors = model[GraphPPL.neighbors(model, normal_node)] let constraint = ResolvedFactorizationConstraint( @@ -984,9 +961,9 @@ end @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1,), (2, 3), (2, 3)) end - model = create_model(with_plugins(multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model(with_plugins(TestUtils.multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) context = GraphPPL.getcontext(model) - normal_node = context[NormalMeanVariance, 5] + normal_node = context[TestUtils.NormalMeanVariance, 5] neighbors = model[GraphPPL.neighbors(model, normal_node)] let constraint = ResolvedFactorizationConstraint( @@ -997,9 +974,9 @@ end @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 3), (2, 3), (1, 2, 3)) end - model = create_model(with_plugins(multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model(with_plugins(TestUtils.multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) context = GraphPPL.getcontext(model) - normal_node = context[NormalMeanVariance, 5] + normal_node = context[TestUtils.NormalMeanVariance, 5] neighbors = model[GraphPPL.neighbors(model, normal_node)] let constraint = ResolvedFactorizationConstraint( @@ -1012,9 +989,11 @@ end # Test ResolvedFactorizationConstraints over anonymous variables - model = create_model(with_plugins(node_with_only_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model( + with_plugins(TestUtils.node_with_only_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + ) context = GraphPPL.getcontext(model) - normal_node = context[NormalMeanVariance, 6] + normal_node = context[TestUtils.NormalMeanVariance, 6] neighbors = model[GraphPPL.neighbors(model, normal_node)] let constraint = ResolvedFactorizationConstraint( ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), @@ -1024,9 +1003,11 @@ end end # Test ResolvedFactorizationConstraints over multiple anonymous variables - model = create_model(with_plugins(node_with_two_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model( + with_plugins(TestUtils.node_with_two_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + ) context = GraphPPL.getcontext(model) - normal_node = context[NormalMeanVariance, 6] + normal_node = context[TestUtils.NormalMeanVariance, 6] neighbors = model[GraphPPL.neighbors(model, normal_node)] let constraint = ResolvedFactorizationConstraint( ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), @@ -1039,9 +1020,11 @@ end end # Test ResolvedFactorizationConstraints over ambiguous anonymouys variables - model = create_model(with_plugins(node_with_ambiguous_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model( + with_plugins(TestUtils.node_with_ambiguous_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + ) context = GraphPPL.getcontext(model) - normal_node = last(filter(GraphPPL.as_node(NormalMeanVariance), model)) + normal_node = last(filter(GraphPPL.as_node(TestUtils.NormalMeanVariance), model)) neighbors = model[GraphPPL.neighbors(model, normal_node)] let constraint = ResolvedFactorizationConstraint( ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), @@ -1056,9 +1039,9 @@ end end # Test ResolvedFactorizationConstraint with a Mixture node - model = create_model(with_plugins(mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + model = create_model(with_plugins(TestUtils.mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) context = GraphPPL.getcontext(model) - mixture_node = first(filter(GraphPPL.as_node(Mixture), model)) + mixture_node = first(filter(GraphPPL.as_node(TestUtils.Mixture), model)) neighbors = model[GraphPPL.neighbors(model, mixture_node)] let constraint = ResolvedFactorizationConstraint( ResolvedConstraintLHS(( @@ -1124,7 +1107,7 @@ end end end -@testitem "default_constraints" begin +@testitem "default_constraints" setup = [TestUtils] begin import GraphPPL: create_model, with_plugins, @@ -1136,50 +1119,46 @@ end getextra, UnspecifiedConstraints - include("../../testutils.jl") - - using .TestUtils.ModelZoo - - @test default_constraints(simple_model) == UnspecifiedConstraints - @test default_constraints(model_with_default_constraints) == @constraints( + @test default_constraints(TestUtils.simple_model) == UnspecifiedConstraints + @test default_constraints(TestUtils.model_with_default_constraints) == @constraints( begin q(a, d) = q(a)q(d) end ) - model = create_model(with_plugins(contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin()))) + model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin()))) ctx = GraphPPL.getcontext(model) # Test that default constraints are applied for i in 1:10 - node = model[ctx[model_with_default_constraints, i][NormalMeanVariance, 1]] + node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] @test hasextra(node, :factorization_constraint_indices) @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1,), (2,), (3,)) end # Test that default constraints are not applied if we specify constraints in the context c = @constraints begin - for q in model_with_default_constraints + for q in TestUtils.model_with_default_constraints q(a, d) = q(a, d) end end - model = create_model(with_plugins(contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) + model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) ctx = GraphPPL.getcontext(model) for i in 1:10 - node = model[ctx[model_with_default_constraints, i][NormalMeanVariance, 1]] + node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] @test hasextra(node, :factorization_constraint_indices) @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1, 2), (3,)) end # Test that default constraints are not applied if we specify constraints for a specific instance of the submodel c = @constraints begin - for q in (model_with_default_constraints, 1) + for q in (TestUtils.model_with_default_constraints, 1) q(a, d) = q(a, d) end end - model = create_model(with_plugins(contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) + model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) ctx = GraphPPL.getcontext(model) for i in 1:10 - node = model[ctx[model_with_default_constraints, i][NormalMeanVariance, 1]] + node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] @test hasextra(node, :factorization_constraint_indices) if i == 1 @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1, 2), (3,)) @@ -1208,7 +1187,8 @@ end @test_throws BoundsError mean_field_constraint!(BoundedBitSetTuple(5), (1, 2, 3, 4, 5, 6)) == ((1,), (2,), (3,), (4,), (5,)) end -@testitem "Apply constraints to matrix variables" begin +@testitem "Apply constraints to matrix variables" setup = [TestUtils] begin + using Distributions import GraphPPL: getproperties, PluginsCollection, @@ -1217,19 +1197,16 @@ end getcontext, with_plugins, create_model, - NotImplementedError - - include("../../testutils.jl") - - using .TestUtils.ModelZoo + NotImplementedError, + @model # Test for constraints applied to a model with matrix variables c = @constraints begin q(x, y) = q(x)q(y) end - model = create_model(with_plugins(filled_matrix_model(), PluginsCollection(VariationalConstraintsPlugin(c)))) + model = create_model(with_plugins(TestUtils.filled_matrix_model(), PluginsCollection(VariationalConstraintsPlugin(c)))) - for node in filter(as_node(Normal), model) + for node in filter(TestUtils.as_node(TestUtils.Normal), model) @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) end @@ -1410,7 +1387,7 @@ end q(x, z, y) = q(z)(q(x[begin + 1]) .. q(x[end]))(q(y[begin + 1]) .. q(y[end])) end - model = create_model(with_plugins(vector_model(), PluginsCollection(VariationalConstraintsPlugin(constraints_11)))) + model = create_model(with_plugins(TestUtils.vector_model(), PluginsCollection(VariationalConstraintsPlugin(constraints_11)))) ctx = getcontext(model) for node in filter(as_node(Normal), model) @@ -1450,6 +1427,7 @@ end end @testitem "Test factorization constraint with automatically folded data/const variables" begin + using Distributions import GraphPPL: getproperties, PluginsCollection, @@ -1459,9 +1437,8 @@ end with_plugins, create_model, getextra, - VariationalConstraintsFactorizationIndicesKey - - include("../../testutils.jl") + VariationalConstraintsFactorizationIndicesKey, + @model @model function fold_datavars(f, a, b) y ~ Normal(f(f(a, b), f(a, b)), 0.5) @@ -1497,14 +1474,13 @@ end end @testitem "show constraints" begin + using Distributions using GraphPPL - include("../../testutils.jl") - constraint = @constraints begin - q(x)::PointMass + q(x)::Normal end - @test occursin(r"q\(x\) ::(.*?)PointMass", repr(constraint)) + @test occursin(r"q\(x\) ::(.*?)Normal", repr(constraint)) constraint = @constraints begin q(x, y) = q(x)q(y) @@ -1512,14 +1488,14 @@ end @test occursin(r"q\(x, y\) = q\(x\)q\(y\)", repr(constraint)) constraint = @constraints begin - μ(x)::PointMass + μ(x)::Normal end - @test occursin(r"μ\(x\) ::(.*?)PointMass", repr(constraint)) + @test occursin(r"μ\(x\) ::(.*?)Normal", repr(constraint)) constraint = @constraints begin q(x, y) = q(x)q(y) - μ(x)::PointMass + μ(x)::Normal end @test occursin(r"q\(x, y\) = q\(x\)q\(y\)", repr(constraint)) - @test occursin(r"μ\(x\) ::(.*?)PointMass", repr(constraint)) + @test occursin(r"μ\(x\) ::(.*?)Normal", repr(constraint)) end diff --git a/test/plugins/variational_constraints/variational_constraints_macro_tests.jl b/test/plugins/variational_constraints/variational_constraints_macro_tests.jl index 9b96702d..0ce01aa7 100644 --- a/test/plugins/variational_constraints/variational_constraints_macro_tests.jl +++ b/test/plugins/variational_constraints/variational_constraints_macro_tests.jl @@ -1,5 +1,6 @@ -@testitem "check_reserved_variable_names_constraints" begin +@testitem "check_reserved_variable_names_constraints" setup = [TestUtils] begin import GraphPPL: apply_pipeline, check_reserved_variable_names_constraints + using MacroTools # Test 1: test that reserved variable name __parent_options__ throws an error input = quote @@ -22,10 +23,9 @@ @test apply_pipeline(input, check_reserved_variable_names_constraints) == input end -@testitem "check_for_returns" begin +@testitem "check_for_returns" setup = [TestUtils] begin import GraphPPL: check_for_returns_constraints, apply_pipeline - - include("../../testutils.jl") + using MacroTools # Test 1: check_for_returns with no returns input = quote @@ -33,7 +33,7 @@ end q(x)::PointMass end output = input - @test_expression_generating apply_pipeline(input, check_for_returns_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, check_for_returns_constraints) output # Test 2: check_for_returns with one return input = quote @@ -57,10 +57,9 @@ end ) end -@testitem "add_constraints_construction" begin +@testitem "add_constraints_construction" setup = [TestUtils] begin import GraphPPL: add_constraints_construction - - include("../../testutils.jl") + using MacroTools # Test 1: add_constraints_construction to regular constraint specification input = quote @@ -77,7 +76,7 @@ end __constraints__ end end - @test_expression_generating add_constraints_construction(input) output + TestUtils.@test_expression_generating add_constraints_construction(input) output # Test 2: add_constraints_construction to constraint specification with nested model specification input = quote @@ -102,7 +101,7 @@ end __constraints__ end end - @test_expression_generating add_constraints_construction(input) output + TestUtils.@test_expression_generating add_constraints_construction(input) output # Test 3: add_constraints_construction to constraint specification with function specification input = quote @@ -122,7 +121,7 @@ end return __constraints__ end end - @test_expression_generating add_constraints_construction(input) output + TestUtils.@test_expression_generating add_constraints_construction(input) output # Test 4: add_constraints_construction to constraint specification with function specification with arguments input = quote @@ -142,7 +141,7 @@ end return __constraints__ end end - @test_expression_generating add_constraints_construction(input) output + TestUtils.@test_expression_generating add_constraints_construction(input) output # Test 5: add_constraints_construction to constraint specification with function specification with arguments and kwargs input = quote @@ -162,13 +161,12 @@ end return __constraints__ end end - @test_expression_generating add_constraints_construction(input) output + TestUtils.@test_expression_generating add_constraints_construction(input) output end -@testitem "rewrite_stacked_constraints" begin +@testitem "rewrite_stacked_constraints" setup = [TestUtils] begin import GraphPPL: rewrite_stacked_constraints, apply_pipeline - - include("../../testutils.jl") + using MacroTools # Test 1: rewrite_stacked_constraints with no stacked constraints input = quote @@ -176,7 +174,7 @@ end q(x)::PointMass() end output = input - @test_expression_generating apply_pipeline(input, rewrite_stacked_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, rewrite_stacked_constraints) output # Test 2: rewrite_stacked_constraints with two stacked constraints input = quote @@ -187,7 +185,7 @@ end q(x, y) = q(x)q(y) q(x)::GraphPPL.stack_constraints(PointMass(), SampleList()) end - @test_expression_generating apply_pipeline(input, rewrite_stacked_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, rewrite_stacked_constraints) output # Test 3: rewrite_stacked_constraints with three stacked constraints input = quote @@ -198,11 +196,12 @@ end q(x, y) = q(x)q(y) q(x)::GraphPPL.stack_constraints(GraphPPL.stack_constraints(PointMass(), Sample), SampleList()) end - @test_expression_generating apply_pipeline(input, rewrite_stacked_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, rewrite_stacked_constraints) output end -@testitem "stack_constraints" begin +@testitem "stack_constraints" setup = [TestUtils] begin import GraphPPL: stack_constraints + using MacroTools @test stack_constraints(1, 2) == (1, 2) @test stack_constraints(1, (2, 1)) == (1, 2, 1) @@ -210,10 +209,9 @@ end @test stack_constraints((1, 3), (2, 1)) == (1, 3, 2, 1) end -@testitem "replace_begin_end" begin +@testitem "replace_begin_end" setup = [TestUtils] begin import GraphPPL: replace_begin_end, apply_pipeline - - include("../../testutils.jl") + using MacroTools # Test 1: replace_begin_end with one begin and end input = quote @@ -222,7 +220,7 @@ end output = quote q(x) = q(x[GraphPPL.FunctionalIndex{:begin}(firstindex)]) .. q(x[GraphPPL.FunctionalIndex{:end}(lastindex)]) end - @test_expression_generating apply_pipeline(input, replace_begin_end) output + TestUtils.@test_expression_generating apply_pipeline(input, replace_begin_end) output # Test 2: replace_begin_end with two begins and ends input = quote @@ -233,7 +231,7 @@ end q(x[GraphPPL.FunctionalIndex{:begin}(firstindex), GraphPPL.FunctionalIndex{:begin}(firstindex)]) .. q(x[GraphPPL.FunctionalIndex{:end}(lastindex), GraphPPL.FunctionalIndex{:end}(lastindex)]) end - @test_expression_generating apply_pipeline(input, replace_begin_end) output + TestUtils.@test_expression_generating apply_pipeline(input, replace_begin_end) output # Test 3: replace_begin_end with mixed begin and ends input = quote @@ -242,7 +240,7 @@ end output = quote q(x) = q(x[GraphPPL.FunctionalIndex{:begin}(firstindex), 1]) .. q(x[GraphPPL.FunctionalIndex{:end}(lastindex), 2]) end - @test_expression_generating apply_pipeline(input, replace_begin_end) output + TestUtils.@test_expression_generating apply_pipeline(input, replace_begin_end) output # Test 4: replace_begin_end with composite index input = quote @@ -255,7 +253,7 @@ end q(x[1]) * q(x[GraphPPL.FunctionalIndex{:end}(lastindex)]) end - @test_expression_generating apply_pipeline(input, replace_begin_end) output + TestUtils.@test_expression_generating apply_pipeline(input, replace_begin_end) output # Test 5: replace_begin_end with random begin and ends @@ -266,7 +264,7 @@ end end end end - @test_expression_generating apply_pipeline(input, replace_begin_end) input + TestUtils.@test_expression_generating apply_pipeline(input, replace_begin_end) input # Test 6: replace_begin_end with model specification begin and ends input = quote @@ -275,13 +273,12 @@ end output = quote y ~ Normal(μ = x[GraphPPL.FunctionalIndex{:end}(lastindex)], σ = 1.0) end - @test_expression_generating apply_pipeline(input, replace_begin_end) output + TestUtils.@test_expression_generating apply_pipeline(input, replace_begin_end) output end -@testitem "create_submodel_constraints" begin +@testitem "create_submodel_constraints" setup = [TestUtils] begin import GraphPPL: create_submodel_constraints, apply_pipeline - - include("../../testutils.jl") + using MacroTools # Test 1: create_submodel_constraints with one nested layer input = quote @@ -309,7 +306,7 @@ end __constraints__ end end - @test_expression_generating apply_pipeline(input, create_submodel_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, create_submodel_constraints) output # Test 2: create_submodel_constraints with two nested layers input = quote @@ -342,7 +339,7 @@ end q(a, b, c) = q(a)q(b)q(c) return __constraints__ end - @test_expression_generating apply_pipeline(input, create_submodel_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, create_submodel_constraints) output # Test 3: create_submodel_constraints with one nested layer and specific subconstraints input = quote @@ -368,7 +365,7 @@ end q(a, b, c) = q(a)q(b)q(c) return __constraints__ end - @test_expression_generating apply_pipeline(input, create_submodel_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, create_submodel_constraints) output # Test 2: create_submodel_constraints with two nested layers input = quote @@ -401,13 +398,12 @@ end q(a, b, c) = q(a)q(b)q(c) return __constraints__ end - @test_expression_generating apply_pipeline(input, create_submodel_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, create_submodel_constraints) output end -@testitem "create_factorization_split" begin +@testitem "create_factorization_split" setup = [TestUtils] begin import GraphPPL: create_factorization_split, apply_pipeline - - include("../../testutils.jl") + using MacroTools # Test 1: create_factorization_split with one factorization split input = quote @@ -416,7 +412,7 @@ end output = quote q(x) = GraphPPL.factorization_split(q(x[begin]), q(x[end])) end - @test_expression_generating apply_pipeline(input, create_factorization_split) output + TestUtils.@test_expression_generating apply_pipeline(input, create_factorization_split) output # Test 2: create_factorization_split with two factorization splits input = quote @@ -425,7 +421,7 @@ end output = quote q(x, y) = GraphPPL.factorization_split(q(x[begin], y[begin]), q(x[end], y[end])) end - @test_expression_generating apply_pipeline(input, create_factorization_split) output + TestUtils.@test_expression_generating apply_pipeline(input, create_factorization_split) output # Test 3: create_factorization_split with two a factorization split and more entries input = quote @@ -434,15 +430,14 @@ end output = quote q(x, y, z) = GraphPPL.factorization_split(q(y)q(x[begin]), q(x[end])q(z)) end - @test_expression_generating apply_pipeline(input, create_factorization_split) output + TestUtils.@test_expression_generating apply_pipeline(input, create_factorization_split) output # Test 4: create_factorization_split with two factorization splits and more entries end -@testitem "create_factorization_combinedrange" begin +@testitem "create_factorization_combinedrange" setup = [TestUtils] begin import GraphPPL: create_factorization_combinedrange, apply_pipeline - - include("../../testutils.jl") + using MacroTools # Test 1: create_factorization_combinedrange with one combined range input = quote @@ -451,13 +446,12 @@ end output = quote q(x) = q(x[GraphPPL.CombinedRange(begin, end)]) end - @test_expression_generating apply_pipeline(input, create_factorization_combinedrange) output + TestUtils.@test_expression_generating apply_pipeline(input, create_factorization_combinedrange) output end -@testitem "convert_variable_statements" begin +@testitem "convert_variable_statements" setup = [TestUtils] begin import GraphPPL: convert_variable_statements, apply_pipeline - - include("../../testutils.jl") + using MacroTools # Test 1: convert_variable_statements with a single variable statement input = quote @@ -469,7 +463,7 @@ end q(GraphPPL.IndexedVariable(:x, GraphPPL.FunctionalIndex{:end}(lastindex))) ) end - @test_expression_generating apply_pipeline(input, convert_variable_statements) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_variable_statements) output # Test 2: convert_variable_statements with a multi-indexed variable statement input = quote @@ -479,7 +473,7 @@ end q(GraphPPL.IndexedVariable(:x, nothing), GraphPPL.IndexedVariable(:y, nothing)) = q(GraphPPL.IndexedVariable(:x, nothing))q(GraphPPL.IndexedVariable(:y, [1, 1]))q(GraphPPL.IndexedVariable(:y, [2, 2])) end - @test_expression_generating apply_pipeline(input, convert_variable_statements) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_variable_statements) output # Test 3: convert_variable_statements with a message constraint input = quote @@ -488,7 +482,7 @@ end output = quote μ(GraphPPL.IndexedVariable(:x, nothing))::PointMass end - @test_expression_generating apply_pipeline(input, convert_variable_statements) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_variable_statements) output # Test 4: convert_variable_statements with a message constraint with indcides input = quote @@ -497,7 +491,7 @@ end output = quote μ(GraphPPL.IndexedVariable(:x, [1, 1]))::PointMass end - @test_expression_generating apply_pipeline(input, convert_variable_statements) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_variable_statements) output # Test 5: convert_variable_statements with a CombinedRange input = quote @@ -506,7 +500,7 @@ end output = quote μ(GraphPPL.IndexedVariable(:x, CombinedRange(1, 2)))::PointMass end - @test_expression_generating apply_pipeline(input, convert_variable_statements) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_variable_statements) output # Test 6: convert_variable_statements with a CombinedRange input = quote @@ -515,13 +509,12 @@ end output = quote q(GraphPPL.IndexedVariable(:x, nothing)) = q(GraphPPL.IndexedVariable(:x, CombinedRange(1, 2))) end - @test_expression_generating apply_pipeline(input, convert_variable_statements) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_variable_statements) output end -@testitem "convert_functionalform_constraints" begin +@testitem "convert_functionalform_constraints" setup = [TestUtils] begin import GraphPPL: convert_functionalform_constraints, apply_pipeline, IndexedVariable - - include("../../testutils.jl") + using MacroTools # Test 1: convert_functionalform_constraints with a single functional form constraint input = quote @@ -530,7 +523,7 @@ end output = quote push!(__constraints__, GraphPPL.MarginalFormConstraint(GraphPPL.IndexedVariable(:x, nothing), PointMass)) end - @test_expression_generating apply_pipeline(input, convert_functionalform_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_functionalform_constraints) output # Test 2: convert_functionalform_constraints with a functional form constraint over multiple variables input = quote @@ -542,7 +535,7 @@ end GraphPPL.MarginalFormConstraint((GraphPPL.IndexedVariable(:x, nothing), GraphPPL.IndexedVariable(:y, nothing)), PointMass) ) end - @test_expression_generating apply_pipeline(input, convert_functionalform_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_functionalform_constraints) output # Test 3: convert_functionalform_constraints with a functional form constraint in a nested constraint specification input = quote @@ -588,13 +581,12 @@ end end end end - @test_expression_generating apply_pipeline(input, convert_functionalform_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_functionalform_constraints) output end -@testitem "convert_message_constraints" begin +@testitem "convert_message_constraints" setup = [TestUtils] begin import GraphPPL: convert_message_constraints, apply_pipeline, IndexedVariable - - include("../../testutils.jl") + using MacroTools # Test 1: convert_message_constraints with a single functional form constraint input = quote @@ -603,13 +595,12 @@ end output = quote push!(__constraints__, GraphPPL.MessageFormConstraint(GraphPPL.IndexedVariable(:x, nothing), PointMass)) end - @test_expression_generating apply_pipeline(input, convert_message_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_message_constraints) output end -@testitem "convert_factorization_constraints" begin +@testitem "convert_factorization_constraints" setup = [TestUtils] begin import GraphPPL: convert_factorization_constraints, apply_pipeline, IndexedVariable - - include("../../testutils.jl") + using MacroTools # Test 1: convert_factorization_constraints with a single factorization constraint input = quote @@ -626,7 +617,7 @@ end ) ) end - @test_expression_generating apply_pipeline(input, convert_factorization_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_factorization_constraints) output # Test 2: convert_factorization_constraints with a factorization constraint that has no multiplication input = quote @@ -641,13 +632,12 @@ end ) ) end - @test_expression_generating apply_pipeline(input, convert_factorization_constraints) output + TestUtils.@test_expression_generating apply_pipeline(input, convert_factorization_constraints) output end -@testitem "constraints_macro_interior" begin +@testitem "constraints_macro_interior" setup = [TestUtils] begin import GraphPPL: constraints_macro_interior - - include("../../testutils.jl") + using MacroTools input = quote q(x)::Normal @@ -683,21 +673,20 @@ end end end - @test_expression_generating constraints_macro_interior(input) output + TestUtils.@test_expression_generating constraints_macro_interior(input) output end -@testitem "constraints_macro" begin +@testitem "constraints_macro" setup = [TestUtils] begin import GraphPPL: Constraints + using MacroTools - include("../../testutils.jl") - - using .TestUtils.ModelZoo + struct PointMass end constraints = @constraints begin q(x, y) = q(x)q(y) q(x) = q(x[begin]) .. q(x[end]) q(μ)::PointMass - for q in prior + for q in TestUtils.prior q(u, v, k) = q(u)q(v)q(k) end end diff --git a/test/plugins/variational_constraints/variational_constraints_tests.jl b/test/plugins/variational_constraints/variational_constraints_tests.jl index 5b94d7fc..bf3f5816 100644 --- a/test/plugins/variational_constraints/variational_constraints_tests.jl +++ b/test/plugins/variational_constraints/variational_constraints_tests.jl @@ -8,9 +8,15 @@ end @testitem "simple @model + various constraints" begin using Distributions import GraphPPL: - create_model, with_plugins, PluginsCollection, VariationalConstraintsPlugin, getorcreate!, NodeCreationOptions, hasextra, getextra - - include("../../testutils.jl") + create_model, + with_plugins, + PluginsCollection, + VariationalConstraintsPlugin, + getorcreate!, + NodeCreationOptions, + hasextra, + getextra, + @model @model function simple_model() x ~ Beta(1, 1) @@ -278,7 +284,7 @@ end end end -@testitem "simple @model + mean field @constraints + anonymous variable linked through a deterministic relation" begin +@testitem "simple @model + mean field @constraints + anonymous variable linked through a deterministic relation" setup = [TestUtils] begin using Distributions using GraphPPL: create_model, @@ -293,9 +299,7 @@ end VariationalConstraintsPlugin, with_plugins - include("../../testutils.jl") - - @model function simple_model(a, b, c) + TestUtils.@model function simple_model(a, b, c) x ~ Gamma(α = b, θ = sqrt(c)) a ~ Normal(μ = x, τ = 1) end @@ -323,7 +327,9 @@ end end end -@testitem "state space model @model + mean field @constraints + anonymous variable linked through a deterministic relation" begin +@testitem "state space model @model + mean field @constraints + anonymous variable linked through a deterministic relation" setup = [ + TestUtils +] begin using Distributions using GraphPPL: create_model, @@ -340,17 +346,13 @@ end with_plugins, datalabel - include("../../testutils.jl") - - using .TestUtils.ModelZoo - - @model function random_walk(y, a, b) - x[1] ~ NormalMeanVariance(0, 1) - y[1] ~ NormalMeanVariance(x[1], 1) + TestUtils.@model function random_walk(y, a, b) + x[1] ~ TestUtils.NormalMeanVariance(0, 1) + y[1] ~ TestUtils.NormalMeanVariance(x[1], 1) for i in 2:length(y) - x[i] ~ NormalMeanPrecision(a * x[i - 1] + b, 1) - y[i] ~ NormalMeanVariance(x[i], 1) + x[i] ~ TestUtils.NormalMeanPrecision(a * x[i - 1] + b, 1) + y[i] ~ TestUtils.NormalMeanVariance(x[i], 1) end end @@ -369,19 +371,19 @@ end end @test length(collect(filter(as_node(Normal), model))) === 2 * n - @test length(collect(filter(as_node(NormalMeanVariance), model))) === n + 1 - @test length(collect(filter(as_node(NormalMeanPrecision), model))) === n - 1 + @test length(collect(filter(as_node(TestUtils.NormalMeanVariance), model))) === n + 1 + @test length(collect(filter(as_node(TestUtils.NormalMeanPrecision), model))) === n - 1 @test length(collect(filter(as_node(prod), model))) === n - 1 @test length(collect(filter(as_node(sum), model))) === n - 1 - @test all(filter(as_node(NormalMeanVariance), model)) do node + @test all(filter(as_node(TestUtils.NormalMeanVariance), model)) do node # This must be factorized out just because of the implicit constraint for conststs and datavars interfaces = GraphPPL.edges(model, node) @test hasextra(model[node], :factorization_constraint_indices) return Tuple.(getextra(model[node], :factorization_constraint_indices)) === ((1,), (2,), (3,)) end - @test all(filter(as_node(NormalMeanPrecision), model)) do node + @test all(filter(as_node(TestUtils.NormalMeanPrecision), model)) do node # The test tests that the factorization constraint around the node `x[i] ~ Normal(a * x[i - 1] + b, 1)` # is correctly resolved to structured, since empty constraints do not factorize out this case interfaces = GraphPPL.edges(model, node) @@ -398,12 +400,12 @@ end end @test length(collect(filter(as_node(Normal), model))) == 2 * n - @test length(collect(filter(as_node(NormalMeanVariance), model))) === n + 1 - @test length(collect(filter(as_node(NormalMeanPrecision), model))) === n - 1 + @test length(collect(filter(as_node(TestUtils.NormalMeanVariance), model))) === n + 1 + @test length(collect(filter(as_node(TestUtils.NormalMeanPrecision), model))) === n - 1 @test length(collect(filter(as_node(prod), model))) === n - 1 @test length(collect(filter(as_node(sum), model))) === n - 1 - @test all(filter(as_node(NormalMeanPrecision) | as_node(NormalMeanVariance), model)) do node + @test all(filter(as_node(TestUtils.NormalMeanPrecision) | as_node(TestUtils.NormalMeanVariance), model)) do node # The test tests that the factorization constraint around the node `x[i] ~ Normal(a * x[i - 1] + b, 1)` # is correctly resolved to mean-field, because `a * x[i - 1] + b` is deterministically linked to `x[i - 1]`, thus # the interfaces must be factorized out @@ -416,7 +418,9 @@ end end end -@testitem "simple @model + structured @constraints + anonymous variable linked through a deterministic relation with constants/datavars" begin +@testitem "simple @model + structured @constraints + anonymous variable linked through a deterministic relation with constants/datavars" setup = [ + TestUtils +] begin using Distributions, LinearAlgebra using GraphPPL: create_model, @@ -432,9 +436,7 @@ end VariationalConstraintsPlugin, with_plugins - include("../../testutils.jl") - - @model function simple_model(y, a, b) + TestUtils.@model function simple_model(y, a, b) τ ~ Gamma(10, 10) # wrong for MvNormal, but test is for a different purpose θ ~ Gamma(10, 10) @@ -476,7 +478,7 @@ end end end -@testitem "state space @model (nested) + @constraints + anonymous variable linked through a deterministic relation" begin +@testitem "state space @model (nested) + @constraints + anonymous variable linked through a deterministic relation" setup = [TestUtils] begin using Distributions using GraphPPL: create_model, @@ -493,17 +495,15 @@ end with_plugins, datalabel - include("../../testutils.jl") - - @model function nested2(u, θ, c, d) + TestUtils.@model function nested2(u, θ, c, d) u ~ Normal(c * θ + d, 1) end - @model function nested1(z, g, a, b) + TestUtils.@model function nested1(z, g, a, b) z ~ nested2(θ = g, c = a, d = b) end - @model function random_walk(y, a, b) + TestUtils.@model function random_walk(y, a, b) x[1] ~ Normal(0, 1) y[1] ~ Normal(x[1], 1) @@ -581,7 +581,7 @@ end end end -@testitem "Simple @model + functional form constraints" begin +@testitem "Simple @model + functional form constraints" setup = [TestUtils] begin using Distributions import GraphPPL: @@ -597,9 +597,7 @@ end VariationalConstraintsMarginalFormConstraintKey, VariationalConstraintsMessagesFormConstraintKey - include("../../testutils.jl") - - @model function simple_model_for_fform_constraints() + TestUtils.@model function simple_model_for_fform_constraints() x ~ Normal(0, 1) y ~ Gamma(1, 1) z ~ Normal(x, y) @@ -689,7 +687,7 @@ end end end -@testitem "@constraints macro pipeline" begin +@testitem "@constraints macro pipeline" setup = [TestUtils] begin import GraphPPL: create_model, with_plugins, @@ -703,44 +701,40 @@ end VariationalConstraintsMessagesFormConstraintKey, VariationalConstraintsFactorizationIndicesKey - include("../../testutils.jl") - - using .TestUtils.ModelZoo - constraints = @constraints begin q(x, y) = q(x)q(y) q(y, z) = q(y)q(z) - q(x)::NormalMeanVariance() - μ(y)::NormalMeanVariance() + q(x)::TestUtils.NormalMeanVariance() + μ(y)::TestUtils.NormalMeanVariance() end # Test constraints macro with single variables and no nesting - model = create_model(with_plugins(simple_model(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) + model = create_model(with_plugins(TestUtils.simple_model(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) ctx = GraphPPL.getcontext(model) for node in filter(GraphPPL.as_variable(:x), model) - @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == NormalMeanVariance() + @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == TestUtils.NormalMeanVariance() @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) end for node in filter(GraphPPL.as_variable(:y), model) @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) - @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == NormalMeanVariance() + @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == TestUtils.NormalMeanVariance() end for node in filter(GraphPPL.as_variable(:z), model) @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) end - @test Tuple.(getextra(model[ctx[NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey)) == ((1,), (2,), (3,)) - @test Tuple.(getextra(model[ctx[NormalMeanVariance, 2]], VariationalConstraintsFactorizationIndicesKey)) == ((1, 2), (3,)) + @test Tuple.(getextra(model[ctx[TestUtils.NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey)) == ((1,), (2,), (3,)) + @test Tuple.(getextra(model[ctx[TestUtils.NormalMeanVariance, 2]], VariationalConstraintsFactorizationIndicesKey)) == ((1, 2), (3,)) # Test constriants macro with nested model constraints = @constraints begin - for q in inner + for q in TestUtils.inner q(α, θ) = q(α)q(θ) - q(α)::NormalMeanVariance() - μ(θ)::NormalMeanVariance() + q(α)::TestUtils.NormalMeanVariance() + μ(θ)::TestUtils.NormalMeanVariance() end end - model = create_model(with_plugins(outer(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) + model = create_model(with_plugins(TestUtils.outer(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) ctx = GraphPPL.getcontext(model) @test hasextra(model[ctx[:w][1]], VariationalConstraintsMarginalFormConstraintKey) === false @@ -750,44 +744,47 @@ end @test hasextra(model[ctx[:w][5]], VariationalConstraintsMarginalFormConstraintKey) === false @test hasextra(model[ctx[:w][1]], VariationalConstraintsMessagesFormConstraintKey) === false - @test getextra(model[ctx[:w][2]], VariationalConstraintsMessagesFormConstraintKey) === NormalMeanVariance() - @test getextra(model[ctx[:w][3]], VariationalConstraintsMessagesFormConstraintKey) === NormalMeanVariance() + @test getextra(model[ctx[:w][2]], VariationalConstraintsMessagesFormConstraintKey) === TestUtils.NormalMeanVariance() + @test getextra(model[ctx[:w][3]], VariationalConstraintsMessagesFormConstraintKey) === TestUtils.NormalMeanVariance() @test hasextra(model[ctx[:w][4]], VariationalConstraintsMessagesFormConstraintKey) === false @test hasextra(model[ctx[:w][5]], VariationalConstraintsMessagesFormConstraintKey) === false - @test getextra(model[ctx[:y]], VariationalConstraintsMarginalFormConstraintKey) == NormalMeanVariance() - for node in filter(GraphPPL.as_node(NormalMeanVariance) & GraphPPL.as_context(inner_inner), model) + @test getextra(model[ctx[:y]], VariationalConstraintsMarginalFormConstraintKey) == TestUtils.NormalMeanVariance() + for node in filter(GraphPPL.as_node(TestUtils.NormalMeanVariance) & GraphPPL.as_context(TestUtils.inner_inner), model) @test Tuple.(getextra(model[node], VariationalConstraintsFactorizationIndicesKey)) == ((1,), (2, 3)) end # Test with specifying specific submodel constraints = @constraints begin - for q in (child_model, 1) + for q in (TestUtils.child_model, 1) q(in, out, σ) = q(in, out)q(σ) end end - model = create_model(with_plugins(parent_model(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) + model = create_model(with_plugins(TestUtils.parent_model(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) ctx = GraphPPL.getcontext(model) - @test Tuple.(getextra(model[ctx[child_model, 1][NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey)) == - ((1, 2), (3,)) + @test Tuple.( + getextra(model[ctx[TestUtils.child_model, 1][TestUtils.NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey) + ) == ((1, 2), (3,)) for i in 2:99 - @test Tuple.(getextra(model[ctx[child_model, i][NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey)) == - ((1, 2, 3),) + @test Tuple.( + getextra(model[ctx[TestUtils.child_model, i][TestUtils.NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey) + ) == ((1, 2, 3),) end # Test with specifying general submodel constraints = @constraints begin - for q in child_model + for q in TestUtils.child_model q(in, out, σ) = q(in, out)q(σ) end end - model = create_model(with_plugins(parent_model(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) + model = create_model(with_plugins(TestUtils.parent_model(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) ctx = GraphPPL.getcontext(model) - @test Tuple.(getextra(model[ctx[child_model, 1][NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey)) == - ((1, 2), (3,)) - for node in filter(GraphPPL.as_node(NormalMeanVariance) & GraphPPL.as_context(child_model), model) + @test Tuple.( + getextra(model[ctx[TestUtils.child_model, 1][TestUtils.NormalMeanVariance, 1]], VariationalConstraintsFactorizationIndicesKey) + ) == ((1, 2), (3,)) + for node in filter(GraphPPL.as_node(TestUtils.NormalMeanVariance) & GraphPPL.as_context(TestUtils.child_model), model) @test Tuple.(getextra(model[node], VariationalConstraintsFactorizationIndicesKey)) == ((1, 2), (3,)) end @@ -795,42 +792,42 @@ end constraints = @constraints begin q(x, y) = q(x)q(y) end - @test_throws ErrorException create_model(with_plugins(simple_model(), PluginsCollection(VariationalConstraintsPlugin(constraints)))) + @test_throws ErrorException create_model( + with_plugins(TestUtils.simple_model(), PluginsCollection(VariationalConstraintsPlugin(constraints))) + ) end -@testitem "A complex hierarchical constraints with lots of renaming and interleaving with constants" begin +@testitem "A complex hierarchical constraints with lots of renaming and interleaving with constants" setup = [TestUtils] begin using Distributions using BitSetTuples import GraphPPL: create_model, with_plugins, PluginsCollection, VariationalConstraintsPlugin, getorcreate!, NodeCreationOptions, hasextra, getextra - include("../../testutils.jl") - - @model function submodel_3_1(b, n, m) + TestUtils.@model function submodel_3_1(b, n, m) b ~ Normal(n, m) end - @model function submodel_3_2(b, n, m) + TestUtils.@model function submodel_3_2(b, n, m) b ~ Normal(n + 1, m + 1) end - @model function submodel_2_1(a, b, c, submodel_3) + TestUtils.@model function submodel_2_1(a, b, c, submodel_3) c ~ submodel_3(b = a, m = b) end - @model function submodel_2_2(a, b, c, submodel_3) + TestUtils.@model function submodel_2_2(a, b, c, submodel_3) c ~ submodel_3(b = a + 1, m = b + 1) end - @model function submodel_1_1(x, y, z, submodel_2, submodel_3) + TestUtils.@model function submodel_1_1(x, y, z, submodel_2, submodel_3) z ~ submodel_2(a = x, b = y, submodel_3 = submodel_3) end - @model function submodel_1_2(x, y, z, submodel_2, submodel_3) + TestUtils.@model function submodel_1_2(x, y, z, submodel_2, submodel_3) z ~ submodel_2(a = x + 1, b = y + 1, submodel_3 = submodel_3) end - @model function main_model(case, submodel_1, submodel_2, submodel_3) + TestUtils.@model function main_model(case, submodel_1, submodel_2, submodel_3) r ~ Gamma(1, 1) u ~ Beta(1, 1) # In the test we impose the mean-field factorization @@ -909,7 +906,7 @@ end end end -@testitem "A joint constraint over 'initial variable' and 'state variables' aka `q(x0, x)q(γ)`" begin +@testitem "A joint constraint over 'initial variable' and 'state variables' aka `q(x0, x)q(γ)`" setup = [TestUtils] begin using Distributions import GraphPPL: @@ -923,9 +920,7 @@ end hasextra, datalabel - include("../../testutils.jl") - - @model function some_state_space_model(y) + TestUtils.@model function some_state_space_model(y) γ ~ Gamma(1, 1) θ ~ Gamma(1, 1) μ0 ~ Beta(1, 1) @@ -971,15 +966,11 @@ end end end -@testitem "Apply MeanField constraints" begin +@testitem "Apply MeanField constraints" setup = [TestUtils] begin using GraphPPL import GraphPPL: create_model, with_plugins, getproperties, neighbor_data - include("../../testutils.jl") - - using .TestUtils.ModelZoo - - for model_fform in ModelsInTheZooWithoutArguments + for model_fform in TestUtils.ModelsInTheZooWithoutArguments model = create_model(with_plugins(model_fform(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(MeanField())))) for node in filter(as_node(), model) node_data = model[node] @@ -989,17 +980,13 @@ end end end -@testitem "Apply BetheFactorization constraints" begin +@testitem "Apply BetheFactorization constraints" setup = [TestUtils] begin using GraphPPL import GraphPPL: create_model, with_plugins, getproperties, neighbor_data, is_factorized - include("../../testutils.jl") - - using .TestUtils.ModelZoo - # BetheFactorization uses `default_constraints` for `contains_default_constraints` # So it is not tested here - for model_fform in setdiff(Set(ModelsInTheZooWithoutArguments), Set([contains_default_constraints])) + for model_fform in setdiff(Set(TestUtils.ModelsInTheZooWithoutArguments), Set([TestUtils.contains_default_constraints])) model = create_model( with_plugins(model_fform(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(BetheFactorization()))) ) @@ -1017,13 +1004,11 @@ end end end -@testitem "Default constraints of top level model" begin - using GraphPPL +@testitem "Default constraints of top level model" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, with_plugins, getproperties, neighbor_data, is_factorized - include("../../testutils.jl") - - @model function model_with_default_constraints() + TestUtils.@model function model_with_default_constraints() x ~ Normal(0, 1) y ~ Normal(x, 1) z ~ Normal(y, 1) @@ -1044,12 +1029,11 @@ end end end -@testitem "Constraint over a mixture model" begin +@testitem "Constraint over a mixture model" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, with_plugins, getproperties, neighbor_data, is_factorized - include("../../testutils.jl") - - @model function mixture() + TestUtils.@model function mixture() m1 ~ Normal(0, 1) m2 ~ Normal(0, 1) m3 ~ Normal(0, 1) @@ -1071,10 +1055,12 @@ end end for constraints in [constraints_1, constraints_2] - model = create_model(with_plugins(mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(constraints)))) + model = create_model( + with_plugins(TestUtils.mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(constraints))) + ) - @test length(collect(filter(as_node(Mixture), model))) === 1 - for node in filter(as_node(Mixture), model) + @test length(collect(filter(as_node(TestUtils.Mixture), model))) === 1 + for node in filter(as_node(TestUtils.Mixture), model) node_data = model[node] @test Tuple.(GraphPPL.getextra(node_data, :factorization_constraint_indices)) == ((1, 2, 3, 4, 5, 6, 7, 8, 9),) end @@ -1091,15 +1077,17 @@ end end for constraints in [constraints_1, constraints_2] - model = create_model(with_plugins(mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(constraints)))) + model = create_model( + with_plugins(TestUtils.mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(constraints))) + ) for node in filter(as_node(), model) node_data = model[node] @test GraphPPL.getextra(node_data, :factorization_constraint_indices) == Tuple([[i] for i in 1:(length(neighbor_data(getproperties(node_data))))]) end - @test length(collect(filter(as_node(Mixture), model))) === 1 - for node in filter(as_node(Mixture), model) + @test length(collect(filter(as_node(TestUtils.Mixture), model))) === 1 + for node in filter(as_node(TestUtils.Mixture), model) node_data = model[node] @test Tuple.(GraphPPL.getextra(node_data, :factorization_constraint_indices)) == ((1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,)) @@ -1108,18 +1096,17 @@ end end end -@testitem "Issue 262, factorization constraint should not attempt to create a variable from submodels" begin +@testitem "Issue 262, factorization constraint should not attempt to create a variable from submodels" setup = [TestUtils] begin + using Distributions import GraphPPL: create_model, with_plugins, getproperties, neighbor_data, is_factorized - include("../../testutils.jl") - - @model function submodel(y, x) + TestUtils.@model function submodel(y, x) for i in 1:10 y[i] ~ Normal(x, 1) end end - @model function main_model() + TestUtils.@model function main_model() x ~ Normal(0, 1) y ~ submodel(x = x) end @@ -1135,7 +1122,7 @@ end @test length(collect(filter(as_node(Normal), model))) == 11 end -@testitem "`@constraints` should save the source code #1" begin +@testitem "`@constraints` should save the source code #1" setup = [TestUtils] begin using GraphPPL constraints = @constraints begin diff --git a/test/runtests.jl b/test/runtests.jl index 78df110d..4ab07d94 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,19 @@ -using ReTestItems, GraphPPL, Aqua +using GraphPPL +using Test +using Aqua +using JET +using TestItemRunner -Aqua.test_all(GraphPPL; ambiguities = (broken = true,)) +# include("testutils.jl") -runtests(GraphPPL) +@testset "GraphPPL.jl" begin + @testset "Code quality (Aqua.jl)" begin + Aqua.test_all(GraphPPL; ambiguities = (broken = true,)) + end + + # @testset "Code linting (JET.jl)" begin + # JET.test_package(GraphPPL; target_defined_modules = true) + # end + + TestItemRunner.@run_package_tests() +end \ No newline at end of file diff --git a/test/testutils.jl b/test/testutils.jl index a349a180..1461492c 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -1,411 +1,405 @@ -module TestUtils - -using GraphPPL, MacroTools, Static, Distributions - -export @test_expression_generating - -macro test_expression_generating(lhs, rhs) - test_expr_gen = gensym(:text_expr_gen) - return esc( - quote - $test_expr_gen = (prettify($lhs) == prettify($rhs)) - if !$test_expr_gen - println("Expressions do not match: ") - println("lhs: ", prettify($lhs)) - println("rhs: ", prettify($rhs)) +@testmodule TestUtils begin + using GraphPPL + using MacroTools + using Static + using Distributions + + export @test_expression_generating + + macro test_expression_generating(lhs, rhs) + test_expr_gen = gensym(:text_expr_gen) + return esc( + quote + $test_expr_gen = (prettify($lhs) == prettify($rhs)) + if !$test_expr_gen + println("Expressions do not match: ") + println("lhs: ", prettify($lhs)) + println("rhs: ", prettify($rhs)) + end + @test (prettify($lhs) == prettify($rhs)) end - @test (prettify($lhs) == prettify($rhs)) - end - ) -end - -export @test_expression_generating_broken - -macro test_expression_generating_broken(lhs, rhs) - return esc(:(@test_broken (prettify($lhs) == prettify($rhs)))) -end - -# We use a custom backend for testing purposes, instead of using the `DefaultBackend` -# The `TestGraphPPLBackend` is a simple backend that specifies how to handle objects from `Distributions.jl` -# It does use the default pipeline collection for the `@model` macro -struct TestGraphPPLBackend <: GraphPPL.AbstractBackend end + ) + end -GraphPPL.model_macro_interior_pipelines(::TestGraphPPLBackend) = GraphPPL.model_macro_interior_pipelines(GraphPPL.DefaultBackend()) + export @test_expression_generating_broken -# The `TestGraphPPLBackend` redirects some of the methods to the `DefaultBackend` -# (not all though, `TestGraphPPLBackend` implements some of them for the custom structures defined also below) -# The `DefaultBackend` has extension rules for `Distributions.jl` types for example -GraphPPL.NodeBehaviour(::TestGraphPPLBackend, fform) = GraphPPL.NodeBehaviour(GraphPPL.DefaultBackend(), fform) -GraphPPL.NodeType(::TestGraphPPLBackend, fform) = GraphPPL.NodeType(GraphPPL.DefaultBackend(), fform) -GraphPPL.aliases(::TestGraphPPLBackend, fform) = GraphPPL.aliases(GraphPPL.DefaultBackend(), fform) -GraphPPL.interfaces(::TestGraphPPLBackend, fform, n) = GraphPPL.interfaces(GraphPPL.DefaultBackend(), fform, n) -GraphPPL.factor_alias(::TestGraphPPLBackend, f, interfaces) = GraphPPL.factor_alias(GraphPPL.DefaultBackend(), f, interfaces) -GraphPPL.interface_aliases(::TestGraphPPLBackend, f) = GraphPPL.interface_aliases(GraphPPL.DefaultBackend(), f) -GraphPPL.default_parametrization(::TestGraphPPLBackend, nodetype, f, rhs) = - GraphPPL.default_parametrization(GraphPPL.DefaultBackend(), nodetype, f, rhs) -GraphPPL.instantiate(::Type{TestGraphPPLBackend}) = TestGraphPPLBackend() + macro test_expression_generating_broken(lhs, rhs) + return esc(:(@test_broken (prettify($lhs) == prettify($rhs)))) + end -# Check that we can alias the `+` into `sum` and `*` into `prod` -GraphPPL.factor_alias(::TestGraphPPLBackend, ::typeof(+), interfaces) = sum -GraphPPL.factor_alias(::TestGraphPPLBackend, ::typeof(*), interfaces) = prod + # We use a custom backend for testing purposes, instead of using the `DefaultBackend` + # The `TestGraphPPLBackend` is a simple backend that specifies how to handle objects from `Distributions.jl` + # It does use the default pipeline collection for the `@model` macro + struct TestGraphPPLBackend <: GraphPPL.AbstractBackend end + + GraphPPL.model_macro_interior_pipelines(::TestGraphPPLBackend) = GraphPPL.model_macro_interior_pipelines(GraphPPL.DefaultBackend()) + + # The `TestGraphPPLBackend` redirects some of the methods to the `DefaultBackend` + # (not all though, `TestGraphPPLBackend` implements some of them for the custom structures defined also below) + # The `DefaultBackend` has extension rules for `Distributions.jl` types for example + GraphPPL.NodeBehaviour(::TestGraphPPLBackend, fform) = GraphPPL.NodeBehaviour(GraphPPL.DefaultBackend(), fform) + GraphPPL.NodeType(::TestGraphPPLBackend, fform) = GraphPPL.NodeType(GraphPPL.DefaultBackend(), fform) + GraphPPL.aliases(::TestGraphPPLBackend, fform) = GraphPPL.aliases(GraphPPL.DefaultBackend(), fform) + GraphPPL.interfaces(::TestGraphPPLBackend, fform, n) = GraphPPL.interfaces(GraphPPL.DefaultBackend(), fform, n) + GraphPPL.factor_alias(::TestGraphPPLBackend, f, interfaces) = GraphPPL.factor_alias(GraphPPL.DefaultBackend(), f, interfaces) + GraphPPL.interface_aliases(::TestGraphPPLBackend, f) = GraphPPL.interface_aliases(GraphPPL.DefaultBackend(), f) + GraphPPL.default_parametrization(::TestGraphPPLBackend, nodetype, f, rhs) = + GraphPPL.default_parametrization(GraphPPL.DefaultBackend(), nodetype, f, rhs) + GraphPPL.instantiate(::Type{TestGraphPPLBackend}) = TestGraphPPLBackend() + + # Check that we can alias the `+` into `sum` and `*` into `prod` + GraphPPL.factor_alias(::TestGraphPPLBackend, ::typeof(+), interfaces) = sum + GraphPPL.factor_alias(::TestGraphPPLBackend, ::typeof(*), interfaces) = prod + + export @model + + # This is a special `@model` macro that should be used in tests + macro model(model_specification) + return esc(GraphPPL.model_macro_interior(TestGraphPPLBackend, model_specification)) + end -export @model + export create_test_model -# This is a special `@model` macro that should be used in tests -macro model(model_specification) - return esc(GraphPPL.model_macro_interior(TestGraphPPLBackend, model_specification)) -end + function create_test_model(; + fform = identity, plugins = GraphPPL.PluginsCollection(), backend = TestGraphPPLBackend(), source = nothing + ) + # `identity` is not really a probabilistic model and also does not have a backend nor a source code + # for testing purposes however it should be fine + return GraphPPL.Model(fform, plugins, backend, source) + end -export create_test_model + # Node zoo fo tests -function create_test_model(; fform = identity, plugins = GraphPPL.PluginsCollection(), backend = TestGraphPPLBackend(), source = nothing) - # `identity` is not really a probabilistic model and also does not have a backend nor a source code - # for testing purposes however it should be fine - return GraphPPL.Model(fform, plugins, backend, source) -end + export PointMass, ArbitraryNode, NormalMeanVariance, NormalMeanPrecision, GammaShapeRate, GammaShapeScale, Mixture -# Node zoo fo tests + struct PointMass end -export PointMass, ArbitraryNode, NormalMeanVariance, NormalMeanPrecision, GammaShapeRate, GammaShapeScale, Mixture + GraphPPL.prettyname(::Type{PointMass}) = "δ" -struct PointMass end + GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{PointMass}) = GraphPPL.Deterministic() -GraphPPL.prettyname(::Type{PointMass}) = "δ" + struct ArbitraryNode end -GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{PointMass}) = GraphPPL.Deterministic() + GraphPPL.prettyname(::Type{ArbitraryNode}) = "ArbitraryNode" -struct ArbitraryNode end + GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{ArbitraryNode}) = GraphPPL.Stochastic() -GraphPPL.prettyname(::Type{ArbitraryNode}) = "ArbitraryNode" + struct NormalMeanVariance end -GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{ArbitraryNode}) = GraphPPL.Stochastic() + GraphPPL.prettyname(::Type{NormalMeanVariance}) = "𝓝(μ, σ^2)" -struct NormalMeanVariance end + GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{NormalMeanVariance}) = GraphPPL.Stochastic() -GraphPPL.prettyname(::Type{NormalMeanVariance}) = "𝓝(μ, σ^2)" + struct NormalMeanPrecision end -GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{NormalMeanVariance}) = GraphPPL.Stochastic() + GraphPPL.prettyname(::Type{NormalMeanPrecision}) = "𝓝(μ, σ^-2)" -struct NormalMeanPrecision end + GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{NormalMeanPrecision}) = GraphPPL.Stochastic() -GraphPPL.prettyname(::Type{NormalMeanPrecision}) = "𝓝(μ, σ^-2)" + GraphPPL.aliases(::TestGraphPPLBackend, ::Type{Normal}) = (Normal, NormalMeanVariance, NormalMeanPrecision) -GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{NormalMeanPrecision}) = GraphPPL.Stochastic() + GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{NormalMeanVariance}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :μ, :σ)) + GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{NormalMeanPrecision}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :μ, :τ)) -GraphPPL.aliases(::TestGraphPPLBackend, ::Type{Normal}) = (Normal, NormalMeanVariance, NormalMeanPrecision) + GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Normal}, ::GraphPPL.StaticInterfaces{(:μ, :σ)}) = NormalMeanVariance + GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Normal}, ::GraphPPL.StaticInterfaces{(:μ, :τ)}) = NormalMeanPrecision -GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{NormalMeanVariance}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :μ, :σ)) -GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{NormalMeanPrecision}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :μ, :τ)) + GraphPPL.interface_aliases(::TestGraphPPLBackend, ::Type{Normal}) = GraphPPL.StaticInterfaceAliases(( + (:mean, :μ), + (:m, :μ), + (:variance, :σ), + (:var, :σ), + (:v, :σ), + (:τ⁻¹, :σ), + (:precision, :τ), + (:prec, :τ), + (:p, :τ), + (:w, :τ), + (:σ⁻², :τ), + (:γ, :τ) + )) -GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Normal}, ::GraphPPL.StaticInterfaces{(:μ, :σ)}) = NormalMeanVariance -GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Normal}, ::GraphPPL.StaticInterfaces{(:μ, :τ)}) = NormalMeanPrecision + struct GammaShapeRate end + struct GammaShapeScale end -GraphPPL.interface_aliases(::TestGraphPPLBackend, ::Type{Normal}) = GraphPPL.StaticInterfaceAliases(( - (:mean, :μ), - (:m, :μ), - (:variance, :σ), - (:var, :σ), - (:v, :σ), - (:τ⁻¹, :σ), - (:precision, :τ), - (:prec, :τ), - (:p, :τ), - (:w, :τ), - (:σ⁻², :τ), - (:γ, :τ) -)) + GraphPPL.aliases(::TestGraphPPLBackend, ::Type{Gamma}) = (Gamma, GammaShapeRate, GammaShapeScale) -struct GammaShapeRate end -struct GammaShapeScale end + GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{GammaShapeRate}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :α, :β)) + GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{GammaShapeScale}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :α, :θ)) -GraphPPL.aliases(::TestGraphPPLBackend, ::Type{Gamma}) = (Gamma, GammaShapeRate, GammaShapeScale) + GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :β)}) = GammaShapeRate + GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :θ)}) = GammaShapeScale -GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{GammaShapeRate}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :α, :β)) -GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{GammaShapeScale}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :α, :θ)) + struct Mixture end -GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :β)}) = GammaShapeRate -GraphPPL.factor_alias(::TestGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :θ)}) = GammaShapeScale + GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{Mixture}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :m, :τ)) -struct Mixture end + GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{Mixture}) = GraphPPL.Stochastic() -GraphPPL.interfaces(::TestGraphPPLBackend, ::Type{Mixture}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :m, :τ)) + # Model zoo for tests -GraphPPL.NodeBehaviour(::TestGraphPPLBackend, ::Type{Mixture}) = GraphPPL.Stochastic() + export simple_model, + vector_model, + tensor_model, + anonymous_in_loop, + node_with_only_anonymous, + node_with_two_anonymous, + type_arguments, + node_with_ambiguous_anonymous, + gcv, + gcv_lm, + hgf, + prior, + broadcastable, + broadcaster, + inner_inner, + inner, + outer, + multidim_array, + child_model, + parent_model, + model_with_default_constraints, + contains_default_constraints, + mixture, + filled_matrix_model -# Model zoo for tests + @model function simple_model() + x ~ Normal(0, 1) + y ~ Gamma(1, 1) + z ~ Normal(x, y) + end -module ModelZoo + @model function vector_model() + local x + local y + for i in 1:3 + x[i] ~ Normal(0, 1) + y[i] ~ Gamma(1, 1) + z[i] ~ Normal(x[i], y[i]) + end + end -export simple_model, - vector_model, - tensor_model, - anonymous_in_loop, - node_with_only_anonymous, - node_with_two_anonymous, - type_arguments, - node_with_ambiguous_anonymous, - gcv, - gcv_lm, - hgf, - prior, - broadcastable, - broadcaster, - inner_inner, - inner, - outer, - multidim_array, - child_model, - parent_model, - model_with_default_constraints, - contains_default_constraints, - mixture, - filled_matrix_model + @model function tensor_model() + local x + local y + for i in 1:3 + x[i, i] ~ Normal(0, 1) + y[i, i] ~ Gamma(1, 1) + z[i, i] ~ Normal(x[i, i], y[i, i]) + end + end -using GraphPPL, MacroTools, Static, Distributions -using ..TestUtils + @model function filled_matrix_model() + local x + local y + for i in 1:3 + for j in 1:3 + y[i, j] ~ Gamma(1, 1) + x[i, j] ~ Normal(0, y[i, j]) + end + end + end -@model function simple_model() - x ~ Normal(0, 1) - y ~ Gamma(1, 1) - z ~ Normal(x, y) -end + @model function anonymous_in_loop(x, y) + x_0 ~ Normal(μ = 0, σ = 1.0) + x_prev = x_0 + for i in 1:length(x) + x[i] ~ Normal(μ = x_prev + 1, σ = 1.0) + x_prev = x[i] + end -@model function vector_model() - local x - local y - for i in 1:3 - x[i] ~ Normal(0, 1) - y[i] ~ Gamma(1, 1) - z[i] ~ Normal(x[i], y[i]) + y ~ Normal(μ = x[end], σ = 1.0) end -end - -@model function tensor_model() - local x - local y - for i in 1:3 - x[i, i] ~ Normal(0, 1) - y[i, i] ~ Gamma(1, 1) - z[i, i] ~ Normal(x[i, i], y[i, i]) - end -end - -@model function filled_matrix_model() - local x - local y - for i in 1:3 - for j in 1:3 - y[i, j] ~ Gamma(1, 1) - x[i, j] ~ Normal(0, y[i, j]) + + @model function node_with_only_anonymous() + x[1] ~ Normal(0, 1) + y[1] ~ Normal(0, 1) + for i in 2:10 + y[i] ~ Normal(0, 1) + x[i] ~ Normal(y[i - 1] + 1, 1) end end -end - -@model function anonymous_in_loop(x, y) - x_0 ~ Normal(μ = 0, σ = 1.0) - x_prev = x_0 - for i in 1:length(x) - x[i] ~ Normal(μ = x_prev + 1, σ = 1.0) - x_prev = x[i] - end - y ~ Normal(μ = x[end], σ = 1.0) -end + @model function node_with_two_anonymous() + x[1] ~ Normal(0, 1) + y[1] ~ Normal(0, 1) + for i in 2:10 + y[i] ~ Normal(0, 1) + x[i] ~ Normal(y[i - 1] + 1, y[i] + 1) + end + end -@model function node_with_only_anonymous() - x[1] ~ Normal(0, 1) - y[1] ~ Normal(0, 1) - for i in 2:10 - y[i] ~ Normal(0, 1) - x[i] ~ Normal(y[i - 1] + 1, 1) + @model function type_arguments(n, x) + local y + for i in 1:n + y[i] ~ Normal(0, 1) + x[i] ~ Normal(y[i], 1) + end end -end - -@model function node_with_two_anonymous() - x[1] ~ Normal(0, 1) - y[1] ~ Normal(0, 1) - for i in 2:10 - y[i] ~ Normal(0, 1) - x[i] ~ Normal(y[i - 1] + 1, y[i] + 1) + + @model function node_with_ambiguous_anonymous() + x[1] ~ Normal(0, 1) + y[1] ~ Normal(0, 1) + for i in 2:10 + x[i] ~ Normal(x[i - 1], 1) + y[i] ~ Normal(x[i] + y[i - 1], 1) + end end -end -@model function type_arguments(n, x) - local y - for i in 1:n - y[i] ~ Normal(0, 1) - x[i] ~ Normal(y[i], 1) + @model function gcv(κ, ω, z, x, y) + log_σ := κ * z + ω + y ~ Normal(x, exp(log_σ)) end -end - -@model function node_with_ambiguous_anonymous() - x[1] ~ Normal(0, 1) - y[1] ~ Normal(0, 1) - for i in 2:10 - x[i] ~ Normal(x[i - 1], 1) - y[i] ~ Normal(x[i] + y[i - 1], 1) + + @model function gcv_lm(y, x_prev, x_next, z, ω, κ) + x_next ~ gcv(x = x_prev, z = z, ω = ω, κ = κ) + y ~ Normal(x_next, 1) end -end -@model function gcv(κ, ω, z, x, y) - log_σ := κ * z + ω - y ~ Normal(x, exp(log_σ)) -end + @model function hgf(y) -@model function gcv_lm(y, x_prev, x_next, z, ω, κ) - x_next ~ gcv(x = x_prev, z = z, ω = ω, κ = κ) - y ~ Normal(x_next, 1) -end + # Specify priors -@model function hgf(y) + ξ ~ Gamma(1, 1) + ω_1 ~ Normal(0, 1) + ω_2 ~ Normal(0, 1) + κ_1 ~ Normal(0, 1) + κ_2 ~ Normal(0, 1) + x_1[1] ~ Normal(0, 1) + x_2[1] ~ Normal(0, 1) + x_3[1] ~ Normal(0, 1) - # Specify priors + # Specify generative model + + for i in 2:(length(y) + 1) + x_3[i] ~ Normal(μ = x_3[i - 1], τ = ξ) + x_2[i] ~ gcv(x = x_2[i - 1], z = x_3[i], ω = ω_2, κ = κ_2) + x_1[i] ~ gcv_lm(x_prev = x_1[i - 1], z = x_2[i], ω = ω_1, κ = κ_1, y = y[i - 1]) + end + end + + @model function prior(a) + a ~ Normal(0, 1) + end - ξ ~ Gamma(1, 1) - ω_1 ~ Normal(0, 1) - ω_2 ~ Normal(0, 1) - κ_1 ~ Normal(0, 1) - κ_2 ~ Normal(0, 1) - x_1[1] ~ Normal(0, 1) - x_2[1] ~ Normal(0, 1) - x_3[1] ~ Normal(0, 1) + @model function broadcastable(μ, σ, out) + out ~ Normal(μ, σ) + end - # Specify generative model + @model function broadcaster() + local μ + local σ + for i in 1:10 + μ[i] ~ Normal(0, 1) + σ[i] ~ Gamma(1, 1) + end + z .~ broadcastable(μ = μ, σ = σ) + out ~ Normal(z[10], 1) + end - for i in 2:(length(y) + 1) - x_3[i] ~ Normal(μ = x_3[i - 1], τ = ξ) - x_2[i] ~ gcv(x = x_2[i - 1], z = x_3[i], ω = ω_2, κ = κ_2) - x_1[i] ~ gcv_lm(x_prev = x_1[i - 1], z = x_2[i], ω = ω_1, κ = κ_1, y = y[i - 1]) + @model function inner_inner(τ, y) + y ~ Normal(τ[1], τ[2]) end -end - -@model function prior(a) - a ~ Normal(0, 1) -end - -@model function broadcastable(μ, σ, out) - out ~ Normal(μ, σ) -end - -@model function broadcaster() - local μ - local σ - for i in 1:10 - μ[i] ~ Normal(0, 1) - σ[i] ~ Gamma(1, 1) + + @model function inner(θ, α) + α ~ inner_inner(τ = θ) end - z .~ broadcastable(μ = μ, σ = σ) - out ~ Normal(z[10], 1) -end - -@model function inner_inner(τ, y) - y ~ Normal(τ[1], τ[2]) -end - -@model function inner(θ, α) - α ~ inner_inner(τ = θ) -end - -@model function outer() - local w - for i in 1:5 - w[i] ~ Gamma(1, 1) + + @model function outer() + local w + for i in 1:5 + w[i] ~ Gamma(1, 1) + end + y ~ inner(θ = w[2:3]) end - y ~ inner(θ = w[2:3]) -end - -@model function multidim_array() - local x - for i in 1:3 - x[i, 1] ~ Normal(0, 1) - for j in 2:3 - x[i, j] ~ Normal(x[i, j - 1], 1) + + @model function multidim_array() + local x + for i in 1:3 + x[i, 1] ~ Normal(0, 1) + for j in 2:3 + x[i, j] ~ Normal(x[i, j - 1], 1) + end end end -end -@model function child_model(in, out) - σ ~ Gamma(1, 1) - out ~ Normal(in, σ) -end + @model function child_model(in, out) + σ ~ Gamma(1, 1) + out ~ Normal(in, σ) + end -@model function parent_model() - x[1] ~ Normal(0, 1) - for i in 2:100 - x[i] ~ child_model(in = x[i - 1]) + @model function parent_model() + x[1] ~ Normal(0, 1) + for i in 2:100 + x[i] ~ child_model(in = x[i - 1]) + end end -end - -@model function model_with_default_constraints(a, b, c, d) - a := b + c - d ~ Normal(a, 1) -end - -@model function contains_default_constraints() - a ~ Normal(0, 1) - b ~ Normal(0, 1) - c ~ Normal(0, 1) - for i in 1:10 - d[i] ~ model_with_default_constraints(a = a, b = b, c = c) + + @model function model_with_default_constraints(a, b, c, d) + a := b + c + d ~ Normal(a, 1) + end + + @model function contains_default_constraints() + a ~ Normal(0, 1) + b ~ Normal(0, 1) + c ~ Normal(0, 1) + for i in 1:10 + d[i] ~ model_with_default_constraints(a = a, b = b, c = c) + end end -end -GraphPPL.default_constraints(::typeof(model_with_default_constraints)) = @constraints( - begin - q(a, d) = q(a)q(d) + GraphPPL.default_constraints(::typeof(model_with_default_constraints)) = @constraints( + begin + q(a, d) = q(a)q(d) + end + ) + + @model function mixture() + m1 ~ Normal(0, 1) + m2 ~ Normal(0, 1) + m3 ~ Normal(0, 1) + m4 ~ Normal(0, 1) + t1 ~ Normal(0, 1) + t2 ~ Normal(0, 1) + t3 ~ Normal(0, 1) + t4 ~ Normal(0, 1) + y ~ Mixture(m = [m1, m2, m3, m4], τ = [t1, t2, t3, t4]) end -) - -@model function mixture() - m1 ~ Normal(0, 1) - m2 ~ Normal(0, 1) - m3 ~ Normal(0, 1) - m4 ~ Normal(0, 1) - t1 ~ Normal(0, 1) - t2 ~ Normal(0, 1) - t3 ~ Normal(0, 1) - t4 ~ Normal(0, 1) - y ~ Mixture(m = [m1, m2, m3, m4], τ = [t1, t2, t3, t4]) -end - -@model function filled_matrix_model() - local x - local y - for i in 1:3 - for j in 1:3 - y[i, j] ~ Gamma(1, 1) - x[i, j] ~ Normal(0, y[i, j]) + + @model function filled_matrix_model() + local x + local y + for i in 1:3 + for j in 1:3 + y[i, j] ~ Gamma(1, 1) + x[i, j] ~ Normal(0, y[i, j]) + end end end -end -@model function coin_toss_model() - θ ~ Beta(1, 2) - for i in 1:5 - y[i] ~ Bernoulli(θ) + @model function coin_toss_model() + θ ~ Beta(1, 2) + for i in 1:5 + y[i] ~ Bernoulli(θ) + end end -end - -const ModelsInTheZooWithoutArguments = [ - coin_toss_model, - simple_model, - vector_model, - tensor_model, - node_with_only_anonymous, - node_with_two_anonymous, - node_with_ambiguous_anonymous, - outer, - multidim_array, - parent_model, - contains_default_constraints, - mixture, - filled_matrix_model -] - -export ModelsInTheZooWithoutArguments - -end -end - -using GraphPPL, MacroTools, Static, Distributions -using .TestUtils + + const ModelsInTheZooWithoutArguments = [ + coin_toss_model, + simple_model, + vector_model, + tensor_model, + node_with_only_anonymous, + node_with_two_anonymous, + node_with_ambiguous_anonymous, + outer, + multidim_array, + parent_model, + contains_default_constraints, + mixture, + filled_matrix_model + ] + + export ModelsInTheZooWithoutArguments +end # TestUtils testmodule From 7c486bbf7aca9d0e84d31d78a91ba7f7e471de0f Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 7 May 2025 11:21:22 +0200 Subject: [PATCH 5/6] Fix model_operations_tests.jl --- test/model/model_operations_tests.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/model/model_operations_tests.jl b/test/model/model_operations_tests.jl index 87a41762..8f5db588 100644 --- a/test/model/model_operations_tests.jl +++ b/test/model/model_operations_tests.jl @@ -264,6 +264,7 @@ end end @testitem "make_node!(::Atomic)" setup = [TestUtils] begin + using Distributions using Graphs, BitSetTuples import GraphPPL: getcontext, @@ -543,19 +544,19 @@ end ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) - make_node!(model, ctx, options, TestUtils.ModelZoo.prior, proxylabel(:x, xref, nothing), ()) + make_node!(model, ctx, options, TestUtils.prior, proxylabel(:x, xref, nothing), ()) @test nv(model) == 4 - @test ctx[TestUtils.ModelZoo.prior, 1][:a] == proxylabel(:x, xref, nothing) + @test ctx[TestUtils.prior, 1][:a] == proxylabel(:x, xref, nothing) #test make node for other composite models model = TestUtils.create_test_model() ctx = getcontext(model) options = NodeCreationOptions() xref = getorcreate!(model, ctx, :x, nothing) - @test_throws ErrorException make_node!(model, ctx, options, TestUtils.ModelZoo.gcv, proxylabel(:x, xref, nothing), (0, 1)) + @test_throws ErrorException make_node!(model, ctx, options, TestUtils.gcv, proxylabel(:x, xref, nothing), (0, 1)) # test make node of broadcastable composite model - model = create_model(TestUtils.ModelZoo.broadcaster()) + model = create_model(TestUtils.broadcaster()) @test nv(model) == 103 end From 420523742a157a5bde4a952920b7cf83b9677981 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 7 May 2025 13:51:58 +0200 Subject: [PATCH 6/6] Refactor constraints engine tests --- .../variational_constraints_engine.jl | 4 +- .../components/bitset_operations_tests.jl | 18 + .../constraint_application_tests.jl | 1074 ++++++++++++ .../components/constraint_types_tests.jl | 113 ++ .../components/constraints_container_tests.jl | 138 ++ .../components/constraints_defaults_tests.jl | 87 + .../components/range_types_tests.jl | 202 +++ .../components/resolvers_tests.jl | 327 ++++ .../components/utils_tests.jl | 32 + .../variational_constraints_tests.jl | 0 .../variational_constraints_macro_tests.jl | 0 .../variational_constraints_engine_tests.jl | 1501 ----------------- 12 files changed, 1993 insertions(+), 1503 deletions(-) create mode 100644 test/plugins/variational_constraints/components/bitset_operations_tests.jl create mode 100644 test/plugins/variational_constraints/components/constraint_application_tests.jl create mode 100644 test/plugins/variational_constraints/components/constraint_types_tests.jl create mode 100644 test/plugins/variational_constraints/components/constraints_container_tests.jl create mode 100644 test/plugins/variational_constraints/components/constraints_defaults_tests.jl create mode 100644 test/plugins/variational_constraints/components/range_types_tests.jl create mode 100644 test/plugins/variational_constraints/components/resolvers_tests.jl create mode 100644 test/plugins/variational_constraints/components/utils_tests.jl rename test/plugins/variational_constraints/{ => integration}/variational_constraints_tests.jl (100%) rename test/plugins/variational_constraints/{ => macro}/variational_constraints_macro_tests.jl (100%) delete mode 100644 test/plugins/variational_constraints/variational_constraints_engine_tests.jl diff --git a/src/plugins/variational_constraints/variational_constraints_engine.jl b/src/plugins/variational_constraints/variational_constraints_engine.jl index e4ddd7ca..acf25237 100644 --- a/src/plugins/variational_constraints/variational_constraints_engine.jl +++ b/src/plugins/variational_constraints/variational_constraints_engine.jl @@ -818,8 +818,8 @@ function resolve(model::Model, context::Context, constraint::FactorizationConstr return ResolvedFactorizationConstraint(ResolvedConstraintLHS(lhs), rhs) end -function is_factorized(nodedata::NodeData) - properties = getproperties(nodedata)::VariableNodeProperties +function is_factorized(nodedata::AbstractNodeData) + properties = getproperties(nodedata) if is_constant(properties) return true end diff --git a/test/plugins/variational_constraints/components/bitset_operations_tests.jl b/test/plugins/variational_constraints/components/bitset_operations_tests.jl new file mode 100644 index 00000000..6666d109 --- /dev/null +++ b/test/plugins/variational_constraints/components/bitset_operations_tests.jl @@ -0,0 +1,18 @@ +@testitem "mean_field_constraint!" begin + using BitSetTuples + import GraphPPL: mean_field_constraint! + + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5))) == ((1,), (2,), (3,), (4,), (5,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(10))) == ((1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)) + + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(1), 1)) == ((1,),) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), 3)) == + ((1, 2, 4, 5), (1, 2, 4, 5), (3,), (1, 2, 4, 5), (1, 2, 4, 5)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(1), (1,))) == ((1,),) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(2), (1,))) == ((1,), (2,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(2), (2,))) == ((1,), (2,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 2))) == ((1,), (2,), (3, 4, 5), (3, 4, 5), (3, 4, 5)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 3, 5))) == ((1,), (2, 4), (3,), (2, 4), (5,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 2, 3, 4, 5))) == ((1,), (2,), (3,), (4,), (5,)) + @test_throws BoundsError mean_field_constraint!(BoundedBitSetTuple(5), (1, 2, 3, 4, 5, 6)) == ((1,), (2,), (3,), (4,), (5,)) +end \ No newline at end of file diff --git a/test/plugins/variational_constraints/components/constraint_application_tests.jl b/test/plugins/variational_constraints/components/constraint_application_tests.jl new file mode 100644 index 00000000..197796da --- /dev/null +++ b/test/plugins/variational_constraints/components/constraint_application_tests.jl @@ -0,0 +1,1074 @@ +@testitem "is_factorized" setup = [TestUtils] begin + import GraphPPL: is_factorized, AbstractNodeData, AbstractNodeProperties, getproperties, getlink + + mutable struct MockNodeData{T} <: AbstractNodeData + properties::T + extras::Dict{Symbol, Any} + end + + # Mock implementation to test pure functions without requiring model setup + struct MockVariableNodeProperties <: AbstractNodeProperties + is_constant::Bool + link::Union{Nothing, Vector{MockNodeData}} + end + + Base.getproperty(p::MockVariableNodeProperties, name::Symbol) = + if name === :link + getfield(p, :link) + else + getfield(p, name) + end + + getproperties(data::MockNodeData) = data.properties + GraphPPL.getextra(data::MockNodeData, key::Symbol) = get(data.extras, key, nothing) + GraphPPL.hasextra(data::MockNodeData, key::Symbol) = haskey(data.extras, key) + GraphPPL.getlink(props::MockVariableNodeProperties) = props.link + GraphPPL.is_constant(props::MockVariableNodeProperties) = props.is_constant + # Test 1: Basic constant variable + node1_props = MockVariableNodeProperties(true, nothing) + node1 = MockNodeData{MockVariableNodeProperties}(node1_props, Dict{Symbol, Any}()) + @test is_factorized(node1) + + # Test 2: Variable with factorized flag + node2_props = MockVariableNodeProperties(false, nothing) + node2 = MockNodeData{MockVariableNodeProperties}(node2_props, Dict{Symbol, Any}(:factorized => true)) + @test is_factorized(node2) + + # Test 3: Variable without factorized flag + node3_props = MockVariableNodeProperties(false, nothing) + node3 = MockNodeData{MockVariableNodeProperties}(node3_props, Dict{Symbol, Any}()) + @test !is_factorized(node3) + + # Test 4: Variable with factorized links + node4_link1_props = MockVariableNodeProperties(true, nothing) + node4_link1 = MockNodeData{MockVariableNodeProperties}(node4_link1_props, Dict{Symbol, Any}()) + node4_link2_props = MockVariableNodeProperties(false, nothing) + node4_link2 = MockNodeData{MockVariableNodeProperties}(node4_link2_props, Dict{Symbol, Any}(:factorized => true)) + node4_links = [node4_link1, node4_link2] + node4_props = MockVariableNodeProperties(false, node4_links) + node4 = MockNodeData{MockVariableNodeProperties}(node4_props, Dict{Symbol, Any}()) + @test is_factorized(node4) + + # Test 5: Variable with non-factorized links + node5_link1_props = MockVariableNodeProperties(true, nothing) + node5_link1 = MockNodeData{MockVariableNodeProperties}(node5_link1_props, Dict{Symbol, Any}()) + node5_link2_props = MockVariableNodeProperties(false, nothing) + node5_link2 = MockNodeData{MockVariableNodeProperties}(node5_link2_props, Dict{Symbol, Any}()) + node5_links = [node5_link1, node5_link2] + node5_props = MockVariableNodeProperties(false, node5_links) + node5 = MockNodeData{MockVariableNodeProperties}(node5_props, Dict{Symbol, Any}()) + @test !is_factorized(node5) + + # Test 6: Variable with mixed factorized/non-factorized links + node6_link1_props = MockVariableNodeProperties(true, nothing) + node6_link1 = MockNodeData{MockVariableNodeProperties}(node6_link1_props, Dict{Symbol, Any}()) + node6_link2_props = MockVariableNodeProperties(false, nothing) + node6_link2 = MockNodeData{MockVariableNodeProperties}(node6_link2_props, Dict{Symbol, Any}()) + node6_links = [node6_link1, node6_link2] + node6_props = MockVariableNodeProperties(false, node6_links) + node6 = MockNodeData{MockVariableNodeProperties}(node6_props, Dict{Symbol, Any}()) + @test !is_factorized(node6) +end + +@testitem "is_factorized || is_constant" setup = [TestUtils] begin + import GraphPPL: + is_constant, is_factorized, create_model, with_plugins, getcontext, getproperties, getorcreate!, variable_nodes, NodeCreationOptions + + m = TestUtils.create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + ctx = getcontext(m) + x = getorcreate!(m, ctx, NodeCreationOptions(kind = :data, factorized = true), :x, nothing) + @test is_factorized(m[x]) + + for model_fn in TestUtils.ModelsInTheZooWithoutArguments + model = create_model(with_plugins(model_fn(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + for label in variable_nodes(model) + nodedata = model[label] + if is_constant(getproperties(nodedata)) + @test is_factorized(nodedata) + else + @test !is_factorized(nodedata) + end + end + end +end + +@testitem "mean_field_constraint!" setup = [TestUtils] begin + import GraphPPL: mean_field_constraint!, BoundedBitSetTuple, contents + + # Test 1: Basic mean field constraint + bitset = BoundedBitSetTuple(3) + fill!(contents(bitset), true) + + mean_field_constraint!(bitset) + for i in 1:3 + for j in 1:3 + if i == j + @test bitset[i, j] + else + @test !bitset[i, j] + end + end + end + + # Test 2: Mean field constraint with specific index + bitset = BoundedBitSetTuple(3) + fill!(contents(bitset), true) + + mean_field_constraint!(bitset, 2) + # Check that row/column 2 is all zeros except for [2,2] + for i in 1:3 + for j in 1:3 + if i == 2 && j == 2 + @test bitset[i, j] + elseif i == 2 || j == 2 + @test !bitset[i, j] + else + @test bitset[i, j] + end + end + end + + # Test 3: Mean field constraint with multiple indices + bitset = BoundedBitSetTuple(4) + fill!(contents(bitset), true) + + mean_field_constraint!(bitset, (1, 3)) + for i in 1:4 + for j in 1:4 + if (i == 1 && j == 1) || (i == 3 && j == 3) + @test bitset[i, j] + elseif i == 1 || j == 1 || i == 3 || j == 3 + @test !bitset[i, j] + else + @test bitset[i, j] + end + end + end +end + +@testitem "is_valid_partition" setup = [TestUtils] begin + import GraphPPL: is_valid_partition + + # Test valid partitions + valid1 = [[1, 0, 0], [0, 1, 1]] + @test is_valid_partition(valid1) + + valid2 = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + @test is_valid_partition(valid2) + + # Test invalid partitions + + # Element missing from any partition + invalid1 = [[1, 0, 0], [0, 1, 0]] + @test !is_valid_partition(invalid1) + + # Element in multiple partitions + invalid2 = [[1, 1, 0], [0, 1, 1]] + @test !is_valid_partition(invalid2) + + # Empty partition set + invalid3 = Vector{Int}[] + @test_broken !is_valid_partition(invalid3) +end + +@testitem "materialize_is_factorized_neighbors!" setup = [TestUtils] begin + import GraphPPL: materialize_is_factorized_neighbors!, BoundedBitSetTuple, NodeData, is_factorized, AbstractNodeData + + # Mock implementation + mutable struct MockNodeData <: AbstractNodeData + factorized::Bool + end + + GraphPPL.is_factorized(n::MockNodeData) = n.factorized + + # Test 1: All factorized neighbors + bitset = BoundedBitSetTuple(3) + fill!(bitset.contents, true) + neighbors = [MockNodeData(true), MockNodeData(true), MockNodeData(true)] + + materialize_is_factorized_neighbors!(bitset, neighbors) + for i in 1:3 + for j in 1:3 + if i == j + @test bitset[i, j] + else + @test !bitset[i, j] + end + end + end + + # Test 2: Mixed factorized/non-factorized neighbors + bitset = BoundedBitSetTuple(3) + fill!(bitset.contents, true) + neighbors = [MockNodeData(true), MockNodeData(false), MockNodeData(true)] + + materialize_is_factorized_neighbors!(bitset, neighbors) + for i in 1:3 + for j in 1:3 + if (i == j) || (i != 1 && i != 3 && j != 1 && j != 3) + @test bitset[i, j] + elseif (i == 1 || i == 3 || j == 1 || j == 3) && i != j + @test !bitset[i, j] + end + end + end + + # Test 3: No factorized neighbors + bitset = BoundedBitSetTuple(3) + fill!(bitset.contents, true) + neighbors = [MockNodeData(false), MockNodeData(false), MockNodeData(false)] + + materialize_is_factorized_neighbors!(bitset, neighbors) + for i in 1:3 + for j in 1:3 + @test bitset[i, j] + end + end +end + +@testitem "convert_to_bitsets" setup = [TestUtils] begin + using BitSetTuples + import GraphPPL: + create_model, + with_plugins, + ResolvedFactorizationConstraint, + ResolvedConstraintLHS, + ResolvedFactorizationConstraintEntry, + ResolvedIndexedVariable, + SplittedRange, + CombinedRange, + apply_constraints!, + getproperties + + model = create_model(with_plugins(TestUtils.outer(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + context = GraphPPL.getcontext(model) + inner_context = context[TestUtils.inner, 1] + inner_inner_context = inner_context[TestUtils.inner_inner, 1] + + normal_node = inner_inner_context[TestUtils.NormalMeanVariance, 1] + neighbors = model[GraphPPL.neighbors(model, normal_node)] + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context),)), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 2, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 3, context),)) + ) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2), (1, 3)) + end + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 4:5, context),)), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 4, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 5, context),)) + ) + ) + @test !GraphPPL.is_applicable(neighbors, constraint) + end + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context),)), + (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, SplittedRange(2, 3), context),)),) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2), (1, 3)) + end + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context), ResolvedIndexedVariable(:y, nothing, context))), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, SplittedRange(2, 3), context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, context),)) + ) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1,), (2,), (3,)) + end + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context), ResolvedIndexedVariable(:y, nothing, context))), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 2, context),)), + ResolvedFactorizationConstraintEntry(( + ResolvedIndexedVariable(:w, 3, context), ResolvedIndexedVariable(:y, nothing, context) + )) + ) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 3), (2,), (1, 3)) + end + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context), ResolvedIndexedVariable(:y, nothing, context))), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, CombinedRange(2, 3), context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, context),)) + ) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1,), (2, 3), (2, 3)) + end + + model = create_model(with_plugins(TestUtils.multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + context = GraphPPL.getcontext(model) + normal_node = context[TestUtils.NormalMeanVariance, 5] + neighbors = model[GraphPPL.neighbors(model, normal_node)] + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:x, nothing, context),),), + (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, SplittedRange(1, 9), context),)),) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 3), (2, 3), (1, 2, 3)) + end + + model = create_model(with_plugins(TestUtils.multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + context = GraphPPL.getcontext(model) + normal_node = context[TestUtils.NormalMeanVariance, 5] + neighbors = model[GraphPPL.neighbors(model, normal_node)] + + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:x, nothing, context),),), + (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, CombinedRange(1, 9), context),)),) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2, 3), (1, 2, 3)) + end + + # Test ResolvedFactorizationConstraints over anonymous variables + + model = create_model( + with_plugins(TestUtils.node_with_only_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + ) + context = GraphPPL.getcontext(model) + normal_node = context[TestUtils.NormalMeanVariance, 6] + neighbors = model[GraphPPL.neighbors(model, normal_node)] + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), + (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, SplittedRange(1, 10), context),)),) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + end + + # Test ResolvedFactorizationConstraints over multiple anonymous variables + model = create_model( + with_plugins(TestUtils.node_with_two_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + ) + context = GraphPPL.getcontext(model) + normal_node = context[TestUtils.NormalMeanVariance, 6] + neighbors = model[GraphPPL.neighbors(model, normal_node)] + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), + (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, SplittedRange(1, 10), context),)),) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + + # This shouldn't throw and resolve because both anonymous variables are 1-to-1 and referenced by constraint. + @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2), (1, 3)) + end + + # Test ResolvedFactorizationConstraints over ambiguous anonymouys variables + model = create_model( + with_plugins(TestUtils.node_with_ambiguous_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) + ) + context = GraphPPL.getcontext(model) + normal_node = last(filter(GraphPPL.as_node(TestUtils.NormalMeanVariance), model)) + neighbors = model[GraphPPL.neighbors(model, normal_node)] + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), + (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, SplittedRange(1, 10), context),)),) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + + # This test should throw since we cannot resolve the constraint + @test_throws GraphPPL.UnresolvableFactorizationConstraintError GraphPPL.convert_to_bitsets( + model, normal_node, neighbors, constraint + ) + end + + # Test ResolvedFactorizationConstraint with a Mixture node + model = create_model(with_plugins(TestUtils.mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) + context = GraphPPL.getcontext(model) + mixture_node = first(filter(GraphPPL.as_node(TestUtils.Mixture), model)) + neighbors = model[GraphPPL.neighbors(model, mixture_node)] + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS(( + ResolvedIndexedVariable(:m1, nothing, context), + ResolvedIndexedVariable(:m2, nothing, context), + ResolvedIndexedVariable(:m3, nothing, context), + ResolvedIndexedVariable(:m4, nothing, context) + ),), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m1, nothing, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m2, nothing, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m3, nothing, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m4, nothing, context),)) + ) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test tupled_contents(GraphPPL.convert_to_bitsets(model, mixture_node, neighbors, constraint)) == tupled_contents( + BitSetTuple([ + collect(1:9), + [1, 2, 6, 7, 8, 9], + [1, 3, 6, 7, 8, 9], + [1, 4, 6, 7, 8, 9], + [1, 5, 6, 7, 8, 9], + collect(1:9), + collect(1:9), + collect(1:9), + collect(1:9) + ]) + ) + end +end + +@testitem "Application of MarginalFormConstraint" setup = [TestUtils] begin + import GraphPPL: + create_model, + MarginalFormConstraint, + IndexedVariable, + apply_constraints!, + getextra, + hasextra, + VariationalConstraintsMarginalFormConstraintKey + + struct ArbitraryFunctionalFormConstraint end + + # Test saving of MarginalFormConstraint in single variable + model = create_model(TestUtils.simple_model()) + context = GraphPPL.getcontext(model) + constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) + apply_constraints!(model, context, constraint) + for node in filter(GraphPPL.as_variable(:x), model) + @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() + end + + # Test saving of MarginalFormConstraint in multiple variables + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) + apply_constraints!(model, context, constraint) + for node in filter(GraphPPL.as_variable(:x), model) + @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() + end + for node in filter(GraphPPL.as_variable(:y), model) + @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) + end + + # Test saving of MarginalFormConstraint in single variable in array + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MarginalFormConstraint(IndexedVariable(:x, 1), ArbitraryFunctionalFormConstraint()) + apply_constraints!(model, context, constraint) + applied_node = context[:x][1] + for node in filter(GraphPPL.as_variable(:x), model) + if node == applied_node + @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() + else + @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) + end + end +end + +@testitem "Application of MessageFormConstraint" setup = [TestUtils] begin + import GraphPPL: + create_model, + MessageFormConstraint, + IndexedVariable, + apply_constraints!, + hasextra, + getextra, + VariationalConstraintsMessagesFormConstraintKey + + struct ArbitraryMessageFormConstraint end + + # Test saving of MessageFormConstraint in single variable + model = create_model(TestUtils.simple_model()) + context = GraphPPL.getcontext(model) + constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) + node = first(filter(GraphPPL.as_variable(:x), model)) + apply_constraints!(model, context, constraint) + @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() + + # Test saving of MessageFormConstraint in multiple variables + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) + apply_constraints!(model, context, constraint) + for node in filter(GraphPPL.as_variable(:x), model) + @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() + end + for node in filter(GraphPPL.as_variable(:y), model) + @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) + end + + # Test saving of MessageFormConstraint in single variable in array + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MessageFormConstraint(IndexedVariable(:x, 1), ArbitraryMessageFormConstraint()) + apply_constraints!(model, context, constraint) + applied_node = context[:x][1] + for node in filter(GraphPPL.as_variable(:x), model) + if node == applied_node + @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() + else + @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) + end + end +end + +@testitem "save constraints with constants via `mean_field_constraint!`" setup = [TestUtils] begin + using BitSetTuples + import GraphPPL: + create_model, + with_plugins, + getextra, + mean_field_constraint!, + getproperties, + VariationalConstraintsPlugin, + PluginsCollection, + VariationalConstraintsFactorizationBitSetKey + + model = create_model(with_plugins(TestUtils.simple_model(), GraphPPL.PluginsCollection(VariationalConstraintsPlugin()))) + ctx = GraphPPL.getcontext(model) + + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 1)) == ((1,), (2, 3), (2, 3)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 2)) == ((1, 3), (2,), (1, 3)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 3)) == ((1, 2), (1, 2), (3,)) + + node = ctx[TestUtils.NormalMeanVariance, 2] + constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2, 3), (2, 3)) + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 2))) == ((1,), (2,), (3,)) + + node = ctx[TestUtils.NormalMeanVariance, 1] + constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) + # Here it is the mean field because the original model has `x ~ Normal(0, 1)` and `0` and `1` are constants + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2,), (3,)) +end + +@testitem "materialize_constraints!(:Model, ::NodeLabel, ::FactorNodeData)" setup = [TestUtils] begin + using BitSetTuples + import GraphPPL: + create_model, + with_plugins, + materialize_constraints!, + EdgeLabel, + get_constraint_names, + getproperties, + getextra, + setextra!, + VariationalConstraintsPlugin + + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + # Test 1: Test materialize with a Full Factorization constraint + node = ctx[TestUtils.NormalMeanVariance, 2] + + # Force overwrite the bitset and the constraints + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(3)) + materialize_constraints!(model, node) + @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1, 2, 3),) + + node = ctx[TestUtils.NormalMeanVariance, 1] + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2,), (3,)))) + materialize_constraints!(model, node) + @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2,), (3,)) + + # Test 2: Test materialize with an applied constraint + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2, 3), (2, 3)))) + materialize_constraints!(model, node) + @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2, 3)) + + # # Test 3: Check that materialize_constraints! throws if the constraint is not a valid partition + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (3,), (1, 3)))) + @test_throws ErrorException materialize_constraints!(model, node) + + # Test 4: Check that materialize_constraints! throws if the constraint is not a valid partition + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (1,), (3,)))) + @test_throws ErrorException materialize_constraints!(model, node) +end + +@testitem "Apply constraints to matrix variables" setup = [TestUtils] begin + using Distributions + import GraphPPL: + getproperties, + PluginsCollection, + VariationalConstraintsPlugin, + getextra, + getcontext, + with_plugins, + create_model, + NotImplementedError, + @model + + # Test for constraints applied to a model with matrix variables + c = @constraints begin + q(x, y) = q(x)q(y) + end + model = create_model(with_plugins(TestUtils.filled_matrix_model(), PluginsCollection(VariationalConstraintsPlugin(c)))) + + for node in filter(TestUtils.as_node(TestUtils.Normal), model) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + end + + @model function uneven_matrix() + local prec + local y + for i in 1:3 + for j in 1:3 + prec[i, j] ~ Gamma(1, 1) + y[i, j] ~ Normal(0, prec[i, j]) + end + end + prec[2, 4] ~ Gamma(1, 1) + y[2, 4] ~ Normal(0, prec[2, 4]) + end + constraints_1 = @constraints begin + q(prec, y) = q(prec)q(y) + end + + model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_1)))) + for node in filter(as_node(Normal), model) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + end + + constraints_2 = @constraints begin + q(prec[1], y) = q(prec[1])q(y) + end + + model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_2)))) + ctx = getcontext(model) + for node in filter(as_node(Normal), model) + if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:prec][1]) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + else + @test getextra(model[node], :factorization_constraint_indices) == ([1, 3], [2]) + end + end + + constraints_3 = @constraints begin + q(prec[2], y) = q(prec[2])q(y) + end + + model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_3)))) + ctx = getcontext(model) + for node in filter(as_node(Normal), model) + if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:prec][2]) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + else + @test getextra(model[node], :factorization_constraint_indices) == ([1, 3], [2]) + end + end + + constraints_4 = @constraints begin + q(prec[1, 3], y) = q(prec[1, 3])q(y) + end + model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_4)))) + ctx = getcontext(model) + for node in filter(as_node(Normal), model) + if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:prec][1, 3]) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + else + @test getextra(model[node], :factorization_constraint_indices) == ([1, 3], [2]) + end + end + + constraints_5 = @constraints begin + q(prec, y) = q(prec[(1, 1):(3, 3)])q(y) + end + @test_throws GraphPPL.UnresolvableFactorizationConstraintError local model = create_model( + with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_5))) + ) + + @test_throws GraphPPL.NotImplementedError local constraints_5 = @constraints begin + q(prec, y) = q(prec[(1, 1)]) .. q(prec[(3, 3)])q(y) + end + + @model function inner_matrix(y, mat) + for i in 1:2 + for j in 1:2 + mat[i, j] ~ Normal(0, 1) + end + end + y ~ Normal(mat[1, 1], mat[2, 2]) + end + + @model function outer_matrix() + local mat + for i in 1:3 + for j in 1:3 + mat[i, j] ~ Normal(0, 1) + end + end + y ~ inner_matrix(mat = mat[2:3, 2:3]) + end + + constraints_7 = @constraints begin + for q in inner_matrix + q(mat, y) = q(mat)q(y) + end + end + @test_throws GraphPPL.UnresolvableFactorizationConstraintError local model = create_model( + with_plugins(outer_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_7))) + ) + + @model function mixed_v(y, v) + for i in 1:3 + v[i] ~ Normal(0, 1) + end + y ~ Normal(v[1], v[2]) + end + + @model function mixed_m() + v1 ~ Normal(0, 1) + v2 ~ Normal(0, 1) + v3 ~ Normal(0, 1) + y ~ mixed_v(v = [v1, v2, v3]) + end + + constraints_8 = @constraints begin + for q in mixed_v + q(v, y) = q(v)q(y) + end + end + + @test_throws GraphPPL.UnresolvableFactorizationConstraintError local model = create_model( + with_plugins(mixed_m(), PluginsCollection(VariationalConstraintsPlugin(constraints_8))) + ) + + @model function ordinary_v() + local v + for i in 1:3 + v[i] ~ Normal(0, 1) + end + y ~ Normal(v[1], v[2]) + end + + constraints_9 = @constraints begin + q(v[1:2]) = q(v[1])q(v[2]) + q(v, y) = q(v)q(y) + end + + model = create_model(with_plugins(ordinary_v(), PluginsCollection(VariationalConstraintsPlugin(constraints_9)))) + ctx = getcontext(model) + for node in filter(as_node(Normal), model) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + end + + @model function operate_slice(y, v) + local v + for i in 1:3 + v[i] ~ Normal(0, 1) + end + y ~ Normal(v[1], v[2]) + end + + @model function pass_slice() + local m + for i in 1:3 + for j in 1:3 + m[i, j] ~ Normal(0, 1) + end + end + v = GraphPPL.ResizableArray(m[:, 1]) + y ~ operate_slice(v = v) + end + + constraints_10 = @constraints begin + for q in operate_slice + q(v, y) = q(v[begin]) .. q(v[end])q(y) + end + end + + @test_throws GraphPPL.NotImplementedError local model = create_model( + with_plugins(pass_slice(), PluginsCollection(VariationalConstraintsPlugin(constraints_10))) + ) + + constraints_11 = @constraints begin + q(x, z, y) = q(z)(q(x[begin + 1]) .. q(x[end]))(q(y[begin + 1]) .. q(y[end])) + end + + model = create_model(with_plugins(TestUtils.vector_model(), PluginsCollection(VariationalConstraintsPlugin(constraints_11)))) + + ctx = getcontext(model) + for node in filter(as_node(Normal), model) + if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:y][1]) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2, 3]) + else + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + end + end + + constraints_12 = @constraints begin + q(mat) = q(mat[begin]) .. q(mat[end]) + end + @test_throws NotImplementedError local model = create_model( + with_plugins(outer_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_12))) + ) + + @model function some_matrix() + local mat + for i in 1:3 + for j in 1:3 + mat[i, j] ~ Normal(0, 1) + end + end + y ~ Normal(mat[1, 1], mat[2, 2]) + end + + constraints_13 = @constraints begin + q(mat) = MeanField() + q(mat, y) = q(mat)q(y) + end + model = create_model(with_plugins(some_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_13)))) + ctx = getcontext(model) + for node in filter(as_node(Normal), model) + @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) + end +end + +@testitem "Test factorization constraint with automatically folded data/const variables" begin + using Distributions + import GraphPPL: + getproperties, + PluginsCollection, + VariationalConstraintsPlugin, + NodeCreationOptions, + getorcreate!, + with_plugins, + create_model, + getextra, + VariationalConstraintsFactorizationIndicesKey, + @model + + @model function fold_datavars(f, a, b) + y ~ Normal(f(f(a, b), f(a, b)), 0.5) + end + + @testset for f in (+, *, (a, b) -> a + b, (a, b) -> a * b), case in (1, 2, 3) + model = create_model(with_plugins(fold_datavars(f = f), PluginsCollection(VariationalConstraintsPlugin()))) do model, ctx + if case === 1 + return ( + a = getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 0.35), :a, nothing), + b = getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 0.54), :b, nothing) + ) + elseif case === 2 + return ( + a = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :a, nothing), + b = getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 0.54), :b, nothing) + ) + elseif case === 3 + return ( + a = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :a, nothing), + b = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :b, nothing) + ) + end + end + + @test length(collect(filter(as_node(Normal), model))) === 1 + @test length(collect(filter(as_node(f), model))) === 0 + + foreach(collect(filter(as_node(Normal), model))) do node + @test getextra(model[node], VariationalConstraintsFactorizationIndicesKey) == ([1], [2], [3]) + end + end +end + +@testitem "Application of MarginalFormConstraint" setup = [TestUtils] begin + import GraphPPL: + create_model, + MarginalFormConstraint, + IndexedVariable, + apply_constraints!, + getextra, + hasextra, + VariationalConstraintsMarginalFormConstraintKey + + struct ArbitraryFunctionalFormConstraint end + + # Test saving of MarginalFormConstraint in single variable + model = create_model(TestUtils.simple_model()) + context = GraphPPL.getcontext(model) + constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) + apply_constraints!(model, context, constraint) + for node in filter(GraphPPL.as_variable(:x), model) + @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() + end + + # Test saving of MarginalFormConstraint in multiple variables + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) + apply_constraints!(model, context, constraint) + for node in filter(GraphPPL.as_variable(:x), model) + @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() + end + for node in filter(GraphPPL.as_variable(:y), model) + @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) + end + + # Test saving of MarginalFormConstraint in single variable in array + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MarginalFormConstraint(IndexedVariable(:x, 1), ArbitraryFunctionalFormConstraint()) + apply_constraints!(model, context, constraint) + applied_node = context[:x][1] + for node in filter(GraphPPL.as_variable(:x), model) + if node == applied_node + @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() + else + @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) + end + end +end + +@testitem "Application of MessageFormConstraint" setup = [TestUtils] begin + import GraphPPL: + create_model, + MessageFormConstraint, + IndexedVariable, + apply_constraints!, + hasextra, + getextra, + VariationalConstraintsMessagesFormConstraintKey + + struct ArbitraryMessageFormConstraint end + + # Test saving of MessageFormConstraint in single variable + model = create_model(TestUtils.simple_model()) + context = GraphPPL.getcontext(model) + constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) + node = first(filter(GraphPPL.as_variable(:x), model)) + apply_constraints!(model, context, constraint) + @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() + + # Test saving of MessageFormConstraint in multiple variables + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) + apply_constraints!(model, context, constraint) + for node in filter(GraphPPL.as_variable(:x), model) + @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() + end + for node in filter(GraphPPL.as_variable(:y), model) + @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) + end + + # Test saving of MessageFormConstraint in single variable in array + model = create_model(TestUtils.vector_model()) + context = GraphPPL.getcontext(model) + constraint = MessageFormConstraint(IndexedVariable(:x, 1), ArbitraryMessageFormConstraint()) + apply_constraints!(model, context, constraint) + applied_node = context[:x][1] + for node in filter(GraphPPL.as_variable(:x), model) + if node == applied_node + @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() + else + @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) + end + end +end + +@testitem "save constraints with constants via `mean_field_constraint!`" setup = [TestUtils] begin + using BitSetTuples + import GraphPPL: + create_model, + with_plugins, + getextra, + mean_field_constraint!, + getproperties, + VariationalConstraintsPlugin, + PluginsCollection, + VariationalConstraintsFactorizationBitSetKey + + model = create_model(with_plugins(TestUtils.simple_model(), GraphPPL.PluginsCollection(VariationalConstraintsPlugin()))) + ctx = GraphPPL.getcontext(model) + + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 1)) == ((1,), (2, 3), (2, 3)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 2)) == ((1, 3), (2,), (1, 3)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 3)) == ((1, 2), (1, 2), (3,)) + + node = ctx[TestUtils.NormalMeanVariance, 2] + constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2, 3), (2, 3)) + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 2))) == ((1,), (2,), (3,)) + + node = ctx[TestUtils.NormalMeanVariance, 1] + constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) + # Here it is the mean field because the original model has `x ~ Normal(0, 1)` and `0` and `1` are constants + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2,), (3,)) +end + +@testitem "materialize_constraints!(:Model, ::NodeLabel, ::FactorNodeData)" setup = [TestUtils] begin + using BitSetTuples + import GraphPPL: + create_model, + with_plugins, + materialize_constraints!, + EdgeLabel, + get_constraint_names, + getproperties, + getextra, + setextra!, + VariationalConstraintsPlugin + + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + # Test 1: Test materialize with a Full Factorization constraint + node = ctx[TestUtils.NormalMeanVariance, 2] + + # Force overwrite the bitset and the constraints + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(3)) + materialize_constraints!(model, node) + @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1, 2, 3),) + + node = ctx[TestUtils.NormalMeanVariance, 1] + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2,), (3,)))) + materialize_constraints!(model, node) + @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2,), (3,)) + + # Test 2: Test materialize with an applied constraint + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2, 3), (2, 3)))) + materialize_constraints!(model, node) + @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2, 3)) + + # # Test 3: Check that materialize_constraints! throws if the constraint is not a valid partition + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (3,), (1, 3)))) + @test_throws ErrorException materialize_constraints!(model, node) + + # Test 4: Check that materialize_constraints! throws if the constraint is not a valid partition + model = create_model(TestUtils.simple_model()) + ctx = GraphPPL.getcontext(model) + node = ctx[TestUtils.NormalMeanVariance, 2] + + setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (1,), (3,)))) + @test_throws ErrorException materialize_constraints!(model, node) +end diff --git a/test/plugins/variational_constraints/components/constraint_types_tests.jl b/test/plugins/variational_constraints/components/constraint_types_tests.jl new file mode 100644 index 00000000..078ca23c --- /dev/null +++ b/test/plugins/variational_constraints/components/constraint_types_tests.jl @@ -0,0 +1,113 @@ +@testitem "FactorizationConstraintEntry" setup = [TestUtils] begin + import GraphPPL: FactorizationConstraintEntry, IndexedVariable + + # Test 1: Test FactorisationConstraintEntry + @test FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) isa FactorizationConstraintEntry + + a = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) + b = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) + @test a == b + c = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing), IndexedVariable(:z, nothing))) + @test a != c + d = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:p, nothing))) + @test a != d + + # Test 2: Test FactorisationConstraintEntry with mixed IndexedVariable types + a = FactorizationConstraintEntry((IndexedVariable(:x, 1), IndexedVariable(:y, nothing))) +end + +@testitem "multiply(::FactorizationConstraintEntry, ::FactorizationConstraintEntry)" begin + import GraphPPL: FactorizationConstraintEntry, IndexedVariable + + entry = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) + global x = entry + for i in 1:3 + global x = x * x + @test x == Tuple([entry for _ in 1:(2^i)]) + end +end + +@testitem "FactorizationConstraint" setup = [TestUtils] begin + import GraphPPL: FactorizationConstraint, FactorizationConstraintEntry, IndexedVariable, FunctionalIndex, CombinedRange, SplittedRange + + # Test 1: Test FactorizationConstraint with single variables + @test FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) + ) isa Any + @test FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), FactorizationConstraintEntry((IndexedVariable(:y, nothing),))) + ) isa Any + @test_throws ErrorException FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)),) + ) + @test_throws ErrorException FactorizationConstraint( + (IndexedVariable(:x, nothing),), (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) + ) + + # Test 2: Test FactorizationConstraint with indexed variables + @test FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, 1), IndexedVariable(:y, 1))),) + ) isa Any + @test FactorizationConstraint( + (IndexedVariable(:x, 1), IndexedVariable(:y, 1)), + (FactorizationConstraintEntry((IndexedVariable(:x, 1),)), FactorizationConstraintEntry((IndexedVariable(:y, 1),))) + ) isa FactorizationConstraint + @test_throws ErrorException FactorizationConstraint( + (IndexedVariable(:x, 1), IndexedVariable(:y, 1)), (FactorizationConstraintEntry((IndexedVariable(:x, 1),)),) + ) + @test_throws ErrorException FactorizationConstraint( + (IndexedVariable(:x, 1),), (FactorizationConstraintEntry((IndexedVariable(:x, 1), IndexedVariable(:y, 1))),) + ) + + # Test 3: Test FactorizationConstraint with SplittedRanges + @test FactorizationConstraint( + (IndexedVariable(:x, nothing),), + ( + FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + )), + ) + ) isa FactorizationConstraint + @test_throws ErrorException FactorizationConstraint( + (IndexedVariable(:x, nothing),), + ( + FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + IndexedVariable(:y, nothing) + )), + ) + ) + + # Test 4: Test FactorizationConstraint with CombinedRanges + @test FactorizationConstraint( + (IndexedVariable(:x, nothing),), + ( + FactorizationConstraintEntry(( + IndexedVariable(:x, CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + )), + ) + ) isa FactorizationConstraint + @test_throws ErrorException FactorizationConstraint( + (IndexedVariable(:x, nothing)), + ( + FactorizationConstraintEntry(( + IndexedVariable(:x, CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + IndexedVariable(:y, nothing) + )), + ) + ) + + # Test 5: Test FactorizationConstraint with duplicate entries + @test_throws ErrorException constraint = FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing), IndexedVariable(:out, nothing)), + ( + FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), + FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), + FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), + FactorizationConstraintEntry((IndexedVariable(:out, nothing),)) + ) + ) +end \ No newline at end of file diff --git a/test/plugins/variational_constraints/components/constraints_container_tests.jl b/test/plugins/variational_constraints/components/constraints_container_tests.jl new file mode 100644 index 00000000..3fec67aa --- /dev/null +++ b/test/plugins/variational_constraints/components/constraints_container_tests.jl @@ -0,0 +1,138 @@ +@testitem "push!(::Constraints, ::Constraint)" begin + using Distributions + import GraphPPL: + Constraints, + FactorizationConstraint, + FactorizationConstraintEntry, + MarginalFormConstraint, + MessageFormConstraint, + SpecificSubModelConstraints, + GeneralSubModelConstraints, + IndexedVariable, + FactorID + + # Test 1: Test push! with FactorizationConstraint + constraints = Constraints() + constraint = FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)),),) + ) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + constraint = FactorizationConstraint( + (IndexedVariable(:x, 1), IndexedVariable(:y, 1)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) + ) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + constraint = FactorizationConstraint( + (IndexedVariable(:y, nothing), IndexedVariable(:x, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) + ) + @test_throws ErrorException push!(constraints, constraint) + + # Test 2: Test push! with MarginalFormConstraint + constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + constraint = MarginalFormConstraint((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), Normal) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + constraint = MarginalFormConstraint(IndexedVariable(:x, 1), Normal) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + constraint = MarginalFormConstraint([IndexedVariable(:x, 1), IndexedVariable(:y, 1)], Normal) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + + # Test 3: Test push! with MessageFormConstraint + constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + constraint = MessageFormConstraint(IndexedVariable(:x, 2), Normal) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + + # Test 4: Test push! with SpecificSubModelConstraints + constraint = SpecificSubModelConstraints(FactorID(sum, 3), Constraints()) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) + + # Test 5: Test push! with GeneralSubModelConstraints + constraint = GeneralSubModelConstraints(sum, Constraints()) + push!(constraints, constraint) + @test_throws ErrorException push!(constraints, constraint) +end + +@testitem "push!(::SubModelConstraints, c::Constraint)" setup = [TestUtils] begin + using Distributions + import GraphPPL: + Constraint, + GeneralSubModelConstraints, + SpecificSubModelConstraints, + FactorizationConstraint, + FactorizationConstraintEntry, + MarginalFormConstraint, + MessageFormConstraint, + getconstraint, + Constraints, + IndexedVariable + + # Test 1: Test push! with FactorizationConstraint + constraints = GeneralSubModelConstraints(TestUtils.gcv) + constraint = FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) + ) + push!(constraints, constraint) + @test getconstraint(constraints) == Constraints([ + FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)),),) + ) + ],) + @test_throws MethodError push!(constraints, "string") + + # Test 2: Test push! with MarginalFormConstraint + constraints = GeneralSubModelConstraints(TestUtils.gcv) + constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) + push!(constraints, constraint) + @test getconstraint(constraints) == Constraints([MarginalFormConstraint(IndexedVariable(:x, nothing), Normal)],) + @test_throws MethodError push!(constraints, "string") + + # Test 3: Test push! with MessageFormConstraint + constraints = GeneralSubModelConstraints(TestUtils.gcv) + constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) + push!(constraints, constraint) + @test getconstraint(constraints) == Constraints([MessageFormConstraint(IndexedVariable(:x, nothing), Normal)],) + @test_throws MethodError push!(constraints, "string") + + # Test 4: Test push! with SpecificSubModelConstraints + constraints = SpecificSubModelConstraints(GraphPPL.FactorID(TestUtils.gcv, 3)) + constraint = FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) + ) + push!(constraints, constraint) + @test getconstraint(constraints) == Constraints([ + FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)),),) + ) + ],) + @test_throws MethodError push!(constraints, "string") + + # Test 5: Test push! with MarginalFormConstraint + constraints = GeneralSubModelConstraints(TestUtils.gcv) + constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) + push!(constraints, constraint) + @test getconstraint(constraints) == Constraints([MarginalFormConstraint(IndexedVariable(:x, nothing), Normal)],) + @test_throws MethodError push!(constraints, "string") + + # Test 6: Test push! with MessageFormConstraint + constraints = GeneralSubModelConstraints(TestUtils.gcv) + constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) + push!(constraints, constraint) + @test getconstraint(constraints) == Constraints([MessageFormConstraint(IndexedVariable(:x, nothing), Normal)],) + @test_throws MethodError push!(constraints, "string") +end \ No newline at end of file diff --git a/test/plugins/variational_constraints/components/constraints_defaults_tests.jl b/test/plugins/variational_constraints/components/constraints_defaults_tests.jl new file mode 100644 index 00000000..7dbfab2a --- /dev/null +++ b/test/plugins/variational_constraints/components/constraints_defaults_tests.jl @@ -0,0 +1,87 @@ +@testitem "default_constraints" setup = [TestUtils] begin + import GraphPPL: + create_model, + with_plugins, + default_constraints, + getproperties, + PluginsCollection, + VariationalConstraintsPlugin, + hasextra, + getextra, + UnspecifiedConstraints + + @test default_constraints(TestUtils.simple_model) == UnspecifiedConstraints + @test default_constraints(TestUtils.model_with_default_constraints) == @constraints( + begin + q(a, d) = q(a)q(d) + end + ) + + model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin()))) + ctx = GraphPPL.getcontext(model) + # Test that default constraints are applied + for i in 1:10 + node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] + @test hasextra(node, :factorization_constraint_indices) + @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1,), (2,), (3,)) + end + + # Test that default constraints are not applied if we specify constraints in the context + c = @constraints begin + for q in TestUtils.model_with_default_constraints + q(a, d) = q(a, d) + end + end + model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) + ctx = GraphPPL.getcontext(model) + for i in 1:10 + node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] + @test hasextra(node, :factorization_constraint_indices) + @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1, 2), (3,)) + end + + # Test that default constraints are not applied if we specify constraints for a specific instance of the submodel + c = @constraints begin + for q in (TestUtils.model_with_default_constraints, 1) + q(a, d) = q(a, d) + end + end + model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) + ctx = GraphPPL.getcontext(model) + for i in 1:10 + node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] + @test hasextra(node, :factorization_constraint_indices) + if i == 1 + @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1, 2), (3,)) + else + @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1,), (2,), (3,)) + end + end +end + +@testitem "show constraints" begin + using Distributions + using GraphPPL + + constraint = @constraints begin + q(x)::Normal + end + @test occursin(r"q\(x\) ::(.*?)Normal", repr(constraint)) + + constraint = @constraints begin + q(x, y) = q(x)q(y) + end + @test occursin(r"q\(x, y\) = q\(x\)q\(y\)", repr(constraint)) + + constraint = @constraints begin + μ(x)::Normal + end + @test occursin(r"μ\(x\) ::(.*?)Normal", repr(constraint)) + + constraint = @constraints begin + q(x, y) = q(x)q(y) + μ(x)::Normal + end + @test occursin(r"q\(x, y\) = q\(x\)q\(y\)", repr(constraint)) + @test occursin(r"μ\(x\) ::(.*?)Normal", repr(constraint)) +end \ No newline at end of file diff --git a/test/plugins/variational_constraints/components/range_types_tests.jl b/test/plugins/variational_constraints/components/range_types_tests.jl new file mode 100644 index 00000000..6823956c --- /dev/null +++ b/test/plugins/variational_constraints/components/range_types_tests.jl @@ -0,0 +1,202 @@ +@testitem "CombinedRange" setup = [TestUtils] begin + import GraphPPL: CombinedRange, is_splitted, FunctionalIndex, IndexedVariable + for left in 1:3, right in 5:8 + cr = CombinedRange(left, right) + + @test firstindex(cr) === left + @test lastindex(cr) === right + @test !is_splitted(cr) + @test length(cr) === lastindex(cr) - firstindex(cr) + 1 + + for i in left:right + @test i ∈ cr + @test !((i + lastindex(cr) + 1) ∈ cr) + end + end + range = CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) + @test firstindex(range).f === firstindex + @test lastindex(range).f === lastindex + @test_throws MethodError length(range) + + # Test IndexedVariable with CombinedRange equality + lhs = IndexedVariable(:x, CombinedRange(1, 2)) + rhs = IndexedVariable(:x, CombinedRange(1, 2)) + @test lhs == rhs + @test lhs === rhs + @test lhs != IndexedVariable(:x, CombinedRange(1, 3)) + @test lhs !== IndexedVariable(:x, CombinedRange(1, 3)) + @test lhs != IndexedVariable(:y, CombinedRange(1, 2)) + @test lhs !== IndexedVariable(:y, CombinedRange(1, 2)) +end + +@testitem "SplittedRange" setup = [TestUtils] begin + import GraphPPL: SplittedRange, is_splitted, FunctionalIndex, IndexedVariable + for left in 1:3, right in 5:8 + cr = SplittedRange(left, right) + + @test firstindex(cr) === left + @test lastindex(cr) === right + @test is_splitted(cr) + @test length(cr) === lastindex(cr) - firstindex(cr) + 1 + + for i in left:right + @test i ∈ cr + @test !((i + lastindex(cr) + 1) ∈ cr) + end + end + range = SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) + @test firstindex(range).f === firstindex + @test lastindex(range).f === lastindex + @test_throws MethodError length(range) + + # Test IndexedVariable with SplittedRange equality + lhs = IndexedVariable(:x, SplittedRange(1, 2)) + rhs = IndexedVariable(:x, SplittedRange(1, 2)) + @test lhs == rhs + @test lhs === rhs + @test lhs != IndexedVariable(:x, SplittedRange(1, 3)) + @test lhs !== IndexedVariable(:x, SplittedRange(1, 3)) + @test lhs != IndexedVariable(:y, SplittedRange(1, 2)) + @test lhs !== IndexedVariable(:y, SplittedRange(1, 2)) +end + +@testitem "__factorization_specification_resolve_index" setup = [TestUtils] begin + using GraphPPL + import GraphPPL: __factorization_specification_resolve_index, FunctionalIndex, CombinedRange, SplittedRange, NodeLabel, ResizableArray + + collection = ResizableArray(NodeLabel, Val(1)) + for i in 1:10 + collection[i] = NodeLabel(:x, i) + end + + # Test 1: Test __factorization_specification_resolve_index with FunctionalIndex + index = FunctionalIndex{:begin}(firstindex) + @test __factorization_specification_resolve_index(index, collection) === firstindex(collection) + + @test_throws ErrorException __factorization_specification_resolve_index(index, collection[1]) + + # Test 2: Test __factorization_specification_resolve_index with CombinedRange + index = CombinedRange(1, 5) + @test __factorization_specification_resolve_index(index, collection) === index + index = CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) + @test __factorization_specification_resolve_index(index, collection) === CombinedRange(1, 10) + index = CombinedRange(5, FunctionalIndex{:end}(lastindex)) + @test __factorization_specification_resolve_index(index, collection) === CombinedRange(5, 10) + index = CombinedRange(1, 20) + @test_throws ErrorException __factorization_specification_resolve_index(index, collection) + + @test_throws ErrorException __factorization_specification_resolve_index(index, collection[1]) + + # Test 3: Test __factorization_specification_resolve_index with SplittedRange + index = SplittedRange(1, 5) + @test __factorization_specification_resolve_index(index, collection) === index + index = SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) + @test __factorization_specification_resolve_index(index, collection) === SplittedRange(1, 10) + index = SplittedRange(5, FunctionalIndex{:end}(lastindex)) + @test __factorization_specification_resolve_index(index, collection) === SplittedRange(5, 10) + index = SplittedRange(1, 20) + @test_throws ErrorException __factorization_specification_resolve_index(index, collection) + + @test_throws ErrorException __factorization_specification_resolve_index(index, collection[1]) + + # Test 4: Test __factorization_specification_resolve_index with Array of indices + index = SplittedRange( + [FunctionalIndex{:begin}(firstindex), FunctionalIndex{:begin}(firstindex)], + [FunctionalIndex{:end}(lastindex), FunctionalIndex{:end}(lastindex)] + ) + collection = GraphPPL.ResizableArray(GraphPPL.NodeLabel, Val(2)) + for i in 1:3 + for j in 1:5 + collection[i, j] = GraphPPL.NodeLabel(:x, i * j) + end + end +end + +@testitem "factorization_split" setup = [TestUtils] begin + import GraphPPL: factorization_split, FactorizationConstraintEntry, IndexedVariable, FunctionalIndex, CombinedRange, SplittedRange + + # Test 1: Test factorization_split with single split + @test factorization_split( + (FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)),), + (FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)),) + ) == ( + FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + ),), + ) + + @test factorization_split( + ( + FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), + FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)) + ), + ( + FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)), + FactorizationConstraintEntry((IndexedVariable(:z, nothing),)) + ) + ) == ( + FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), + FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + )), + FactorizationConstraintEntry((IndexedVariable(:z, nothing),)) + ) + + @test factorization_split( + ( + FactorizationConstraintEntry(( + IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)), IndexedVariable(:y, FunctionalIndex{:begin}(firstindex)) + )), + ), + ( + FactorizationConstraintEntry(( + IndexedVariable(:x, FunctionalIndex{:end}(lastindex)), IndexedVariable(:y, FunctionalIndex{:end}(lastindex)) + )), + ) + ) == ( + FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + IndexedVariable(:y, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))) + )), + ) + + # Test factorization_split with only FactorizationConstraintEntrys + @test factorization_split( + FactorizationConstraintEntry(( + IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)), IndexedVariable(:y, FunctionalIndex{:begin}(firstindex)) + )), + FactorizationConstraintEntry(( + IndexedVariable(:x, FunctionalIndex{:end}(lastindex)), IndexedVariable(:y, FunctionalIndex{:end}(lastindex)) + )) + ) == FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + IndexedVariable(:y, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))) + )) + + # Test mixed behaviour + @test factorization_split( + ( + FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), + FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)) + ), + FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)) + ) == ( + FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), + FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + )) + ) + + @test factorization_split( + FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)), + ( + FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)), + FactorizationConstraintEntry((IndexedVariable(:z, nothing),),) + ) + ) == ( + FactorizationConstraintEntry(( + IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), + )), + FactorizationConstraintEntry((IndexedVariable(:z, nothing),)) + ) +end \ No newline at end of file diff --git a/test/plugins/variational_constraints/components/resolvers_tests.jl b/test/plugins/variational_constraints/components/resolvers_tests.jl new file mode 100644 index 00000000..076cefef --- /dev/null +++ b/test/plugins/variational_constraints/components/resolvers_tests.jl @@ -0,0 +1,327 @@ +@testitem "ResolvedIndexedVariable" setup = [TestUtils] begin + import GraphPPL: ResolvedIndexedVariable, IndexedVariable, getname, index, getcontext + + var = ResolvedIndexedVariable(IndexedVariable(:x, 1), Context()) + @test getname(var) == :x + @test index(var) == 1 + @test getcontext(var) isa Context +end + +@testitem "ResolvedConstraintLHS" setup = [TestUtils] begin + import GraphPPL: ResolvedConstraintLHS, ResolvedIndexedVariable, IndexedVariable, getvariables + + ctx = Context() + var1 = ResolvedIndexedVariable(IndexedVariable(:x, 1), ctx) + var2 = ResolvedIndexedVariable(IndexedVariable(:y, 2), ctx) + + lhs = ResolvedConstraintLHS((var1, var2)) + @test getvariables(lhs) == (var1, var2) + + lhs1 = ResolvedConstraintLHS((var1, var2)) + lhs2 = ResolvedConstraintLHS((var1, var2)) + @test lhs1 == lhs2 + + lhs3 = ResolvedConstraintLHS((var2, var1)) + @test lhs1 != lhs3 +end + +@testitem "ResolvedFactorizationConstraintEntry" setup = [TestUtils] begin + import GraphPPL: ResolvedFactorizationConstraintEntry, ResolvedIndexedVariable, IndexedVariable, getvariables + + ctx = Context() + var1 = ResolvedIndexedVariable(IndexedVariable(:x, 1), ctx) + var2 = ResolvedIndexedVariable(IndexedVariable(:y, 2), ctx) + + entry = ResolvedFactorizationConstraintEntry((var1, var2)) + @test getvariables(entry) == (var1, var2) +end + +@testitem "ResolvedFactorizationConstraint" setup = [TestUtils] begin + import GraphPPL: + ResolvedFactorizationConstraint, + ResolvedConstraintLHS, + ResolvedFactorizationConstraintEntry, + ResolvedIndexedVariable, + IndexedVariable, + lhs, + rhs + + ctx = Context() + var1 = ResolvedIndexedVariable(IndexedVariable(:x, 1), ctx) + var2 = ResolvedIndexedVariable(IndexedVariable(:y, 2), ctx) + + resolved_lhs = ResolvedConstraintLHS((var1, var2)) + entry1 = ResolvedFactorizationConstraintEntry((var1,)) + entry2 = ResolvedFactorizationConstraintEntry((var2,)) + + constraint = ResolvedFactorizationConstraint(resolved_lhs, (entry1, entry2)) + @test lhs(constraint) == resolved_lhs + @test rhs(constraint) == (entry1, entry2) + + constraint1 = ResolvedFactorizationConstraint(resolved_lhs, (entry1, entry2)) + constraint2 = ResolvedFactorizationConstraint(resolved_lhs, (entry1, entry2)) + @test constraint1 == constraint2 + + constraint3 = ResolvedFactorizationConstraint(resolved_lhs, (entry2, entry1)) + @test constraint1 != constraint3 +end + +@testitem "ResolvedFunctionalFormConstraint" setup = [TestUtils] begin + import GraphPPL: ResolvedFunctionalFormConstraint, ResolvedConstraintLHS, ResolvedIndexedVariable, IndexedVariable, lhs, rhs + using Distributions + + ctx = Context() + var1 = ResolvedIndexedVariable(IndexedVariable(:x, 1), ctx) + var2 = ResolvedIndexedVariable(IndexedVariable(:y, 2), ctx) + + resolved_lhs = ResolvedConstraintLHS((var1, var2)) + + constraint = ResolvedFunctionalFormConstraint(resolved_lhs, Normal) + @test lhs(constraint) == resolved_lhs + @test rhs(constraint) == Normal +end + +@testitem "ConstraintStack" setup = [TestUtils] begin + import GraphPPL: + ConstraintStack, + ResolvedFactorizationConstraint, + ResolvedConstraintLHS, + ResolvedFactorizationConstraintEntry, + ResolvedIndexedVariable, + IndexedVariable, + constraints, + context_counts + + ctx1 = Context() + ctx2 = Context() + var1 = ResolvedIndexedVariable(IndexedVariable(:x, 1), ctx1) + var2 = ResolvedIndexedVariable(IndexedVariable(:y, 2), ctx1) + + resolved_lhs = ResolvedConstraintLHS((var1, var2)) + entry1 = ResolvedFactorizationConstraintEntry((var1,)) + entry2 = ResolvedFactorizationConstraintEntry((var2,)) + + constraint1 = ResolvedFactorizationConstraint(resolved_lhs, (entry1, entry2)) + constraint2 = ResolvedFactorizationConstraint(resolved_lhs, (entry2, entry1)) + + stack = ConstraintStack() + @test isempty(context_counts(stack)) + + push!(stack, constraint1, ctx1) + @test context_counts(stack)[ctx1] == 1 + @test isempty(get(context_counts(stack), ctx2, Dict())) + + push!(stack, constraint2, ctx1) + @test context_counts(stack)[ctx1] == 2 + + push!(stack, constraint1, ctx2) + @test context_counts(stack)[ctx1] == 2 + @test context_counts(stack)[ctx2] == 1 + + @test pop!(stack, ctx1) + @test context_counts(stack)[ctx1] == 1 + + @test pop!(stack, ctx1) + @test context_counts(stack)[ctx1] == 0 + + @test !pop!(stack, ctx1) + @test context_counts(stack)[ctx1] == 0 + + @test pop!(stack, ctx2) + @test context_counts(stack)[ctx2] == 0 + + @test !pop!(stack, ctx2) +end + +@testitem "Resolve Factorization Constraints" setup = [TestUtils] begin + using Distributions + import GraphPPL: + create_model, + FactorizationConstraint, + FactorizationConstraintEntry, + IndexedVariable, + resolve, + ResolvedFactorizationConstraint, + ResolvedConstraintLHS, + ResolvedFactorizationConstraintEntry, + ResolvedIndexedVariable, + CombinedRange, + SplittedRange, + @model + + model = create_model(TestUtils.outer()) + ctx = GraphPPL.getcontext(model) + inner_context = ctx[TestUtils.inner, 1] + + # Test resolve constraint in child model + + let constraint = FactorizationConstraint( + (IndexedVariable(:α, nothing), IndexedVariable(:θ, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:α, nothing),)), FactorizationConstraintEntry((IndexedVariable(:θ, nothing),))) + ) + result = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, ctx), ResolvedIndexedVariable(:w, CombinedRange(2, 3), ctx)),), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, ctx),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, CombinedRange(2, 3), ctx),)) + ) + ) + @test resolve(model, inner_context, constraint) == result + end + + # Test constraint in top level model + + let constraint = FactorizationConstraint( + (IndexedVariable(:y, nothing), IndexedVariable(:w, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), FactorizationConstraintEntry((IndexedVariable(:w, nothing),))) + ) + result = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, ctx), ResolvedIndexedVariable(:w, CombinedRange(1, 5), ctx)),), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, ctx),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, CombinedRange(1, 5), ctx),)) + ) + ) + @test resolve(model, ctx, constraint) == result + end + + # Test a constraint that is not applicable at all + + let constraint = FactorizationConstraint( + (IndexedVariable(:i, nothing), IndexedVariable(:dont, nothing), IndexedVariable(:apply, nothing)), + ( + FactorizationConstraintEntry((IndexedVariable(:i, nothing),)), + FactorizationConstraintEntry((IndexedVariable(:dont, nothing),)), + FactorizationConstraintEntry((IndexedVariable(:apply, nothing),)) + ) + ) + result = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((),), + (ResolvedFactorizationConstraintEntry(()), ResolvedFactorizationConstraintEntry(()), ResolvedFactorizationConstraintEntry(())) + ) + @test resolve(model, ctx, constraint) == result + end + + model = create_model(TestUtils.filled_matrix_model()) + ctx = GraphPPL.getcontext(model) + + let constraint = FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), FactorizationConstraintEntry((IndexedVariable(:y, nothing),))) + ) + result = ResolvedFactorizationConstraint( + ResolvedConstraintLHS(( + ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx), ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx) + ),), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx),)) + ) + ) + @test resolve(model, ctx, constraint) == result + end + model = create_model(TestUtils.filled_matrix_model()) + ctx = GraphPPL.getcontext(model) + + let constraint = FactorizationConstraint( + (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), + (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), FactorizationConstraintEntry((IndexedVariable(:y, nothing),))) + ) + result = ResolvedFactorizationConstraint( + ResolvedConstraintLHS(( + ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx), ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx) + ),), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx),)) + ) + ) + @test resolve(model, ctx, constraint) == result + end + + # Test a constraint that mentions a lower-dimensional slice of a matrix variable + + @model function uneven_matrix() + local prec + local y + for i in 1:3 + for j in 1:3 + prec[i, j] ~ Gamma(1, 1) + y[i, j] ~ Normal(0, prec[i, j]) + end + end + prec[2, 4] ~ Gamma(1, 1) + y[2, 4] ~ Normal(0, prec[2, 4]) + end + + model = create_model(uneven_matrix()) + ctx = GraphPPL.getcontext(model) + let constraint = GraphPPL.FactorizationConstraint( + (IndexedVariable(:prec, [1, 3]), IndexedVariable(:y, nothing)), + ( + FactorizationConstraintEntry((IndexedVariable(:prec, [1, 3]),)), + FactorizationConstraintEntry((IndexedVariable(:y, nothing),)) + ) + ) + result = ResolvedFactorizationConstraint( + ResolvedConstraintLHS((ResolvedIndexedVariable(:prec, 3, ctx), ResolvedIndexedVariable(:y, CombinedRange(1, 10), ctx)),), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:prec, 3, ctx),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, CombinedRange(1, 10), ctx),)) + ) + ) + @test resolve(model, ctx, constraint) == result + end +end + +@testitem "Resolved Constraints in" begin + import GraphPPL: + ResolvedFactorizationConstraint, + ResolvedConstraintLHS, + ResolvedFactorizationConstraintEntry, + ResolvedIndexedVariable, + SplittedRange, + getname, + index, + VariableNodeProperties, + NodeLabel, + ResizableArray + + context = GraphPPL.Context() + context[:w] = ResizableArray([NodeLabel(:w, 1), NodeLabel(:w, 2), NodeLabel(:w, 3), NodeLabel(:w, 4), NodeLabel(:w, 5)]) + context[:prec] = ResizableArray([ + [NodeLabel(:prec, 1), NodeLabel(:prec, 2), NodeLabel(:prec, 3)], [NodeLabel(:prec, 4), NodeLabel(:prec, 5), NodeLabel(:prec, 6)] + ]) + + variable = ResolvedIndexedVariable(:w, 2:3, context) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + @test node_data ∈ variable + + variable = ResolvedIndexedVariable(:w, 2:3, context) + node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2)) + @test !(node_data ∈ variable) + + variable = ResolvedIndexedVariable(:w, 2, context) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + @test node_data ∈ variable + + variable = ResolvedIndexedVariable(:w, SplittedRange(2, 3), context) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + @test node_data ∈ variable + + variable = ResolvedIndexedVariable(:w, SplittedRange(10, 15), context) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) + @test !(node_data ∈ variable) + + variable = ResolvedIndexedVariable(:x, nothing, context) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2)) + @test node_data ∈ variable + + variable = ResolvedIndexedVariable(:x, nothing, context) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing)) + @test node_data ∈ variable + + variable = ResolvedIndexedVariable(:prec, 3, context) + node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3))) + @test node_data ∈ variable +end \ No newline at end of file diff --git a/test/plugins/variational_constraints/components/utils_tests.jl b/test/plugins/variational_constraints/components/utils_tests.jl new file mode 100644 index 00000000..27de412c --- /dev/null +++ b/test/plugins/variational_constraints/components/utils_tests.jl @@ -0,0 +1,32 @@ +@testitem "lazy_bool_allequal" begin + import GraphPPL: lazy_bool_allequal + + @testset begin + itr = [1, 2, 3, 4] + + outcome, value = lazy_bool_allequal(x -> x > 0, itr) + @test outcome === true + @test value === true + + outcome, value = lazy_bool_allequal(x -> x < 0, itr) + @test outcome === true + @test value === false + end + + @testset begin + itr = [1, 2, -1, -2] + + outcome, value = lazy_bool_allequal(x -> x > 0, itr) + @test outcome === false + @test value === true + + outcome, value = lazy_bool_allequal(x -> x < 0, itr) + @test outcome === false + @test value === false + end + + @testset begin + # We do not support it for now, but we can add it in the future + @test_throws ErrorException lazy_bool_allequal(x -> x > 0, []) + end +end diff --git a/test/plugins/variational_constraints/variational_constraints_tests.jl b/test/plugins/variational_constraints/integration/variational_constraints_tests.jl similarity index 100% rename from test/plugins/variational_constraints/variational_constraints_tests.jl rename to test/plugins/variational_constraints/integration/variational_constraints_tests.jl diff --git a/test/plugins/variational_constraints/variational_constraints_macro_tests.jl b/test/plugins/variational_constraints/macro/variational_constraints_macro_tests.jl similarity index 100% rename from test/plugins/variational_constraints/variational_constraints_macro_tests.jl rename to test/plugins/variational_constraints/macro/variational_constraints_macro_tests.jl diff --git a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl b/test/plugins/variational_constraints/variational_constraints_engine_tests.jl deleted file mode 100644 index 05557f02..00000000 --- a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl +++ /dev/null @@ -1,1501 +0,0 @@ -@testitem "FactorizationConstraintEntry" setup = [TestUtils] begin - import GraphPPL: FactorizationConstraintEntry, IndexedVariable - - # Test 1: Test FactorisationConstraintEntry - @test FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) isa FactorizationConstraintEntry - - a = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) - b = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) - @test a == b - c = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing), IndexedVariable(:z, nothing))) - @test a != c - d = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:p, nothing))) - @test a != d - - # Test 2: Test FactorisationConstraintEntry with mixed IndexedVariable types - a = FactorizationConstraintEntry((IndexedVariable(:x, 1), IndexedVariable(:y, nothing))) -end - -@testitem "CombinedRange" setup = [TestUtils] begin - import GraphPPL: CombinedRange, is_splitted, FunctionalIndex, IndexedVariable - for left in 1:3, right in 5:8 - cr = CombinedRange(left, right) - - @test firstindex(cr) === left - @test lastindex(cr) === right - @test !is_splitted(cr) - @test length(cr) === lastindex(cr) - firstindex(cr) + 1 - - for i in left:right - @test i ∈ cr - @test !((i + lastindex(cr) + 1) ∈ cr) - end - end - range = CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) - @test firstindex(range).f === firstindex - @test lastindex(range).f === lastindex - @test_throws MethodError length(range) - - # Test IndexedVariable with CombinedRange equality - lhs = IndexedVariable(:x, CombinedRange(1, 2)) - rhs = IndexedVariable(:x, CombinedRange(1, 2)) - @test lhs == rhs - @test lhs === rhs - @test lhs != IndexedVariable(:x, CombinedRange(1, 3)) - @test lhs !== IndexedVariable(:x, CombinedRange(1, 3)) - @test lhs != IndexedVariable(:y, CombinedRange(1, 2)) - @test lhs !== IndexedVariable(:y, CombinedRange(1, 2)) -end - -@testitem "SplittedRange" setup = [TestUtils] begin - import GraphPPL: SplittedRange, is_splitted, FunctionalIndex, IndexedVariable - for left in 1:3, right in 5:8 - cr = SplittedRange(left, right) - - @test firstindex(cr) === left - @test lastindex(cr) === right - @test is_splitted(cr) - @test length(cr) === lastindex(cr) - firstindex(cr) + 1 - - for i in left:right - @test i ∈ cr - @test !((i + lastindex(cr) + 1) ∈ cr) - end - end - range = SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) - @test firstindex(range).f === firstindex - @test lastindex(range).f === lastindex - @test_throws MethodError length(range) - - # Test IndexedVariable with SplittedRange equality - lhs = IndexedVariable(:x, SplittedRange(1, 2)) - rhs = IndexedVariable(:x, SplittedRange(1, 2)) - @test lhs == rhs - @test lhs === rhs - @test lhs != IndexedVariable(:x, SplittedRange(1, 3)) - @test lhs !== IndexedVariable(:x, SplittedRange(1, 3)) - @test lhs != IndexedVariable(:y, SplittedRange(1, 2)) - @test lhs !== IndexedVariable(:y, SplittedRange(1, 2)) -end - -@testitem "__factorization_specification_resolve_index" setup = [TestUtils] begin - using GraphPPL - import GraphPPL: __factorization_specification_resolve_index, FunctionalIndex, CombinedRange, SplittedRange, NodeLabel, ResizableArray - - collection = ResizableArray(NodeLabel, Val(1)) - for i in 1:10 - collection[i] = NodeLabel(:x, i) - end - - # Test 1: Test __factorization_specification_resolve_index with FunctionalIndex - index = FunctionalIndex{:begin}(firstindex) - @test __factorization_specification_resolve_index(index, collection) === firstindex(collection) - - @test_throws ErrorException __factorization_specification_resolve_index(index, collection[1]) - - # Test 2: Test __factorization_specification_resolve_index with CombinedRange - index = CombinedRange(1, 5) - @test __factorization_specification_resolve_index(index, collection) === index - index = CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) - @test __factorization_specification_resolve_index(index, collection) === CombinedRange(1, 10) - index = CombinedRange(5, FunctionalIndex{:end}(lastindex)) - @test __factorization_specification_resolve_index(index, collection) === CombinedRange(5, 10) - index = CombinedRange(1, 20) - @test_throws ErrorException __factorization_specification_resolve_index(index, collection) - - @test_throws ErrorException __factorization_specification_resolve_index(index, collection[1]) - - # Test 3: Test __factorization_specification_resolve_index with SplittedRange - index = SplittedRange(1, 5) - @test __factorization_specification_resolve_index(index, collection) === index - index = SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex)) - @test __factorization_specification_resolve_index(index, collection) === SplittedRange(1, 10) - index = SplittedRange(5, FunctionalIndex{:end}(lastindex)) - @test __factorization_specification_resolve_index(index, collection) === SplittedRange(5, 10) - index = SplittedRange(1, 20) - @test_throws ErrorException __factorization_specification_resolve_index(index, collection) - - @test_throws ErrorException __factorization_specification_resolve_index(index, collection[1]) - - # Test 4: Test __factorization_specification_resolve_index with Array of indices - index = SplittedRange( - [FunctionalIndex{:begin}(firstindex), FunctionalIndex{:begin}(firstindex)], - [FunctionalIndex{:end}(lastindex), FunctionalIndex{:end}(lastindex)] - ) - collection = GraphPPL.ResizableArray(GraphPPL.NodeLabel, Val(2)) - for i in 1:3 - for j in 1:5 - collection[i, j] = GraphPPL.NodeLabel(:x, i * j) - end - end -end - -@testitem "factorization_split" setup = [TestUtils] begin - import GraphPPL: factorization_split, FactorizationConstraintEntry, IndexedVariable, FunctionalIndex, CombinedRange, SplittedRange - - # Test 1: Test factorization_split with single split - @test factorization_split( - (FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)),), - (FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)),) - ) == ( - FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - ),), - ) - - @test factorization_split( - ( - FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), - FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)) - ), - ( - FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)), - FactorizationConstraintEntry((IndexedVariable(:z, nothing),)) - ) - ) == ( - FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), - FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - )), - FactorizationConstraintEntry((IndexedVariable(:z, nothing),)) - ) - - @test factorization_split( - ( - FactorizationConstraintEntry(( - IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)), IndexedVariable(:y, FunctionalIndex{:begin}(firstindex)) - )), - ), - ( - FactorizationConstraintEntry(( - IndexedVariable(:x, FunctionalIndex{:end}(lastindex)), IndexedVariable(:y, FunctionalIndex{:end}(lastindex)) - )), - ) - ) == ( - FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - IndexedVariable(:y, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))) - )), - ) - - # Test factorization_split with only FactorizationConstraintEntrys - @test factorization_split( - FactorizationConstraintEntry(( - IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)), IndexedVariable(:y, FunctionalIndex{:begin}(firstindex)) - )), - FactorizationConstraintEntry(( - IndexedVariable(:x, FunctionalIndex{:end}(lastindex)), IndexedVariable(:y, FunctionalIndex{:end}(lastindex)) - )) - ) == FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - IndexedVariable(:y, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))) - )) - - # Test mixed behaviour - @test factorization_split( - ( - FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), - FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)) - ), - FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)) - ) == ( - FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), - FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - )) - ) - - @test factorization_split( - FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:begin}(firstindex)),)), - ( - FactorizationConstraintEntry((IndexedVariable(:x, FunctionalIndex{:end}(lastindex)),)), - FactorizationConstraintEntry((IndexedVariable(:z, nothing),),) - ) - ) == ( - FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - )), - FactorizationConstraintEntry((IndexedVariable(:z, nothing),)) - ) -end - -@testitem "FactorizationConstraint" setup = [TestUtils] begin - import GraphPPL: FactorizationConstraint, FactorizationConstraintEntry, IndexedVariable, FunctionalIndex, CombinedRange, SplittedRange - - # Test 1: Test FactorizationConstraint with single variables - @test FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) - ) isa Any - @test FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), FactorizationConstraintEntry((IndexedVariable(:y, nothing),))) - ) isa Any - @test_throws ErrorException FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)),) - ) - @test_throws ErrorException FactorizationConstraint( - (IndexedVariable(:x, nothing),), (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) - ) - - # Test 2: Test FactorizationConstraint with indexed variables - @test FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, 1), IndexedVariable(:y, 1))),) - ) isa Any - @test FactorizationConstraint( - (IndexedVariable(:x, 1), IndexedVariable(:y, 1)), - (FactorizationConstraintEntry((IndexedVariable(:x, 1),)), FactorizationConstraintEntry((IndexedVariable(:y, 1),))) - ) isa FactorizationConstraint - @test_throws ErrorException FactorizationConstraint( - (IndexedVariable(:x, 1), IndexedVariable(:y, 1)), (FactorizationConstraintEntry((IndexedVariable(:x, 1),)),) - ) - @test_throws ErrorException FactorizationConstraint( - (IndexedVariable(:x, 1),), (FactorizationConstraintEntry((IndexedVariable(:x, 1), IndexedVariable(:y, 1))),) - ) - - # Test 3: Test FactorizationConstraint with SplittedRanges - @test FactorizationConstraint( - (IndexedVariable(:x, nothing),), - ( - FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - )), - ) - ) isa FactorizationConstraint - @test_throws ErrorException FactorizationConstraint( - (IndexedVariable(:x, nothing),), - ( - FactorizationConstraintEntry(( - IndexedVariable(:x, SplittedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - IndexedVariable(:y, nothing) - )), - ) - ) - - # Test 4: Test FactorizationConstraint with CombinedRanges - @test FactorizationConstraint( - (IndexedVariable(:x, nothing),), - ( - FactorizationConstraintEntry(( - IndexedVariable(:x, CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - )), - ) - ) isa FactorizationConstraint - @test_throws ErrorException FactorizationConstraint( - (IndexedVariable(:x, nothing)), - ( - FactorizationConstraintEntry(( - IndexedVariable(:x, CombinedRange(FunctionalIndex{:begin}(firstindex), FunctionalIndex{:end}(lastindex))), - IndexedVariable(:y, nothing) - )), - ) - ) - - # Test 5: Test FactorizationConstraint with duplicate entries - @test_throws ErrorException constraint = FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing), IndexedVariable(:out, nothing)), - ( - FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), - FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), - FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), - FactorizationConstraintEntry((IndexedVariable(:out, nothing),)) - ) - ) -end - -@testitem "multiply(::FactorizationConstraintEntry, ::FactorizationConstraintEntry)" begin - import GraphPPL: FactorizationConstraintEntry, IndexedVariable - - entry = FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))) - global x = entry - for i in 1:3 - global x = x * x - @test x == Tuple([entry for _ in 1:(2^i)]) - end -end - -@testitem "push!(::Constraints, ::Constraint)" begin - using Distributions - import GraphPPL: - Constraints, - FactorizationConstraint, - FactorizationConstraintEntry, - MarginalFormConstraint, - MessageFormConstraint, - SpecificSubModelConstraints, - GeneralSubModelConstraints, - IndexedVariable, - FactorID - - # Test 1: Test push! with FactorizationConstraint - constraints = Constraints() - constraint = FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)),),) - ) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - constraint = FactorizationConstraint( - (IndexedVariable(:x, 1), IndexedVariable(:y, 1)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) - ) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - constraint = FactorizationConstraint( - (IndexedVariable(:y, nothing), IndexedVariable(:x, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) - ) - @test_throws ErrorException push!(constraints, constraint) - - # Test 2: Test push! with MarginalFormConstraint - constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - constraint = MarginalFormConstraint((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), Normal) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - constraint = MarginalFormConstraint(IndexedVariable(:x, 1), Normal) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - constraint = MarginalFormConstraint([IndexedVariable(:x, 1), IndexedVariable(:y, 1)], Normal) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - - # Test 3: Test push! with MessageFormConstraint - constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - constraint = MessageFormConstraint(IndexedVariable(:x, 2), Normal) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - - # Test 4: Test push! with SpecificSubModelConstraints - constraint = SpecificSubModelConstraints(FactorID(sum, 3), Constraints()) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) - - # Test 5: Test push! with GeneralSubModelConstraints - constraint = GeneralSubModelConstraints(sum, Constraints()) - push!(constraints, constraint) - @test_throws ErrorException push!(constraints, constraint) -end - -@testitem "push!(::SubModelConstraints, c::Constraint)" setup = [TestUtils] begin - using Distributions - import GraphPPL: - Constraint, - GeneralSubModelConstraints, - SpecificSubModelConstraints, - FactorizationConstraint, - FactorizationConstraintEntry, - MarginalFormConstraint, - MessageFormConstraint, - getconstraint, - Constraints, - IndexedVariable - - # Test 1: Test push! with FactorizationConstraint - constraints = GeneralSubModelConstraints(TestUtils.gcv) - constraint = FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) - ) - push!(constraints, constraint) - @test getconstraint(constraints) == Constraints([ - FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)),),) - ) - ],) - @test_throws MethodError push!(constraints, "string") - - # Test 2: Test push! with MarginalFormConstraint - constraints = GeneralSubModelConstraints(TestUtils.gcv) - constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) - push!(constraints, constraint) - @test getconstraint(constraints) == Constraints([MarginalFormConstraint(IndexedVariable(:x, nothing), Normal)],) - @test_throws MethodError push!(constraints, "string") - - # Test 3: Test push! with MessageFormConstraint - constraints = GeneralSubModelConstraints(TestUtils.gcv) - constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) - push!(constraints, constraint) - @test getconstraint(constraints) == Constraints([MessageFormConstraint(IndexedVariable(:x, nothing), Normal)],) - @test_throws MethodError push!(constraints, "string") - - # Test 4: Test push! with SpecificSubModelConstraints - constraints = SpecificSubModelConstraints(GraphPPL.FactorID(TestUtils.gcv, 3)) - constraint = FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing))),) - ) - push!(constraints, constraint) - @test getconstraint(constraints) == Constraints([ - FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)),),) - ) - ],) - @test_throws MethodError push!(constraints, "string") - - # Test 5: Test push! with MarginalFormConstraint - constraints = GeneralSubModelConstraints(TestUtils.gcv) - constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), Normal) - push!(constraints, constraint) - @test getconstraint(constraints) == Constraints([MarginalFormConstraint(IndexedVariable(:x, nothing), Normal)],) - @test_throws MethodError push!(constraints, "string") - - # Test 6: Test push! with MessageFormConstraint - constraints = GeneralSubModelConstraints(TestUtils.gcv) - constraint = MessageFormConstraint(IndexedVariable(:x, nothing), Normal) - push!(constraints, constraint) - @test getconstraint(constraints) == Constraints([MessageFormConstraint(IndexedVariable(:x, nothing), Normal)],) - @test_throws MethodError push!(constraints, "string") -end - -@testitem "is_factorized" setup = [TestUtils] begin - import GraphPPL: is_factorized, create_model, getcontext, getproperties, getorcreate!, variable_nodes, NodeCreationOptions - - m = TestUtils.create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) - ctx = getcontext(m) - - x_1 = getorcreate!(m, ctx, NodeCreationOptions(factorized = true), :x_1, nothing) - @test is_factorized(m[x_1]) - - x_2 = getorcreate!(m, ctx, NodeCreationOptions(factorized = true), :x_2, nothing) - @test is_factorized(m[x_2]) - - x_3 = getorcreate!(m, ctx, NodeCreationOptions(factorized = true), :x_3, 1) - @test is_factorized(m[x_3[1]]) - - x_4 = getorcreate!(m, ctx, NodeCreationOptions(factorized = true), :x_4, 1) - @test is_factorized(m[x_4[1]]) - - x_5 = getorcreate!(m, ctx, NodeCreationOptions(factorized = true), :x_5, 1, 2) - @test is_factorized(m[x_5[1, 2]]) - - x_6 = getorcreate!(m, ctx, NodeCreationOptions(factorized = true), :x_6, 1, 2, 3) - @test is_factorized(m[x_6[1, 2, 3]]) -end - -@testitem "is_factorized || is_constant" setup = [TestUtils] begin - import GraphPPL: - is_constant, is_factorized, create_model, with_plugins, getcontext, getproperties, getorcreate!, variable_nodes, NodeCreationOptions - - m = TestUtils.create_test_model(plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) - ctx = getcontext(m) - x = getorcreate!(m, ctx, NodeCreationOptions(kind = :data, factorized = true), :x, nothing) - @test is_factorized(m[x]) - - for model_fn in TestUtils.ModelsInTheZooWithoutArguments - model = create_model(with_plugins(model_fn(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) - for label in variable_nodes(model) - nodedata = model[label] - if is_constant(getproperties(nodedata)) - @test is_factorized(nodedata) - else - @test !is_factorized(nodedata) - end - end - end -end - -@testitem "Application of MarginalFormConstraint" setup = [TestUtils] begin - import GraphPPL: - create_model, - MarginalFormConstraint, - IndexedVariable, - apply_constraints!, - getextra, - hasextra, - VariationalConstraintsMarginalFormConstraintKey - - struct ArbitraryFunctionalFormConstraint end - - # Test saving of MarginalFormConstraint in single variable - model = create_model(TestUtils.simple_model()) - context = GraphPPL.getcontext(model) - constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) - apply_constraints!(model, context, constraint) - for node in filter(GraphPPL.as_variable(:x), model) - @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() - end - - # Test saving of MarginalFormConstraint in multiple variables - model = create_model(TestUtils.vector_model()) - context = GraphPPL.getcontext(model) - constraint = MarginalFormConstraint(IndexedVariable(:x, nothing), ArbitraryFunctionalFormConstraint()) - apply_constraints!(model, context, constraint) - for node in filter(GraphPPL.as_variable(:x), model) - @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() - end - for node in filter(GraphPPL.as_variable(:y), model) - @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) - end - - # Test saving of MarginalFormConstraint in single variable in array - model = create_model(TestUtils.vector_model()) - context = GraphPPL.getcontext(model) - constraint = MarginalFormConstraint(IndexedVariable(:x, 1), ArbitraryFunctionalFormConstraint()) - apply_constraints!(model, context, constraint) - applied_node = context[:x][1] - for node in filter(GraphPPL.as_variable(:x), model) - if node == applied_node - @test getextra(model[node], VariationalConstraintsMarginalFormConstraintKey) == ArbitraryFunctionalFormConstraint() - else - @test !hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey) - end - end -end - -@testitem "Application of MessageFormConstraint" setup = [TestUtils] begin - import GraphPPL: - create_model, - MessageFormConstraint, - IndexedVariable, - apply_constraints!, - hasextra, - getextra, - VariationalConstraintsMessagesFormConstraintKey - - struct ArbitraryMessageFormConstraint end - - # Test saving of MessageFormConstraint in single variable - model = create_model(TestUtils.simple_model()) - context = GraphPPL.getcontext(model) - constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) - node = first(filter(GraphPPL.as_variable(:x), model)) - apply_constraints!(model, context, constraint) - @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() - - # Test saving of MessageFormConstraint in multiple variables - model = create_model(TestUtils.vector_model()) - context = GraphPPL.getcontext(model) - constraint = MessageFormConstraint(IndexedVariable(:x, nothing), ArbitraryMessageFormConstraint()) - apply_constraints!(model, context, constraint) - for node in filter(GraphPPL.as_variable(:x), model) - @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() - end - for node in filter(GraphPPL.as_variable(:y), model) - @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) - end - - # Test saving of MessageFormConstraint in single variable in array - model = create_model(TestUtils.vector_model()) - context = GraphPPL.getcontext(model) - constraint = MessageFormConstraint(IndexedVariable(:x, 1), ArbitraryMessageFormConstraint()) - apply_constraints!(model, context, constraint) - applied_node = context[:x][1] - for node in filter(GraphPPL.as_variable(:x), model) - if node == applied_node - @test getextra(model[node], VariationalConstraintsMessagesFormConstraintKey) == ArbitraryMessageFormConstraint() - else - @test !hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey) - end - end -end - -@testitem "save constraints with constants via `mean_field_constraint!`" setup = [TestUtils] begin - using BitSetTuples - import GraphPPL: - create_model, - with_plugins, - getextra, - mean_field_constraint!, - getproperties, - VariationalConstraintsPlugin, - PluginsCollection, - VariationalConstraintsFactorizationBitSetKey - - model = create_model(with_plugins(TestUtils.simple_model(), GraphPPL.PluginsCollection(VariationalConstraintsPlugin()))) - ctx = GraphPPL.getcontext(model) - - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 1)) == ((1,), (2, 3), (2, 3)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 2)) == ((1, 3), (2,), (1, 3)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 3)) == ((1, 2), (1, 2), (3,)) - - node = ctx[TestUtils.NormalMeanVariance, 2] - constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) - @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2, 3), (2, 3)) - @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 2))) == ((1,), (2,), (3,)) - - node = ctx[TestUtils.NormalMeanVariance, 1] - constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) - # Here it is the mean field because the original model has `x ~ Normal(0, 1)` and `0` and `1` are constants - @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2,), (3,)) -end - -@testitem "materialize_constraints!(:Model, ::NodeLabel, ::FactorNodeData)" setup = [TestUtils] begin - using BitSetTuples - import GraphPPL: - create_model, - with_plugins, - materialize_constraints!, - EdgeLabel, - get_constraint_names, - getproperties, - getextra, - setextra!, - VariationalConstraintsPlugin - - model = create_model(TestUtils.simple_model()) - ctx = GraphPPL.getcontext(model) - node = ctx[TestUtils.NormalMeanVariance, 2] - - # Test 1: Test materialize with a Full Factorization constraint - node = ctx[TestUtils.NormalMeanVariance, 2] - - # Force overwrite the bitset and the constraints - setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(3)) - materialize_constraints!(model, node) - @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1, 2, 3),) - - node = ctx[TestUtils.NormalMeanVariance, 1] - setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2,), (3,)))) - materialize_constraints!(model, node) - @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2,), (3,)) - - # Test 2: Test materialize with an applied constraint - model = create_model(TestUtils.simple_model()) - ctx = GraphPPL.getcontext(model) - node = ctx[TestUtils.NormalMeanVariance, 2] - - setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (2, 3), (2, 3)))) - materialize_constraints!(model, node) - @test Tuple.(getextra(model[node], :factorization_constraint_indices)) == ((1,), (2, 3)) - - # # Test 3: Check that materialize_constraints! throws if the constraint is not a valid partition - model = create_model(TestUtils.simple_model()) - ctx = GraphPPL.getcontext(model) - node = ctx[TestUtils.NormalMeanVariance, 2] - - setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (3,), (1, 3)))) - @test_throws ErrorException materialize_constraints!(model, node) - - # Test 4: Check that materialize_constraints! throws if the constraint is not a valid partition - model = create_model(TestUtils.simple_model()) - ctx = GraphPPL.getcontext(model) - node = ctx[TestUtils.NormalMeanVariance, 2] - - setextra!(model[node], :factorization_constraint_bitset, BoundedBitSetTuple(((1,), (1,), (3,)))) - @test_throws ErrorException materialize_constraints!(model, node) -end - -@testitem "Resolve Factorization Constraints" setup = [TestUtils] begin - using Distributions - import GraphPPL: - create_model, - FactorizationConstraint, - FactorizationConstraintEntry, - IndexedVariable, - resolve, - ResolvedFactorizationConstraint, - ResolvedConstraintLHS, - ResolvedFactorizationConstraintEntry, - ResolvedIndexedVariable, - CombinedRange, - SplittedRange, - @model - - model = create_model(TestUtils.outer()) - ctx = GraphPPL.getcontext(model) - inner_context = ctx[TestUtils.inner, 1] - - # Test resolve constraint in child model - - let constraint = FactorizationConstraint( - (IndexedVariable(:α, nothing), IndexedVariable(:θ, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:α, nothing),)), FactorizationConstraintEntry((IndexedVariable(:θ, nothing),))) - ) - result = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, ctx), ResolvedIndexedVariable(:w, CombinedRange(2, 3), ctx)),), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, ctx),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, CombinedRange(2, 3), ctx),)) - ) - ) - @test resolve(model, inner_context, constraint) == result - end - - # Test constraint in top level model - - let constraint = FactorizationConstraint( - (IndexedVariable(:y, nothing), IndexedVariable(:w, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:y, nothing),)), FactorizationConstraintEntry((IndexedVariable(:w, nothing),))) - ) - result = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, ctx), ResolvedIndexedVariable(:w, CombinedRange(1, 5), ctx)),), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, ctx),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, CombinedRange(1, 5), ctx),)) - ) - ) - @test resolve(model, ctx, constraint) == result - end - - # Test a constraint that is not applicable at all - - let constraint = FactorizationConstraint( - (IndexedVariable(:i, nothing), IndexedVariable(:dont, nothing), IndexedVariable(:apply, nothing)), - ( - FactorizationConstraintEntry((IndexedVariable(:i, nothing),)), - FactorizationConstraintEntry((IndexedVariable(:dont, nothing),)), - FactorizationConstraintEntry((IndexedVariable(:apply, nothing),)) - ) - ) - result = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((),), - (ResolvedFactorizationConstraintEntry(()), ResolvedFactorizationConstraintEntry(()), ResolvedFactorizationConstraintEntry(())) - ) - @test resolve(model, ctx, constraint) == result - end - - model = create_model(TestUtils.filled_matrix_model()) - ctx = GraphPPL.getcontext(model) - - let constraint = FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), FactorizationConstraintEntry((IndexedVariable(:y, nothing),))) - ) - result = ResolvedFactorizationConstraint( - ResolvedConstraintLHS(( - ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx), ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx) - ),), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx),)) - ) - ) - @test resolve(model, ctx, constraint) == result - end - model = create_model(TestUtils.filled_matrix_model()) - ctx = GraphPPL.getcontext(model) - - let constraint = FactorizationConstraint( - (IndexedVariable(:x, nothing), IndexedVariable(:y, nothing)), - (FactorizationConstraintEntry((IndexedVariable(:x, nothing),)), FactorizationConstraintEntry((IndexedVariable(:y, nothing),))) - ) - result = ResolvedFactorizationConstraint( - ResolvedConstraintLHS(( - ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx), ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx) - ),), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, CombinedRange(1, 9), ctx),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, CombinedRange(1, 9), ctx),)) - ) - ) - @test resolve(model, ctx, constraint) == result - end - - # Test a constraint that mentions a lower-dimensional slice of a matrix variable - - @model function uneven_matrix() - local prec - local y - for i in 1:3 - for j in 1:3 - prec[i, j] ~ Gamma(1, 1) - y[i, j] ~ Normal(0, prec[i, j]) - end - end - prec[2, 4] ~ Gamma(1, 1) - y[2, 4] ~ Normal(0, prec[2, 4]) - end - - model = create_model(uneven_matrix()) - ctx = GraphPPL.getcontext(model) - let constraint = GraphPPL.FactorizationConstraint( - (IndexedVariable(:prec, [1, 3]), IndexedVariable(:y, nothing)), - ( - FactorizationConstraintEntry((IndexedVariable(:prec, [1, 3]),)), - FactorizationConstraintEntry((IndexedVariable(:y, nothing),)) - ) - ) - result = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:prec, 3, ctx), ResolvedIndexedVariable(:y, CombinedRange(1, 10), ctx)),), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:prec, 3, ctx),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, CombinedRange(1, 10), ctx),)) - ) - ) - @test resolve(model, ctx, constraint) == result - end -end - -@testitem "Resolved Constraints in" begin - import GraphPPL: - ResolvedFactorizationConstraint, - ResolvedConstraintLHS, - ResolvedFactorizationConstraintEntry, - ResolvedIndexedVariable, - SplittedRange, - getname, - index, - VariableNodeProperties, - NodeLabel, - ResizableArray - - context = GraphPPL.Context() - context[:w] = ResizableArray([NodeLabel(:w, 1), NodeLabel(:w, 2), NodeLabel(:w, 3), NodeLabel(:w, 4), NodeLabel(:w, 5)]) - context[:prec] = ResizableArray([ - [NodeLabel(:prec, 1), NodeLabel(:prec, 2), NodeLabel(:prec, 3)], [NodeLabel(:prec, 4), NodeLabel(:prec, 5), NodeLabel(:prec, 6)] - ]) - - variable = ResolvedIndexedVariable(:w, 2:3, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) - @test node_data ∈ variable - - variable = ResolvedIndexedVariable(:w, 2:3, context) - node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2)) - @test !(node_data ∈ variable) - - variable = ResolvedIndexedVariable(:w, 2, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) - @test node_data ∈ variable - - variable = ResolvedIndexedVariable(:w, SplittedRange(2, 3), context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) - @test node_data ∈ variable - - variable = ResolvedIndexedVariable(:w, SplittedRange(10, 15), context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2)) - @test !(node_data ∈ variable) - - variable = ResolvedIndexedVariable(:x, nothing, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2)) - @test node_data ∈ variable - - variable = ResolvedIndexedVariable(:x, nothing, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing)) - @test node_data ∈ variable - - variable = ResolvedIndexedVariable(:prec, 3, context) - node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3))) - @test node_data ∈ variable -end - -@testitem "convert_to_bitsets" setup = [TestUtils] begin - using BitSetTuples - import GraphPPL: - create_model, - with_plugins, - ResolvedFactorizationConstraint, - ResolvedConstraintLHS, - ResolvedFactorizationConstraintEntry, - ResolvedIndexedVariable, - SplittedRange, - CombinedRange, - apply_constraints!, - getproperties - - model = create_model(with_plugins(TestUtils.outer(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) - context = GraphPPL.getcontext(model) - inner_context = context[TestUtils.inner, 1] - inner_inner_context = inner_context[TestUtils.inner_inner, 1] - - normal_node = inner_inner_context[TestUtils.NormalMeanVariance, 1] - neighbors = model[GraphPPL.neighbors(model, normal_node)] - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context),)), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 2, context),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 3, context),)) - ) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2), (1, 3)) - end - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 4:5, context),)), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 4, context),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 5, context),)) - ) - ) - @test !GraphPPL.is_applicable(neighbors, constraint) - end - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context),)), - (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, SplittedRange(2, 3), context),)),) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2), (1, 3)) - end - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context), ResolvedIndexedVariable(:y, nothing, context))), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, SplittedRange(2, 3), context),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, context),)) - ) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1,), (2,), (3,)) - end - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context), ResolvedIndexedVariable(:y, nothing, context))), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 2, context),)), - ResolvedFactorizationConstraintEntry(( - ResolvedIndexedVariable(:w, 3, context), ResolvedIndexedVariable(:y, nothing, context) - )) - ) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 3), (2,), (1, 3)) - end - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:w, 2:3, context), ResolvedIndexedVariable(:y, nothing, context))), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, CombinedRange(2, 3), context),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, nothing, context),)) - ) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1,), (2, 3), (2, 3)) - end - - model = create_model(with_plugins(TestUtils.multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) - context = GraphPPL.getcontext(model) - normal_node = context[TestUtils.NormalMeanVariance, 5] - neighbors = model[GraphPPL.neighbors(model, normal_node)] - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:x, nothing, context),),), - (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, SplittedRange(1, 9), context),)),) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 3), (2, 3), (1, 2, 3)) - end - - model = create_model(with_plugins(TestUtils.multidim_array(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) - context = GraphPPL.getcontext(model) - normal_node = context[TestUtils.NormalMeanVariance, 5] - neighbors = model[GraphPPL.neighbors(model, normal_node)] - - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:x, nothing, context),),), - (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:x, CombinedRange(1, 9), context),)),) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2, 3), (1, 2, 3)) - end - - # Test ResolvedFactorizationConstraints over anonymous variables - - model = create_model( - with_plugins(TestUtils.node_with_only_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) - ) - context = GraphPPL.getcontext(model) - normal_node = context[TestUtils.NormalMeanVariance, 6] - neighbors = model[GraphPPL.neighbors(model, normal_node)] - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), - (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, SplittedRange(1, 10), context),)),) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - end - - # Test ResolvedFactorizationConstraints over multiple anonymous variables - model = create_model( - with_plugins(TestUtils.node_with_two_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) - ) - context = GraphPPL.getcontext(model) - normal_node = context[TestUtils.NormalMeanVariance, 6] - neighbors = model[GraphPPL.neighbors(model, normal_node)] - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), - (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, SplittedRange(1, 10), context),)),) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - - # This shouldn't throw and resolve because both anonymous variables are 1-to-1 and referenced by constraint. - @test tupled_contents(GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)) == ((1, 2, 3), (1, 2), (1, 3)) - end - - # Test ResolvedFactorizationConstraints over ambiguous anonymouys variables - model = create_model( - with_plugins(TestUtils.node_with_ambiguous_anonymous(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin())) - ) - context = GraphPPL.getcontext(model) - normal_node = last(filter(GraphPPL.as_node(TestUtils.NormalMeanVariance), model)) - neighbors = model[GraphPPL.neighbors(model, normal_node)] - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), - (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:y, SplittedRange(1, 10), context),)),) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - - # This test should throw since we cannot resolve the constraint - @test_throws GraphPPL.UnresolvableFactorizationConstraintError GraphPPL.convert_to_bitsets( - model, normal_node, neighbors, constraint - ) - end - - # Test ResolvedFactorizationConstraint with a Mixture node - model = create_model(with_plugins(TestUtils.mixture(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin()))) - context = GraphPPL.getcontext(model) - mixture_node = first(filter(GraphPPL.as_node(TestUtils.Mixture), model)) - neighbors = model[GraphPPL.neighbors(model, mixture_node)] - let constraint = ResolvedFactorizationConstraint( - ResolvedConstraintLHS(( - ResolvedIndexedVariable(:m1, nothing, context), - ResolvedIndexedVariable(:m2, nothing, context), - ResolvedIndexedVariable(:m3, nothing, context), - ResolvedIndexedVariable(:m4, nothing, context) - ),), - ( - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m1, nothing, context),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m2, nothing, context),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m3, nothing, context),)), - ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m4, nothing, context),)) - ) - ) - @test GraphPPL.is_applicable(neighbors, constraint) - @test tupled_contents(GraphPPL.convert_to_bitsets(model, mixture_node, neighbors, constraint)) == tupled_contents( - BitSetTuple([ - collect(1:9), - [1, 2, 6, 7, 8, 9], - [1, 3, 6, 7, 8, 9], - [1, 4, 6, 7, 8, 9], - [1, 5, 6, 7, 8, 9], - collect(1:9), - collect(1:9), - collect(1:9), - collect(1:9) - ]) - ) - end -end - -@testitem "lazy_bool_allequal" begin - import GraphPPL: lazy_bool_allequal - - @testset begin - itr = [1, 2, 3, 4] - - outcome, value = lazy_bool_allequal(x -> x > 0, itr) - @test outcome === true - @test value === true - - outcome, value = lazy_bool_allequal(x -> x < 0, itr) - @test outcome === true - @test value === false - end - - @testset begin - itr = [1, 2, -1, -2] - - outcome, value = lazy_bool_allequal(x -> x > 0, itr) - @test outcome === false - @test value === true - - outcome, value = lazy_bool_allequal(x -> x < 0, itr) - @test outcome === false - @test value === false - end - - @testset begin - # We do not support it for now, but we can add it in the future - @test_throws ErrorException lazy_bool_allequal(x -> x > 0, []) - end -end - -@testitem "default_constraints" setup = [TestUtils] begin - import GraphPPL: - create_model, - with_plugins, - default_constraints, - getproperties, - PluginsCollection, - VariationalConstraintsPlugin, - hasextra, - getextra, - UnspecifiedConstraints - - @test default_constraints(TestUtils.simple_model) == UnspecifiedConstraints - @test default_constraints(TestUtils.model_with_default_constraints) == @constraints( - begin - q(a, d) = q(a)q(d) - end - ) - - model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin()))) - ctx = GraphPPL.getcontext(model) - # Test that default constraints are applied - for i in 1:10 - node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] - @test hasextra(node, :factorization_constraint_indices) - @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1,), (2,), (3,)) - end - - # Test that default constraints are not applied if we specify constraints in the context - c = @constraints begin - for q in TestUtils.model_with_default_constraints - q(a, d) = q(a, d) - end - end - model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) - ctx = GraphPPL.getcontext(model) - for i in 1:10 - node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] - @test hasextra(node, :factorization_constraint_indices) - @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1, 2), (3,)) - end - - # Test that default constraints are not applied if we specify constraints for a specific instance of the submodel - c = @constraints begin - for q in (TestUtils.model_with_default_constraints, 1) - q(a, d) = q(a, d) - end - end - model = create_model(with_plugins(TestUtils.contains_default_constraints(), PluginsCollection(VariationalConstraintsPlugin(c)))) - ctx = GraphPPL.getcontext(model) - for i in 1:10 - node = model[ctx[TestUtils.model_with_default_constraints, i][TestUtils.NormalMeanVariance, 1]] - @test hasextra(node, :factorization_constraint_indices) - if i == 1 - @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1, 2), (3,)) - else - @test Tuple.(getextra(node, :factorization_constraint_indices)) == ((1,), (2,), (3,)) - end - end -end - -@testitem "mean_field_constraint!" begin - using BitSetTuples - import GraphPPL: mean_field_constraint! - - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5))) == ((1,), (2,), (3,), (4,), (5,)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(10))) == ((1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)) - - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(1), 1)) == ((1,),) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), 3)) == - ((1, 2, 4, 5), (1, 2, 4, 5), (3,), (1, 2, 4, 5), (1, 2, 4, 5)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(1), (1,))) == ((1,),) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(2), (1,))) == ((1,), (2,)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(2), (2,))) == ((1,), (2,)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 2))) == ((1,), (2,), (3, 4, 5), (3, 4, 5), (3, 4, 5)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 3, 5))) == ((1,), (2, 4), (3,), (2, 4), (5,)) - @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 2, 3, 4, 5))) == ((1,), (2,), (3,), (4,), (5,)) - @test_throws BoundsError mean_field_constraint!(BoundedBitSetTuple(5), (1, 2, 3, 4, 5, 6)) == ((1,), (2,), (3,), (4,), (5,)) -end - -@testitem "Apply constraints to matrix variables" setup = [TestUtils] begin - using Distributions - import GraphPPL: - getproperties, - PluginsCollection, - VariationalConstraintsPlugin, - getextra, - getcontext, - with_plugins, - create_model, - NotImplementedError, - @model - - # Test for constraints applied to a model with matrix variables - c = @constraints begin - q(x, y) = q(x)q(y) - end - model = create_model(with_plugins(TestUtils.filled_matrix_model(), PluginsCollection(VariationalConstraintsPlugin(c)))) - - for node in filter(TestUtils.as_node(TestUtils.Normal), model) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - end - - @model function uneven_matrix() - local prec - local y - for i in 1:3 - for j in 1:3 - prec[i, j] ~ Gamma(1, 1) - y[i, j] ~ Normal(0, prec[i, j]) - end - end - prec[2, 4] ~ Gamma(1, 1) - y[2, 4] ~ Normal(0, prec[2, 4]) - end - constraints_1 = @constraints begin - q(prec, y) = q(prec)q(y) - end - - model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_1)))) - for node in filter(as_node(Normal), model) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - end - - constraints_2 = @constraints begin - q(prec[1], y) = q(prec[1])q(y) - end - - model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_2)))) - ctx = getcontext(model) - for node in filter(as_node(Normal), model) - if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:prec][1]) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - else - @test getextra(model[node], :factorization_constraint_indices) == ([1, 3], [2]) - end - end - - constraints_3 = @constraints begin - q(prec[2], y) = q(prec[2])q(y) - end - - model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_3)))) - ctx = getcontext(model) - for node in filter(as_node(Normal), model) - if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:prec][2]) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - else - @test getextra(model[node], :factorization_constraint_indices) == ([1, 3], [2]) - end - end - - constraints_4 = @constraints begin - q(prec[1, 3], y) = q(prec[1, 3])q(y) - end - model = create_model(with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_4)))) - ctx = getcontext(model) - for node in filter(as_node(Normal), model) - if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:prec][1, 3]) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - else - @test getextra(model[node], :factorization_constraint_indices) == ([1, 3], [2]) - end - end - - constraints_5 = @constraints begin - q(prec, y) = q(prec[(1, 1):(3, 3)])q(y) - end - @test_throws GraphPPL.UnresolvableFactorizationConstraintError local model = create_model( - with_plugins(uneven_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_5))) - ) - - @test_throws GraphPPL.NotImplementedError local constraints_5 = @constraints begin - q(prec, y) = q(prec[(1, 1)]) .. q(prec[(3, 3)])q(y) - end - - @model function inner_matrix(y, mat) - for i in 1:2 - for j in 1:2 - mat[i, j] ~ Normal(0, 1) - end - end - y ~ Normal(mat[1, 1], mat[2, 2]) - end - - @model function outer_matrix() - local mat - for i in 1:3 - for j in 1:3 - mat[i, j] ~ Normal(0, 1) - end - end - y ~ inner_matrix(mat = mat[2:3, 2:3]) - end - - constraints_7 = @constraints begin - for q in inner_matrix - q(mat, y) = q(mat)q(y) - end - end - @test_throws GraphPPL.UnresolvableFactorizationConstraintError local model = create_model( - with_plugins(outer_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_7))) - ) - - @model function mixed_v(y, v) - for i in 1:3 - v[i] ~ Normal(0, 1) - end - y ~ Normal(v[1], v[2]) - end - - @model function mixed_m() - v1 ~ Normal(0, 1) - v2 ~ Normal(0, 1) - v3 ~ Normal(0, 1) - y ~ mixed_v(v = [v1, v2, v3]) - end - - constraints_8 = @constraints begin - for q in mixed_v - q(v, y) = q(v)q(y) - end - end - - @test_throws GraphPPL.UnresolvableFactorizationConstraintError local model = create_model( - with_plugins(mixed_m(), PluginsCollection(VariationalConstraintsPlugin(constraints_8))) - ) - - @model function ordinary_v() - local v - for i in 1:3 - v[i] ~ Normal(0, 1) - end - y ~ Normal(v[1], v[2]) - end - - constraints_9 = @constraints begin - q(v[1:2]) = q(v[1])q(v[2]) - q(v, y) = q(v)q(y) - end - - model = create_model(with_plugins(ordinary_v(), PluginsCollection(VariationalConstraintsPlugin(constraints_9)))) - ctx = getcontext(model) - for node in filter(as_node(Normal), model) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - end - - @model function operate_slice(y, v) - local v - for i in 1:3 - v[i] ~ Normal(0, 1) - end - y ~ Normal(v[1], v[2]) - end - - @model function pass_slice() - local m - for i in 1:3 - for j in 1:3 - m[i, j] ~ Normal(0, 1) - end - end - v = GraphPPL.ResizableArray(m[:, 1]) - y ~ operate_slice(v = v) - end - - constraints_10 = @constraints begin - for q in operate_slice - q(v, y) = q(v[begin]) .. q(v[end])q(y) - end - end - - @test_throws GraphPPL.NotImplementedError local model = create_model( - with_plugins(pass_slice(), PluginsCollection(VariationalConstraintsPlugin(constraints_10))) - ) - - constraints_11 = @constraints begin - q(x, z, y) = q(z)(q(x[begin + 1]) .. q(x[end]))(q(y[begin + 1]) .. q(y[end])) - end - - model = create_model(with_plugins(TestUtils.vector_model(), PluginsCollection(VariationalConstraintsPlugin(constraints_11)))) - - ctx = getcontext(model) - for node in filter(as_node(Normal), model) - if any(x -> x ∈ GraphPPL.neighbors(model, node), ctx[:y][1]) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2, 3]) - else - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - end - end - - constraints_12 = @constraints begin - q(mat) = q(mat[begin]) .. q(mat[end]) - end - @test_throws NotImplementedError local model = create_model( - with_plugins(outer_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_12))) - ) - - @model function some_matrix() - local mat - for i in 1:3 - for j in 1:3 - mat[i, j] ~ Normal(0, 1) - end - end - y ~ Normal(mat[1, 1], mat[2, 2]) - end - - constraints_13 = @constraints begin - q(mat) = MeanField() - q(mat, y) = q(mat)q(y) - end - model = create_model(with_plugins(some_matrix(), PluginsCollection(VariationalConstraintsPlugin(constraints_13)))) - ctx = getcontext(model) - for node in filter(as_node(Normal), model) - @test getextra(model[node], :factorization_constraint_indices) == ([1], [2], [3]) - end -end - -@testitem "Test factorization constraint with automatically folded data/const variables" begin - using Distributions - import GraphPPL: - getproperties, - PluginsCollection, - VariationalConstraintsPlugin, - NodeCreationOptions, - getorcreate!, - with_plugins, - create_model, - getextra, - VariationalConstraintsFactorizationIndicesKey, - @model - - @model function fold_datavars(f, a, b) - y ~ Normal(f(f(a, b), f(a, b)), 0.5) - end - - @testset for f in (+, *, (a, b) -> a + b, (a, b) -> a * b), case in (1, 2, 3) - model = create_model(with_plugins(fold_datavars(f = f), PluginsCollection(VariationalConstraintsPlugin()))) do model, ctx - if case === 1 - return ( - a = getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 0.35), :a, nothing), - b = getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 0.54), :b, nothing) - ) - elseif case === 2 - return ( - a = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :a, nothing), - b = getorcreate!(model, ctx, NodeCreationOptions(kind = :constant, value = 0.54), :b, nothing) - ) - elseif case === 3 - return ( - a = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :a, nothing), - b = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :b, nothing) - ) - end - end - - @test length(collect(filter(as_node(Normal), model))) === 1 - @test length(collect(filter(as_node(f), model))) === 0 - - foreach(collect(filter(as_node(Normal), model))) do node - @test getextra(model[node], VariationalConstraintsFactorizationIndicesKey) == ([1], [2], [3]) - end - end -end - -@testitem "show constraints" begin - using Distributions - using GraphPPL - - constraint = @constraints begin - q(x)::Normal - end - @test occursin(r"q\(x\) ::(.*?)Normal", repr(constraint)) - - constraint = @constraints begin - q(x, y) = q(x)q(y) - end - @test occursin(r"q\(x, y\) = q\(x\)q\(y\)", repr(constraint)) - - constraint = @constraints begin - μ(x)::Normal - end - @test occursin(r"μ\(x\) ::(.*?)Normal", repr(constraint)) - - constraint = @constraints begin - q(x, y) = q(x)q(y) - μ(x)::Normal - end - @test occursin(r"q\(x, y\) = q\(x\)q\(y\)", repr(constraint)) - @test occursin(r"μ\(x\) ::(.*?)Normal", repr(constraint)) -end