Skip to content

Commit 428d819

Browse files
committed
call Zygote.refresh every 10 distributions in tests
1 parent 8fd1716 commit 428d819

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

test/test_utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ function get_stage()
147147
return "all"
148148
end
149149

150+
const zygote_counter = Ref(0)
151+
150152
function test_ad(f, at = 0.5; rtol = 1e-8, atol = 1e-8)
151153
stg = get_stage()
152154
if stg == "all"
@@ -177,7 +179,11 @@ function test_ad(f, at = 0.5; rtol = 1e-8, atol = 1e-8)
177179
@test isapprox(reverse_tracker, finite_diff, rtol=rtol, atol=atol)
178180
end
179181
elseif stg == "Zygote"
182+
zygote_counter[] += 1
180183
isarr = isa(at, AbstractArray)
184+
if mod(zygote_counter[], 10) == 0
185+
Zygote.refresh()
186+
end
181187
reverse_zygote = Zygote.gradient(f, at)[1]
182188
if isarr
183189
forward = ForwardDiff.gradient(f, at)

0 commit comments

Comments
 (0)