Skip to content

Commit d2ea94a

Browse files
authored
Various CI improvements (#2375)
* Remove check on Julia>1.2 * Upgrade julia-actions/cache to v2 * Refactor CI matrix to make intent clearer * Reinsert 1.11/2 threads combination
1 parent 42189fd commit d2ea94a

File tree

2 files changed

+100
-96
lines changed

2 files changed

+100
-96
lines changed

.github/workflows/Tests.yml

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ on:
99
jobs:
1010
test:
1111
# Use matrix.test.name here to avoid it taking up the entire window width
12-
name: test ${{matrix.test.name}} (${{ matrix.os }}, ${{ matrix.version }}, ${{ matrix.arch }}, ${{ matrix.num_threads }})
13-
runs-on: ${{ matrix.os }}
14-
continue-on-error: ${{ matrix.version == 'nightly' }}
12+
name: test ${{matrix.test.name}} (${{ matrix.runner.os }}, ${{ matrix.runner.version }}, ${{ matrix.runner.arch }}, ${{ matrix.runner.num_threads }})
13+
runs-on: ${{ matrix.runner.os }}
14+
continue-on-error: ${{ matrix.runner.version == 'pre' }}
1515

1616
strategy:
1717
fail-fast: false
@@ -35,63 +35,71 @@ jobs:
3535
args: "mcmc/ess.jl"
3636
- name: "everything else"
3737
args: "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl"
38-
version:
39-
- '1.10'
40-
- '1'
41-
os:
42-
- ubuntu-latest
43-
- windows-latest
44-
- macOS-latest
45-
arch:
46-
- x64
47-
- x86
48-
num_threads:
49-
- 1
50-
- 2
51-
exclude:
52-
# With Windows and macOS, only run x64, 2 threads. We just want to see
53-
# some combination work on OSes other than Ubuntu.
54-
- os: windows-latest
55-
version: '1'
56-
- os: macOS-latest
57-
version: '1'
58-
- os: windows-latest
59-
arch: x86
60-
- os: macOS-latest
38+
runner:
39+
# Default
40+
- version: '1'
41+
os: ubuntu-latest
42+
arch: x64
43+
num_threads: 1
44+
# x86
45+
- version: '1'
46+
os: ubuntu-latest
6147
arch: x86
62-
- os: windows-latest
6348
num_threads: 1
64-
- os: macOS-latest
49+
# Multithreaded
50+
- version: '1'
51+
os: ubuntu-latest
52+
arch: x64
53+
num_threads: 2
54+
# Windows
55+
- version: '1'
56+
os: windows-latest
57+
arch: x64
58+
num_threads: 1
59+
# macOS
60+
- version: '1'
61+
os: macos-latest
62+
arch: x64
63+
num_threads: 1
64+
# Minimum supported Julia version
65+
- version: 'min'
66+
os: ubuntu-latest
67+
arch: x64
6568
num_threads: 1
66-
# It's sufficient to test x86 with only Julia 1.10 and 1 thread.
67-
- arch: x86
68-
version: '1'
69-
- arch: x86
69+
# Minimum supported Julia version, multithreaded
70+
- version: 'min'
71+
os: ubuntu-latest
72+
arch: x64
7073
num_threads: 2
74+
# Pre-release Julia version
75+
- version: 'pre'
76+
os: ubuntu-latest
77+
arch: x64
78+
num_threads: 1
7179

7280
steps:
7381
- name: Print matrix variables
7482
run: |
75-
echo "OS: ${{ matrix.os }}"
76-
echo "Architecture: ${{ matrix.arch }}"
77-
echo "Julia version: ${{ matrix.version }}"
78-
echo "Number of threads: ${{ matrix.num_threads }}"
83+
echo "OS: ${{ matrix.runner.os }}"
84+
echo "Architecture: ${{ matrix.runner.arch }}"
85+
echo "Julia version: ${{ matrix.runner.version }}"
86+
echo "Number of threads: ${{ matrix.runner.num_threads }}"
7987
echo "Test arguments: ${{ matrix.test.args }}"
8088
- name: (De)activate coverage analysis
81-
run: echo "COVERAGE=${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 2 }}" >> "$GITHUB_ENV"
89+
run: echo "COVERAGE=${{ matrix.runner.version == '1' && matrix.runner.os == 'ubuntu-latest' && matrix.runner.num_threads == 2 }}" >> "$GITHUB_ENV"
8290
shell: bash
8391
- uses: actions/checkout@v4
8492
- uses: julia-actions/setup-julia@v2
8593
with:
86-
version: '${{ matrix.version }}'
87-
arch: ${{ matrix.arch }}
88-
- uses: julia-actions/cache@v1
94+
version: '${{ matrix.runner.version }}'
95+
arch: ${{ matrix.runner.arch }}
96+
- uses: julia-actions/cache@v2
8997
- uses: julia-actions/julia-buildpkg@v1
9098
# TODO: Use julia-actions/julia-runtest when test_args are supported
9199
# Custom calls of Pkg.test tend to miss features such as e.g. adjustments for CompatHelper PRs
92100
# Ref https://github.com/julia-actions/julia-runtest/pull/73
93101
- name: Call Pkg.test
94-
run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test.args }}
102+
run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.runner.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test.args }}
95103
- uses: julia-actions/julia-processcoverage@v1
96104
if: ${{ env.COVERAGE }}
97105
- uses: codecov/codecov-action@v4

test/mcmc/Inference.jl

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,70 +17,66 @@ using Test: @test, @test_throws, @testset
1717
using Turing
1818

1919
@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
20-
# Only test threading if 1.3+.
21-
if VERSION > v"1.2"
22-
@testset "threaded sampling" begin
23-
# Test that chains with the same seed will sample identically.
24-
@testset "rng" begin
25-
model = gdemo_default
26-
27-
# multithreaded sampling with PG causes segfaults on Julia 1.5.4
28-
# https://github.com/TuringLang/Turing.jl/issues/1571
29-
samplers = @static if VERSION <= v"1.5.3" || VERSION >= v"1.6.0"
30-
(
31-
HMC(0.1, 7; adtype=adbackend),
32-
PG(10),
33-
IS(),
34-
MH(),
35-
Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)),
36-
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
37-
)
38-
else
39-
(
40-
HMC(0.1, 7; adtype=adbackend),
41-
IS(),
42-
MH(),
43-
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
44-
)
45-
end
46-
for sampler in samplers
47-
Random.seed!(5)
48-
chain1 = sample(model, sampler, MCMCThreads(), 1000, 4)
20+
@testset "threaded sampling" begin
21+
# Test that chains with the same seed will sample identically.
22+
@testset "rng" begin
23+
model = gdemo_default
24+
25+
# multithreaded sampling with PG causes segfaults on Julia 1.5.4
26+
# https://github.com/TuringLang/Turing.jl/issues/1571
27+
samplers = @static if VERSION <= v"1.5.3" || VERSION >= v"1.6.0"
28+
(
29+
HMC(0.1, 7; adtype=adbackend),
30+
PG(10),
31+
IS(),
32+
MH(),
33+
Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)),
34+
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
35+
)
36+
else
37+
(
38+
HMC(0.1, 7; adtype=adbackend),
39+
IS(),
40+
MH(),
41+
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
42+
)
43+
end
44+
for sampler in samplers
45+
Random.seed!(5)
46+
chain1 = sample(model, sampler, MCMCThreads(), 1000, 4)
4947

50-
Random.seed!(5)
51-
chain2 = sample(model, sampler, MCMCThreads(), 1000, 4)
48+
Random.seed!(5)
49+
chain2 = sample(model, sampler, MCMCThreads(), 1000, 4)
5250

53-
@test chain1.value == chain2.value
54-
end
51+
@test chain1.value == chain2.value
52+
end
5553

56-
# Should also be stable with am explicit RNG
57-
seed = 5
58-
rng = Random.MersenneTwister(seed)
59-
for sampler in samplers
60-
Random.seed!(rng, seed)
61-
chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)
54+
# Should also be stable with am explicit RNG
55+
seed = 5
56+
rng = Random.MersenneTwister(seed)
57+
for sampler in samplers
58+
Random.seed!(rng, seed)
59+
chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)
6260

63-
Random.seed!(rng, seed)
64-
chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)
61+
Random.seed!(rng, seed)
62+
chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)
6563

66-
@test chain1.value == chain2.value
67-
end
64+
@test chain1.value == chain2.value
6865
end
66+
end
6967

70-
# Smoke test for default sample call.
71-
Random.seed!(100)
72-
chain = sample(
73-
gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4
74-
)
75-
check_gdemo(chain)
68+
# Smoke test for default sample call.
69+
Random.seed!(100)
70+
chain = sample(gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4)
71+
check_gdemo(chain)
7672

77-
# run sampler: progress logging should be disabled and
78-
# it should return a Chains object
79-
sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default)
80-
chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4)
81-
@test chains isa MCMCChains.Chains
82-
end
73+
# run sampler: progress logging should be disabled and
74+
# it should return a Chains object
75+
sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default)
76+
chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4)
77+
@test chains isa MCMCChains.Chains
8378
end
79+
8480
@testset "chain save/resume" begin
8581
Random.seed!(1234)
8682

0 commit comments

Comments
 (0)