diff --git a/Project.toml b/Project.toml index 2d191e82..46ddf258 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "1.8.1" +version = "1.8.2" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/composition/learning_networks/inspection.jl b/src/composition/learning_networks/inspection.jl index dbcd0a31..f830f2e2 100644 --- a/src/composition/learning_networks/inspection.jl +++ b/src/composition/learning_networks/inspection.jl @@ -114,17 +114,6 @@ train_args(::Source) = [] train_args(N::Node{<:Machine}) = N.machine.args train_args(N::Node{Nothing}) = [] -""" - children(N::AbstractNode, y::AbstractNode) - -List all (immediate) children of node `N` in the ancestor graph of `y` -(training edges included). - -""" -children(N::AbstractNode, y::AbstractNode) = filter(nodes(y)) do W - N in args(W) || N in train_args(W) -end |> unique - """ lower_bound(type_itr) diff --git a/src/sources.jl b/src/sources.jl index 083e269e..b126c0ba 100644 --- a/src/sources.jl +++ b/src/sources.jl @@ -30,6 +30,9 @@ mutable struct Source <: AbstractNode scitype::DataType end +# To ensure `source() != source()`: +MMI.is_same_except(s1::Source, s2::Source; kwargs...) = s1 === s2 + """ Xs = source(X=nothing) diff --git a/test/composition/learning_networks/inspection.jl b/test/composition/learning_networks/inspection.jl index d87081b2..eed5f6f3 100644 --- a/test/composition/learning_networks/inspection.jl +++ b/test/composition/learning_networks/inspection.jl @@ -3,8 +3,19 @@ module TestLearningCompositesInspection using Test using MLJBase using ..Models +import MLJModelInterface as MMI -KNNRegressor() +""" + children(N::AbstractNode, y::AbstractNode) + +List all (immediate) children of node `N` in the ancestor graph of `y` +(training edges included). + +""" +children(N::AbstractNode, y::AbstractNode) = filter(nodes(y)) do Z + t = N in MLJBase.args(Z) || + N in MLJBase.train_args(Z) +end |> unique @constant X = source() @constant y = source() @@ -36,12 +47,12 @@ knnM = machine(knn, W, y) @test Set(machines(yhat)) == Set([knnM, hotM]) @test Set(MLJBase.args(yhat)) == Set([W, ]) @test Set(MLJBase.train_args(yhat)) == Set([W, y]) -@test Set(MLJBase.children(X, all)) == Set([W, K]) +@test Set(children(X, all)) == Set([W, K]) @constant Q = 2X @constant R = 3X @constant S = glb(X, Q, R) -@test Set(MLJBase.children(X, S)) == Set([Q, R, S]) +@test Set(children(X, S)) == Set([Q, R, S]) @test MLJBase.lower_bound([Int, Float64]) == Union{} @test MLJBase.lower_bound([Int, Integer]) == Int @test MLJBase.lower_bound([Int, Integer]) == Int diff --git a/test/sources.jl b/test/sources.jl index a3836244..dfcd66c9 100644 --- a/test/sources.jl +++ b/test/sources.jl @@ -13,6 +13,7 @@ Xs = source(X) rebind!(Xs, nothing) @test isempty(Xs) @test Xs.scitype == Nothing +@test source() != source() end true