Skip to content

Commit 75bad6e

Browse files
authored
Merge pull request #120 from JuliaOpt/naive-parallel-sddp
Naive parallelization
2 parents adc6681 + ff817da commit 75bad6e

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

examples/parallel_sddp.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
import StochDynamicProgramming
3+
4+
5+
"""
6+
Solve SDDP in parallel, dispatching both forward and backward passes to process,
7+
which is not the most standard parallelization of SDDP.
8+
9+
# Arguments
10+
* `model::SPmodel`:
11+
the stochastic problem we want to optimize
12+
* `param::SDDPparameters`:
13+
the parameters of the SDDP algorithm
14+
* `V::Array{PolyhedralFunction}`:
15+
the current estimation of Bellman's functions
16+
* `n_parallel_pass::Int`: default is 4
17+
Number of parallel pass to compute
18+
* `synchronize::Int`: default is 5
19+
Synchronize the cuts between the different processes every "synchronise" iterations
20+
* `display::Int`: default is 0
21+
Says whether to display results or not
22+
23+
# Return
24+
* `V::Array{PolyhedralFunction}`:
25+
the collection of approximation of the bellman functions
26+
"""
27+
function psolve_sddp(model, params, V; n_parallel_pass=4,
28+
synchronize=5, display=0)
29+
# Redefine seeds in every processes to maximize randomness:
30+
@everywhere srand()
31+
32+
mitn = params.maxItNumber
33+
params.maxItNumber = synchronize
34+
35+
# Count number of available CPU:
36+
ncpu = nprocs() - 1
37+
(display > 0) && println("\nLaunch simulation on ", ncpu, " processes")
38+
workers = procs()[2:end]
39+
40+
fpn = params.forwardPassNumber
41+
# As we distribute computation in n process, we perform forward pass in parallel:
42+
params.forwardPassNumber = max(1, round(Int, params.forwardPassNumber/ncpu))
43+
44+
# Start parallel computation:
45+
for i in 1:n_parallel_pass
46+
# Distribute computation of SDDP in each process:
47+
refs = [@spawnat w StochDynamicProgramming.solve_SDDP(model, params, V, display)[1] for w in workers]
48+
# Catch the result in the main process:
49+
V = StochDynamicProgramming.catcutsarray([fetch(r) for r in refs]...)
50+
# We clean the resultant cuts:
51+
StochDynamicProgramming.remove_redundant_cuts!(V)
52+
(display > 0) && println("Lower bound at pass ", i, ": ", StochDynamicProgramming.get_lower_bound(model, params, V))
53+
end
54+
55+
params.forwardPassNumber = fpn
56+
params.maxItNumber = mitn
57+
return V
58+
end

src/SDDPoptimize.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ function initialize_value_functions(model::SPModel,
311311
V[end] = model.finalCost
312312
build_terminal_cost!(model, solverProblems[end], V[end])
313313
elseif isa(model.finalCost, Function)
314+
# In this case, define a trivial value functions for final cost to avoid problem:
315+
V[end] = PolyhedralFunction(zeros(1), zeros(1, model.dimStates), 1)
314316
model.finalCost(model, solverProblems[end])
315317
end
316318

src/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,24 @@ function read_polyhedral_functions(dump::AbstractString)
7272
end
7373

7474

75+
"""Concatenate collection of arrays of PolyhedralFunction."""
76+
function catcutsarray(polyfunarray::Vector{StochDynamicProgramming.PolyhedralFunction}...)
77+
assert(length(polyfunarray) > 0)
78+
ntimes = length(polyfunarray[1])
79+
# Concatenate cuts in polyfunarray, and discard final time as we do not add cuts at final time:
80+
concatcuts = StochDynamicProgramming.PolyhedralFunction[catcuts([V[t] for V in polyfunarray]...) for t in 1:ntimes-1]
81+
return vcat(concatcuts, polyfunarray[1][end])
82+
end
83+
84+
85+
"""Concatenate collection of PolyhedralFunction."""
86+
function catcuts(Vts::StochDynamicProgramming.PolyhedralFunction...)
87+
betas = vcat([V.betas for V in Vts]...)
88+
lambdas = vcat([V.lambdas for V in Vts]...)
89+
numcuts = sum([V.numCuts for V in Vts])
90+
return StochDynamicProgramming.PolyhedralFunction(betas, lambdas, numcuts)
91+
end
92+
7593
"""
7694
Extract a vector stored in a 3D Array
7795

0 commit comments

Comments
 (0)