@@ -3,8 +3,19 @@ module TestLearningCompositesInspection
3
3
using Test
4
4
using MLJBase
5
5
using .. Models
6
+ import MLJModelInterface as MMI
6
7
7
- KNNRegressor ()
8
+ """
9
+ children(N::AbstractNode, y::AbstractNode)
10
+
11
+ List all (immediate) children of node `N` in the ancestor graph of `y`
12
+ (training edges included).
13
+
14
+ """
15
+ children (N:: AbstractNode , y:: AbstractNode ) = filter (nodes (y)) do Z
16
+ t = N in MLJBase. args (Z) ||
17
+ N in MLJBase. train_args (Z)
18
+ end |> unique
8
19
9
20
@constant X = source ()
10
21
@constant y = source ()
@@ -36,12 +47,12 @@ knnM = machine(knn, W, y)
36
47
@test Set (machines (yhat)) == Set ([knnM, hotM])
37
48
@test Set (MLJBase. args (yhat)) == Set ([W, ])
38
49
@test Set (MLJBase. train_args (yhat)) == Set ([W, y])
39
- @test Set (MLJBase . children (X, all)) == Set ([W, K])
50
+ @test Set (children (X, all)) == Set ([W, K])
40
51
41
52
@constant Q = 2 X
42
53
@constant R = 3 X
43
54
@constant S = glb (X, Q, R)
44
- @test Set (MLJBase . children (X, S)) == Set ([Q, R, S])
55
+ @test Set (children (X, S)) == Set ([Q, R, S])
45
56
@test MLJBase. lower_bound ([Int, Float64]) == Union{}
46
57
@test MLJBase. lower_bound ([Int, Integer]) == Int
47
58
@test MLJBase. lower_bound ([Int, Integer]) == Int
0 commit comments