Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 9d1354f

Browse files
mode in stoch ad tests for enzyme
1 parent 7e75714 commit 9d1354f

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
9898
fmode = if adtype.mode isa Nothing
9999
Enzyme.Forward
100100
else
101-
set_runtime_activity2(Enzyme.Forward.adtype.mode)
101+
set_runtime_activity2(Enzyme.Forward, adtype.mode)
102102
end
103103

104104
if g == true && f.grad === nothing

test/adtests.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,17 +1172,18 @@ using MLUtils
11721172

11731173
optf = OptimizationFunction(loss, AutoEnzyme())
11741174
optf = OptimizationBase.instantiate_function(
1175-
optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true)
1175+
optf, rand(3), AutoEnzyme(mode = set_runtime_activity(Reverse)),
1176+
iterate(data)[1], g = true, fg = true)
11761177
G0 = zeros(3)
1177-
@test_broken optf.grad(G0, ones(3), (x0, y0))
1178-
# stochgrads = []
1179-
# for (x,y) in data
1180-
# G = zeros(3)
1181-
# optf.grad(G, ones(3), (x,y))
1182-
# push!(stochgrads, copy(G))
1183-
# G1 = zeros(3)
1184-
# optf.fg(G1, ones(3), (x,y))
1185-
# @test GG1 rtol=1e-6
1186-
# end
1187-
# @test G0sum(stochgrads)/length(stochgrads) rtol=1e-1
1178+
optf.grad(G0, ones(3), (x0, y0))
1179+
stochgrads = []
1180+
for (x, y) in data
1181+
G = zeros(3)
1182+
optf.grad(G, ones(3), (x, y))
1183+
push!(stochgrads, copy(G))
1184+
G1 = zeros(3)
1185+
optf.fg(G1, ones(3), (x, y))
1186+
@test GG1 rtol=1e-6
1187+
end
1188+
@test G0sum(stochgrads) rtol=1e-1
11881189
end

0 commit comments

Comments
 (0)