Skip to content

Commit 8ba9e91

Browse files
committed
Merge branch 'parallel-sddp' into naive-parallel-sddp
Conflicts: src/utils.jl
2 parents 0cc03ba + 774da6a commit 8ba9e91

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

examples/parallel_sddp.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
import StochDynamicProgramming
3+
4+
5+
"""
6+
Solve SDDP in parallel.
7+
8+
# Arguments
9+
* `model::SPmodel`:
10+
the stochastic problem we want to optimize
11+
* `param::SDDPparameters`:
12+
the parameters of the SDDP algorithm
13+
* `V::Array{PolyhedralFunction}`:
14+
the current estimation of Bellman's functions
15+
* `n_parallel_pass::Int`: default is 4
16+
Number of parallel pass to compute
17+
* `synchronize::Int`: default is 5
18+
Set when to synchronize the cuts between the different processes.
19+
* `display::Int`: default is 0
20+
Says whether to display results or not
21+
22+
# Return
23+
* `V::Array{PolyhedralFunction}`:
24+
the collection of approximation of the bellman functions
25+
"""
26+
function psolve_sddp(model, params, V; n_parallel_pass=4,
27+
synchronize=5, display=0)
28+
# Redefine seeds in every processes to maximize randomness:
29+
@everywhere srand()
30+
31+
mitn = params.maxItNumber
32+
params.maxItNumber = synchronize
33+
34+
# Count number of available CPU:
35+
ncpu = nprocs() - 1
36+
(display > 0) && println("\nLaunch simulation on ", ncpu, " processes")
37+
workers = procs()[2:end]
38+
39+
fpn = params.forwardPassNumber
40+
# As we distribute computation in n process, we perform forward pass in parallel:
41+
params.forwardPassNumber = max(1, round(Int, params.forwardPassNumber/ncpu))
42+
43+
# Start parallel computation:
44+
for i in 1:n_parallel_pass
45+
# Distribute computation of SDDP in each process:
46+
refs = [@spawnat w StochDynamicProgramming.solve_SDDP(model, params, V, display)[1] for w in workers]
47+
# Catch the result in the main process:
48+
V = StochDynamicProgramming.catcutsarray([fetch(r) for r in refs]...)
49+
# We clean the resultant cuts:
50+
StochDynamicProgramming.remove_redundant_cuts!(V)
51+
(display > 0) && println("Lower bound at pass ", i, ": ", StochDynamicProgramming.get_lower_bound(model, params, V))
52+
end
53+
54+
params.forwardPassNumber = fpn
55+
params.maxItNumber = mitn
56+
return V
57+
end

src/SDDPoptimize.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ function initialize_value_functions(model::SPModel,
311311
V[end] = model.finalCost
312312
build_terminal_cost!(model, solverProblems[end], V[end])
313313
elseif isa(model.finalCost, Function)
314+
# In this case, define a trivial value functions for final cost to avoid problem:
315+
V[end] = PolyhedralFunction(zeros(1), zeros(1, model.dimStates), 1)
314316
model.finalCost(model, solverProblems[end])
315317
end
316318

src/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,24 @@ function read_polyhedral_functions(dump::AbstractString)
7272
end
7373

7474

75+
"""Concatenate collection of arrays of PolyhedralFunction."""
76+
function catcutsarray(polyfunarray::Vector{StochDynamicProgramming.PolyhedralFunction}...)
77+
assert(length(polyfunarray) > 0)
78+
ntimes = length(polyfunarray[1])
79+
# Concatenate cuts in polyfunarray, and discard final time as we do not add cuts at final time:
80+
concatcuts = StochDynamicProgramming.PolyhedralFunction[catcuts([V[t] for V in polyfunarray]...) for t in 1:ntimes-1]
81+
return vcat(concatcuts, polyfunarray[1][end])
82+
end
83+
84+
85+
"""Concatenate collection of PolyhedralFunction."""
86+
function catcuts(Vts::StochDynamicProgramming.PolyhedralFunction...)
87+
betas = vcat([V.betas for V in Vts]...)
88+
lambdas = vcat([V.lambdas for V in Vts]...)
89+
numcuts = sum([V.numCuts for V in Vts])
90+
return StochDynamicProgramming.PolyhedralFunction(betas, lambdas, numcuts)
91+
end
92+
7593
"""
7694
Extract a vector stored in a 3D Array
7795

0 commit comments

Comments
 (0)