Skip to content

Commit eeafb34

Browse files
committed
style: cleaner test code
1 parent f948ddb commit eeafb34

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

test/test_optim.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,26 @@ end
8080

8181
tree = copy(original_tree)
8282
did_i_run_2 = Ref(false)
83-
fg!(F, G, tree) =
84-
let
85-
if G !== nothing
86-
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(
87-
tree, X, operators; variable=false
83+
function my_fg!(F, G, tree)
84+
if G !== nothing
85+
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(tree, X, operators; variable=false)
86+
dresult_dŷ = @. 2 * (ŷ - y)
87+
for i in eachindex(G)
88+
G[i] = sum(
89+
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
90+
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
8891
)
89-
dresult_dŷ = @. 2 * (ŷ - y)
90-
for i in eachindex(G)
91-
G[i] = sum(
92-
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
93-
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
94-
)
95-
end
96-
if F !== nothing
97-
did_i_run_2[] = true
98-
return sum(abs2, ŷ .- y)
99-
end
100-
elseif F !== nothing
101-
# Only f
102-
return sum(abs2, tree(X, operators) .- y)
10392
end
93+
if F !== nothing
94+
did_i_run_2[] = true
95+
return sum(abs2, ŷ .- y)
96+
end
97+
elseif F !== nothing
98+
# Only f
99+
return sum(abs2, tree(X, operators) .- y)
104100
end
105-
res = optimize(Optim.only_fg!(fg!), tree, BFGS())
101+
end
102+
res = optimize(Optim.only_fg!(my_fg!), tree, BFGS())
106103

107104
@test did_i_run_2[]
108105
@test isapprox(

test/test_optim_setup.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
using DynamicExpressions
22
using Random: MersenneTwister as RNG
33

4-
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
4+
const operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
55
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)
66

7-
X = rand(RNG(0), Float64, 2, 100)
8-
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
7+
const X = rand(RNG(0), Float64, 2, 100)
8+
const y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9
99

1010
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
1111
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2
1212

13-
f(tree) = sum(abs2, tree(X, operators) .- y)
13+
function f(tree)
14+
out = tree(X, operators)
15+
return sum(i -> abs2(out[i] - y[i]), eachindex(out, y))
16+
end
1417
function g!(G, tree)
1518
dy = only(gradient(f, tree))
1619
G .= dy.gradient

0 commit comments

Comments
 (0)