|
| 1 | +using DynamicExpressions |
| 2 | +using Random |
| 3 | +using Test |
| 4 | + |
| 5 | +operators = OperatorEnum(; |
| 6 | + binary_operators=[+, -, *, /], unary_operators=[cos, sin], enable_autodiff=true |
| 7 | +); |
| 8 | +x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3) |
| 9 | +tree = cos(x1 * 3.2 - 5.8) * 0.2 - 0.5 * x2 * x3 * x3 + 0.9 / (x1 * x1 + 1); |
| 10 | + |
| 11 | +@testset "all" begin |
| 12 | + ctree = copy(tree) |
| 13 | + @test all(t -> t.degree != -1, ctree) |
| 14 | + @test all( |
| 15 | + t -> t.degree != 0 || !t.constant || t.val in (3.2, 5.8, 0.2, 0.5, 0.9, 1.0), ctree |
| 16 | + ) |
| 17 | + @test !all( |
| 18 | + t -> t.degree != 0 || !t.constant || t.val in (3.2, 5.8, 0.2, 0.9, 1.0), ctree |
| 19 | + ) |
| 20 | + @test all(t -> t.degree != 0 || t.constant || t.feature in (1, 2, 3), ctree) |
| 21 | + @test !all(t -> t.degree != 0 || t.constant || t.feature in (2, 3), ctree) |
| 22 | + @test all(t -> t.degree != 1 || t.op == 1, ctree) |
| 23 | + @test !all(t -> t.degree != 1 || t.op == 2, ctree) |
| 24 | +end |
| 25 | + |
| 26 | +@testset "any" begin |
| 27 | + ctree = copy(tree) |
| 28 | + @test any(t.degree == 2, ctree) |
| 29 | + @test any(_ -> true, ctree) |
| 30 | + @test any(t -> t.degree == 0 && t.constant && t.val == 3.2, ctree) |
| 31 | + @test !any(t -> t.degree == 0 && t.constant && t.val == 3.3, ctree) |
| 32 | +end |
| 33 | + |
| 34 | +@testset "collect" begin |
| 35 | + ctree = copy(tree) |
| 36 | + @test first(collect(ctree)) == Node{Float64} |
| 37 | + @test objectid(first(collect(ctree))) == objectid(ctree) |
| 38 | + @test objectid(first(collect(ctree))) == objectid(ctree) |
| 39 | + @test objectid(first(collect(ctree))) == objectid(ctree) |
| 40 | + @test typeof(collect(ctree)) == Vector{Node{Float64}} |
| 41 | + @test length(collect(ctree)) == 24 |
| 42 | + @test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) == 11.6 |
| 43 | +end |
| 44 | + |
| 45 | +@testset "count" begin |
| 46 | + ctree = copy(tree) |
| 47 | + @test count(_ -> true, ctree) == 24 |
| 48 | + @test count(t -> t.degree == 0, ctree) == 12 |
| 49 | + @test count(t -> t.degree == 1, ctree) == 1 |
| 50 | + @test count(t -> t.degree == 2, ctree) == 11 |
| 51 | + @test count(t -> t.degree == 0 && t.constant, ctree) == 6 |
| 52 | + @test count(t -> t.degree == 0 && t.constant && t.val == 1, ctree) == 1 |
| 53 | +end |
| 54 | + |
| 55 | +@testset "filter" begin |
| 56 | + ctree = copy(tree) |
| 57 | + @test filter(_ -> true, ctree) == collect(ctree) |
| 58 | + @test length(filter(t -> t.degree == 0 && !t.constant)) == 6 |
| 59 | + @test unique(filter(t -> t.degree == 0 && !t.constant)) == [x1, x2, x3] |
| 60 | + @test length(filter(t -> t.degree == 1, ctree)) == 1 |
| 61 | + @test length(filter(t -> t.degree == 2, ctree)) == 11 |
| 62 | + @test filter(==(x1), ctree) == [x1, x1, x1] |
| 63 | +end |
| 64 | + |
| 65 | +@testset "foreach" begin |
| 66 | + ctree = copy(tree) |
| 67 | + counter = Ref(0) |
| 68 | + foreach(ctree) do t |
| 69 | + counter.x += 1 |
| 70 | + end |
| 71 | + @test counter.x == 24 |
| 72 | + foreach(ctree) do t |
| 73 | + if t.degree == 0 && t.constant |
| 74 | + t.val *= 2 |
| 75 | + end |
| 76 | + end |
| 77 | + @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2 |
| 78 | +end |
| 79 | + |
| 80 | +@testset "iterate" begin |
| 81 | + ctree = copy(tree) |
| 82 | + counter = Ref(0) |
| 83 | + for t in ctree |
| 84 | + counter.x += 1 |
| 85 | + end |
| 86 | + @test counter.x == 24 |
| 87 | + for t in ctree |
| 88 | + if t.degree == 0 && t.constant |
| 89 | + t.val *= 2 |
| 90 | + end |
| 91 | + end |
| 92 | + @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2 |
| 93 | + |
| 94 | + # iterate within iterate: |
| 95 | + counter = Ref(0) |
| 96 | + for t in ctree |
| 97 | + for t2 in t |
| 98 | + counter.x += 1 |
| 99 | + end |
| 100 | + end |
| 101 | + @test counter.x == 104 |
| 102 | +end |
| 103 | + |
| 104 | +@testset "map" begin |
| 105 | + ctree = copy(tree) |
| 106 | + vals = map(t -> t.val, ctree) |
| 107 | + vals = [v for v in vals if v !== nothing] |
| 108 | + @test sum(vals) == 11.6 |
| 109 | + @test sum(map(_ -> 1, ctree)) == 24 |
| 110 | + @test sum(map(_ -> 2, ctree)) == 24 * 2 |
| 111 | + @test sum(map(t -> t.degree == 1, ctree)) == 1 |
| 112 | + @test length(unique(map(objectid, copy_node(tree)))) == 24 |
| 113 | + @test length(unique(map(objectid, copy_node(tree; preserve_sharing=true)))) == 24 - 3 |
| 114 | + map(t -> (t.degree == 0 && t.constant) ? (t.val *= 2) : nothing, ctree) |
| 115 | + @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2 |
| 116 | +end |
| 117 | + |
| 118 | +@testset "in" begin |
| 119 | + ctree = copy(tree) |
| 120 | + @test x1 in ctree |
| 121 | + @test Node(Float64; val=1.0) ∈ ctree |
| 122 | + @test Node(Float32; val=1.0) ∈ ctree |
| 123 | + @test Node(Float64; val=1.1) ∉ ctree |
| 124 | + @test ctree.l ∈ ctree |
| 125 | + @test ctree.l * 2 ∉ ctree |
| 126 | +end |
| 127 | + |
| 128 | +@testset "isempty" begin |
| 129 | + @test !isempty(tree) |
| 130 | + @test !isempty(Node(Float32; val=1)) |
| 131 | +end |
| 132 | + |
| 133 | +@testset "length" begin |
| 134 | + @test length(tree) == 24 |
| 135 | + @test length(tree.l) == 16 |
| 136 | + @test length(tree.r) == 24 - 16 - 1 |
| 137 | +end |
| 138 | + |
| 139 | +@testset "mapreduce" begin |
| 140 | + @test mapreduce(_ -> 1, +, tree) == 24 |
| 141 | + @test mapreduce(_ -> 2, +, tree) == 48 |
| 142 | + @test mapreduce(_ -> 1, *, tree) == 1 |
| 143 | + @test mapreduce(_ -> 2, *, tree) == 2^24 |
| 144 | + @test mapreduce(t -> t.degree, *, tree) == 0 |
| 145 | + @test mapreduce(t -> t.degree + 1, *, tree) == 354294 |
| 146 | + @test mapreduce(t -> t.degree, (l, r) -> (max(l, 1) * max(r, 1)), tree) == 2048 |
| 147 | + @test mapreduce(+, tree) do t |
| 148 | + 1 |
| 149 | + end == 24 |
| 150 | +end |
| 151 | + |
| 152 | +@testset "sum" begin |
| 153 | + ctree = copy(tree) |
| 154 | + @test sum(t -> t.degree == 0 && t.constant ? t.val : 0.0, ctree) == 11.6 |
| 155 | + @test sum(t -> t.degree == 0 && !t.constant ? t.feature : 0, ctree) == |
| 156 | + 3 * 1 + 1 * 2 + 2 * 3 |
| 157 | + @test sum(t -> t.degree == 1 ? t.op : 0, ctree) == 1 |
| 158 | + @test sum(t -> (t.degree == 0 && t.constant) ? t.val * 2 : 0.0, ctree) == 11.6 * 2 |
| 159 | + for t in ctree |
| 160 | + if t.degree == 0 && t.constant |
| 161 | + t.val *= 1.5 |
| 162 | + end |
| 163 | + end |
| 164 | + @test sum(t -> (t.degree == 0 && t.constant) ? t.val : 0.0, ctree) ≈ 11.6 * 1.5 |
| 165 | +end |
0 commit comments