Skip to content

Commit e442880

Browse files
committed
fixed initial_params and initial_state for MCMCDistributed
1 parent ca4f4b9 commit e442880

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ version = "4.5.0"
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1010
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
1111
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
12+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1213
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1314
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1415
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
@@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
2122
[compat]
2223
BangBang = "0.3.19"
2324
ConsoleProgressMonitor = "0.1"
25+
FillArrays = "1"
2426
LogDensityProblems = "2"
2527
LoggingExtras = "0.4, 0.5, 1"
2628
ProgressLogging = "0.1"

src/AbstractMCMC.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging
88
using StatsBase: StatsBase
99
using TerminalLoggers: TerminalLoggers
1010
using Transducers: Transducers
11+
using FillArrays: FillArrays
1112

1213
using Distributed: Distributed
1314
using Logging: Logging

src/sample.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ function mcmcsample(
432432
check_initial_params(initial_params, nchains)
433433
check_initial_state(initial_state, nchains)
434434

435+
_initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
436+
_initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
437+
435438
# Create a seed for each chain using the provided random number generator.
436439
seeds = rand(rng, UInt, nchains)
437440

@@ -490,7 +493,7 @@ function mcmcsample(
490493
return chain
491494
end
492495
chains = Distributed.pmap(
493-
sample_chain, pool, seeds, initial_params, initial_state
496+
sample_chain, pool, seeds, _initial_params, _initial_state
494497
)
495498
finally
496499
# Stop updating the progress bar.

0 commit comments

Comments
 (0)