Skip to content

Commit 5c03421

Browse files
authored
Merge pull request #1003 from JuliaAI/children
Ensure `source() != source()` and clean up some tests
2 parents b7d47a7 + f8044b3 commit 5c03421

File tree

5 files changed

+19
-15
lines changed

5 files changed

+19
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "1.8.1"
4+
version = "1.8.2"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/composition/learning_networks/inspection.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,6 @@ train_args(::Source) = []
114114
train_args(N::Node{<:Machine}) = N.machine.args
115115
train_args(N::Node{Nothing}) = []
116116

117-
"""
118-
children(N::AbstractNode, y::AbstractNode)
119-
120-
List all (immediate) children of node `N` in the ancestor graph of `y`
121-
(training edges included).
122-
123-
"""
124-
children(N::AbstractNode, y::AbstractNode) = filter(nodes(y)) do W
125-
N in args(W) || N in train_args(W)
126-
end |> unique
127-
128117
"""
129118
lower_bound(type_itr)
130119

src/sources.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ mutable struct Source <: AbstractNode
3030
scitype::DataType
3131
end
3232

33+
# To ensure `source() != source()`:
34+
MMI.is_same_except(s1::Source, s2::Source; kwargs...) = s1 === s2
35+
3336
"""
3437
Xs = source(X=nothing)
3538

test/composition/learning_networks/inspection.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,19 @@ module TestLearningCompositesInspection
33
using Test
44
using MLJBase
55
using ..Models
6+
import MLJModelInterface as MMI
67

7-
KNNRegressor()
8+
"""
9+
children(N::AbstractNode, y::AbstractNode)
10+
11+
List all (immediate) children of node `N` in the ancestor graph of `y`
12+
(training edges included).
13+
14+
"""
15+
children(N::AbstractNode, y::AbstractNode) = filter(nodes(y)) do Z
16+
t = N in MLJBase.args(Z) ||
17+
N in MLJBase.train_args(Z)
18+
end |> unique
819

920
@constant X = source()
1021
@constant y = source()
@@ -36,12 +47,12 @@ knnM = machine(knn, W, y)
3647
@test Set(machines(yhat)) == Set([knnM, hotM])
3748
@test Set(MLJBase.args(yhat)) == Set([W, ])
3849
@test Set(MLJBase.train_args(yhat)) == Set([W, y])
39-
@test Set(MLJBase.children(X, all)) == Set([W, K])
50+
@test Set(children(X, all)) == Set([W, K])
4051

4152
@constant Q = 2X
4253
@constant R = 3X
4354
@constant S = glb(X, Q, R)
44-
@test Set(MLJBase.children(X, S)) == Set([Q, R, S])
55+
@test Set(children(X, S)) == Set([Q, R, S])
4556
@test MLJBase.lower_bound([Int, Float64]) == Union{}
4657
@test MLJBase.lower_bound([Int, Integer]) == Int
4758
@test MLJBase.lower_bound([Int, Integer]) == Int

test/sources.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Xs = source(X)
1313
rebind!(Xs, nothing)
1414
@test isempty(Xs)
1515
@test Xs.scitype == Nothing
16+
@test source() != source()
1617

1718
end
1819
true

0 commit comments

Comments
 (0)