Skip to content

Commit 0dab8af

Browse files
authored
Merge pull request #38 from TuringLang/mt/fix_zygote_segfault
Call Zygote.refresh every 10 distributions in tests
2 parents 8fd1716 + 1cd1ebf commit 0dab8af

File tree

5 files changed

+10
-4
lines changed

5 files changed

+10
-4
lines changed

.github/workflows/ForwardDiff_Tracker.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
runs-on: ${{ matrix.os }}
1313
strategy:
1414
matrix:
15-
julia-version: [1.0.5, 1.2.0, 1.3]
15+
julia-version: [1.0, 1.3]
1616
julia-arch: [x64, x86]
1717
os: [ubuntu-latest, macOS-latest]
1818
exclude:

.github/workflows/Others.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
runs-on: ${{ matrix.os }}
1313
strategy:
1414
matrix:
15-
julia-version: [1.0.5, 1.2.0, 1.3]
15+
julia-version: [1.0, 1.3]
1616
julia-arch: [x64, x86]
1717
os: [ubuntu-latest, macOS-latest]
1818
exclude:

.github/workflows/Zygote.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
runs-on: ${{ matrix.os }}
1313
strategy:
1414
matrix:
15-
julia-version: [1.0.5, 1.2.0, 1.3]
15+
julia-version: [1.0, 1.3]
1616
julia-arch: [x64, x86]
1717
os: [ubuntu-latest, macOS-latest]
1818
exclude:

test/others.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ if get_stage() in ("Others", "all")
132132
d1 = TuringDiagMvNormal(zeros(10), sigmas)
133133
d2 = MvNormal(zeros(10), sigmas)
134134

135-
@test entropy(d1) == entropy(d2)
135+
@test isapprox(entropy(d1), entropy(d2), rtol = 1e-6)
136136
end
137137

138138
@testset "Params" begin

test/test_utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ function get_stage()
147147
return "all"
148148
end
149149

150+
const zygote_counter = Ref(0)
151+
150152
function test_ad(f, at = 0.5; rtol = 1e-8, atol = 1e-8)
151153
stg = get_stage()
152154
if stg == "all"
@@ -177,7 +179,11 @@ function test_ad(f, at = 0.5; rtol = 1e-8, atol = 1e-8)
177179
@test isapprox(reverse_tracker, finite_diff, rtol=rtol, atol=atol)
178180
end
179181
elseif stg == "Zygote"
182+
zygote_counter[] += 1
180183
isarr = isa(at, AbstractArray)
184+
if mod(zygote_counter[], 50) == 0
185+
Zygote.refresh()
186+
end
181187
reverse_zygote = Zygote.gradient(f, at)[1]
182188
if isarr
183189
forward = ForwardDiff.gradient(f, at)

0 commit comments

Comments
 (0)