Skip to content

Commit 687d2eb

Browse files
authored
Merge branch 'master' into csp/infer
2 parents 336b7a4 + 4bfe617 commit 687d2eb

File tree

4 files changed

+165
-68
lines changed

4 files changed

+165
-68
lines changed

Project.toml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,31 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "0.5.0"
6+
version = "0.5.2"
77

88
[deps]
9+
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1011
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
12+
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
1113
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1214
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1315
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
14-
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
16+
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
1517

1618
[compat]
19+
ConsoleProgressMonitor = "0.1"
20+
LoggingExtras = "0.4"
1721
ProgressLogging = "0.1"
1822
StatsBase = "0.32"
23+
TerminalLoggers = "0.1"
1924
julia = "1"
2025

2126
[extras]
27+
Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1"
28+
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
2229
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
23-
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2430
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2531

2632
[targets]
27-
test = ["Statistics", "Test", "TerminalLoggers"]
33+
test = ["Atom", "IJulia", "Statistics", "Test"]

src/AbstractMCMC.jl

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,61 @@
11
module AbstractMCMC
22

3+
import ConsoleProgressMonitor
4+
import LoggingExtras
35
import ProgressLogging
46
import StatsBase
57
using StatsBase: sample
8+
import TerminalLoggers
69

710
import Distributed
811
import Logging
912
using Random: GLOBAL_RNG, AbstractRNG, seed!
10-
import UUIDs
13+
14+
# avoid creating a progress bar with @withprogress if progress logging is disabled
15+
# and add a custom progress logger if the current logger does not seem to be able to handle
16+
# progress logs
17+
macro ifwithprogresslogger(progress, exprs...)
18+
return quote
19+
if $progress
20+
if $hasprogresslevel($Logging.current_logger())
21+
$ProgressLogging.@withprogress $(exprs...)
22+
else
23+
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
24+
$ProgressLogging.@withprogress $(exprs...)
25+
end
26+
end
27+
else
28+
$(exprs[end])
29+
end
30+
end |> esc
31+
end
32+
33+
# improved checks?
34+
function hasprogresslevel(logger)
35+
return Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel
36+
end
37+
38+
# filter better, e.g., according to group?
39+
function with_progresslogger(f, _module, logger)
40+
logger1 = LoggingExtras.EarlyFilteredLogger(progresslogger()) do log
41+
log._module === _module && log.level == ProgressLogging.ProgressLevel
42+
end
43+
logger2 = LoggingExtras.EarlyFilteredLogger(logger) do log
44+
log._module !== _module || log.level != ProgressLogging.ProgressLevel
45+
end
46+
47+
Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2))
48+
end
49+
50+
function progresslogger()
51+
# detect if code is running under IJulia since TerminalLogger does not work with IJulia
52+
# https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
53+
if isdefined(Main, :IJulia) && Main.IJulia.inited
54+
return ConsoleProgressMonitor.ProgressLogger()
55+
else
56+
return TerminalLoggers.TerminalLogger()
57+
end
58+
end
1159

1260
"""
1361
AbstractChains
@@ -44,7 +92,7 @@ abstract type AbstractModel end
4492
4593
Return `N` samples from the MCMC `sampler` for the provided `model`.
4694
47-
If a callback function `f` with type signature
95+
If a callback function `f` with type signature
4896
```julia
4997
f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
5098
iteration::Integer, transition; kwargs...)
@@ -77,15 +125,7 @@ function StatsBase.sample(
77125
# Perform any necessary setup.
78126
sample_init!(rng, model, sampler, N; kwargs...)
79127

80-
# Create a progress bar.
81-
if progress
82-
progressid = UUIDs.uuid4()
83-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
84-
_id=progressid)
85-
end
86-
87-
local transitions
88-
try
128+
@ifwithprogresslogger progress name=progressname begin
89129
# Obtain the initial transition.
90130
transition = step!(rng, model, sampler, N; iteration=1, kwargs...)
91131

@@ -97,10 +137,7 @@ function StatsBase.sample(
97137
transitions_save!(transitions, 1, transition, model, sampler, N; kwargs...)
98138

99139
# Update the progress bar.
100-
if progress
101-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=1/N,
102-
_id=progressid)
103-
end
140+
progress && ProgressLogging.@logprogress 1/N
104141

105142
# Step through the sampler.
106143
for i in 2:N
@@ -114,16 +151,7 @@ function StatsBase.sample(
114151
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)
115152

116153
# Update the progress bar.
117-
if progress
118-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=i/N,
119-
_id=progressid)
120-
end
121-
end
122-
finally
123-
# Close the progress bar.
124-
if progress
125-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress="done",
126-
_id=progressid)
154+
progress && ProgressLogging.@logprogress i/N
127155
end
128156
end
129157

@@ -178,12 +206,12 @@ function sample_end!(
178206
end
179207

180208
function bundle_samples(
181-
::AbstractRNG,
182-
::AbstractModel,
183-
::AbstractSampler,
184-
::Integer,
209+
::AbstractRNG,
210+
::AbstractModel,
211+
::AbstractSampler,
212+
::Integer,
185213
transitions,
186-
::Type{Any};
214+
::Type{Any};
187215
kwargs...
188216
)
189217
return transitions
@@ -259,7 +287,7 @@ end
259287
Sample `nchains` chains using the available threads, and combine them into a single chain.
260288
261289
By default, the random number generator, the model and the samplers are deep copied for each
262-
thread to prevent contamination between threads.
290+
thread to prevent contamination between threads.
263291
"""
264292
function psample(
265293
model::AbstractModel,
@@ -292,24 +320,20 @@ function psample(
292320
# Set up a chains vector.
293321
chains = Vector{Any}(undef, nchains)
294322

295-
# Create a progress bar and a channel for progress logging.
296-
if progress
297-
progressid = UUIDs.uuid4()
298-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
299-
_id=progressid)
300-
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
301-
end
323+
@ifwithprogresslogger progress name=progressname begin
324+
# Create a channel for progress logging.
325+
if progress
326+
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
327+
end
302328

303-
try
304329
Distributed.@sync begin
305330
if progress
306331
Distributed.@async begin
307332
# Update the progress bar.
308333
progresschains = 0
309334
while take!(channel)
310335
progresschains += 1
311-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
312-
progress=progresschains/nchains, _id=progressid)
336+
ProgressLogging.@logprogress progresschains/nchains
313337
end
314338
end
315339
end
@@ -322,7 +346,7 @@ function psample(
322346
# Seed the thread-specific random number generator with the pre-made seed.
323347
subrng = rngs[id]
324348
seed!(subrng, seeds[i])
325-
349+
326350
# Sample a chain and save it to the vector.
327351
chains[i] = sample(subrng, models[id], samplers[id], N;
328352
progress = false, kwargs...)
@@ -335,12 +359,6 @@ function psample(
335359
progress && put!(channel, false)
336360
end
337361
end
338-
finally
339-
# Close the progress bar.
340-
if progress
341-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
342-
progress="done", _id=progressid)
343-
end
344362
end
345363

346364
# Concatenate the chains together.

test/interface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ function AbstractMCMC.step!(
2020
N::Integer,
2121
transition::Union{Nothing,MyTransition};
2222
sleepy = false,
23+
loggers = false,
2324
kwargs...
2425
)
2526
a = rand(rng)
2627
b = randn(rng)
2728

29+
loggers && push!(LOGGERS, Logging.current_logger())
2830
sleepy && sleep(0.001)
2931

3032
return MyTransition(a, b)

test/runtests.jl

Lines changed: 87 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,104 @@
11
using AbstractMCMC
22
using AbstractMCMC: sample, psample, steps!
3-
import TerminalLoggers
3+
using Atom.Progress: JunoProgressLogger
4+
using ConsoleProgressMonitor: ProgressLogger
5+
using IJulia
6+
using LoggingExtras: TeeLogger, EarlyFilteredLogger
7+
using TerminalLoggers: TerminalLogger
48

59
import Logging
610
using Random
711
using Statistics
812
using Test
913
using Test: collect_test_logs
1014

11-
# install progress logger
12-
Logging.global_logger(TerminalLoggers.TerminalLogger(right_justify=120))
15+
const LOGGERS = Set()
16+
const CURRENT_LOGGER = Logging.current_logger()
1317

1418
include("interface.jl")
1519

1620
@testset "AbstractMCMC" begin
1721
@testset "Basic sampling" begin
18-
Random.seed!(1234)
19-
N = 1_000
20-
chain = sample(MyModel(), MySampler(), N; sleepy = true)
21-
22-
# test output type and size
23-
@test chain isa Vector{MyTransition}
24-
@test length(chain) == N
25-
26-
# test some statistical properties
27-
@test mean(x.a for x in chain) 0.5 atol=6e-2
28-
@test var(x.a for x in chain) 1 / 12 atol=5e-3
29-
@test mean(x.b for x in chain) 0.0 atol=5e-2
30-
@test var(x.b for x in chain) 1 atol=6e-2
22+
@testset "REPL" begin
23+
empty!(LOGGERS)
24+
25+
Random.seed!(1234)
26+
N = 1_000
27+
chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
28+
29+
@test length(LOGGERS) == 1
30+
logger = first(LOGGERS)
31+
@test logger isa TeeLogger
32+
@test logger.loggers[1].logger isa TerminalLogger
33+
@test logger.loggers[2].logger === CURRENT_LOGGER
34+
@test Logging.current_logger() === CURRENT_LOGGER
35+
36+
# test output type and size
37+
@test chain isa Vector{MyTransition}
38+
@test length(chain) == N
39+
40+
# test some statistical properties
41+
@test mean(x.a for x in chain) 0.5 atol=6e-2
42+
@test var(x.a for x in chain) 1 / 12 atol=5e-3
43+
@test mean(x.b for x in chain) 0.0 atol=5e-2
44+
@test var(x.b for x in chain) 1 atol=6e-2
45+
end
46+
47+
@testset "Juno" begin
48+
empty!(LOGGERS)
49+
50+
Random.seed!(1234)
51+
N = 10
52+
53+
logger = JunoProgressLogger()
54+
Logging.with_logger(logger) do
55+
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
56+
end
57+
58+
@test length(LOGGERS) == 1
59+
@test first(LOGGERS) === logger
60+
@test Logging.current_logger() === CURRENT_LOGGER
61+
end
62+
63+
@testset "IJulia" begin
64+
# emulate running IJulia kernel
65+
@eval IJulia begin
66+
inited = true
67+
end
68+
69+
empty!(LOGGERS)
70+
71+
Random.seed!(1234)
72+
N = 10
73+
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
74+
75+
@test length(LOGGERS) == 1
76+
logger = first(LOGGERS)
77+
@test logger isa TeeLogger
78+
@test logger.loggers[1].logger isa ProgressLogger
79+
@test logger.loggers[2].logger === CURRENT_LOGGER
80+
@test Logging.current_logger() === CURRENT_LOGGER
81+
82+
@eval IJulia begin
83+
inited = false
84+
end
85+
end
86+
87+
@testset "Custom logger" begin
88+
empty!(LOGGERS)
89+
90+
Random.seed!(1234)
91+
N = 10
92+
93+
logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1))
94+
Logging.with_logger(logger) do
95+
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
96+
end
97+
98+
@test length(LOGGERS) == 1
99+
@test first(LOGGERS) === logger
100+
@test Logging.current_logger() === CURRENT_LOGGER
101+
end
31102
end
32103

33104
if VERSION v"1.3"

0 commit comments

Comments
 (0)