Skip to content

Commit 75be3bb

Browse files
author
Giovanni Amorim
committed
fix mpi modes
1 parent ce2b8f6 commit 75be3bb

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/optimizers/gradient_mpi.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function train_with_gradient_mpi!(
7373
(v) -> compute_cost_and_gradients(v[1], v[2], true),
7474
[[curr_θ, i] for i in batches[epoch, :]]
7575
)
76-
dCdy = sum([r[2] for r in pmap_result_with_gradients])
76+
dCdy = sum([r[2] for r in pmap_result_with_gradients]) ./ batch_size
7777

7878
if compute_full_cost
7979
# broadcast `is_done = false` again
@@ -84,7 +84,7 @@ function train_with_gradient_mpi!(
8484
(v) -> compute_cost_and_gradients(v[1], v[2], false),
8585
[[curr_θ, i] for i=1:T]
8686
)
87-
curr_C = sum([r[1] for r in pmap_result_without_gradients])
87+
curr_C = sum([r[1] for r in pmap_result_without_gradients]) ./ T
8888
end
8989

9090
else
@@ -94,8 +94,8 @@ function train_with_gradient_mpi!(
9494
(v) -> compute_cost_and_gradients(v[1], v[2], true),
9595
[[curr_θ, i] for i=1:T]
9696
)
97-
curr_C = sum([r[1] for r in pmap_result])
98-
dCdy = sum([r[2] for r in pmap_result])
97+
curr_C = sum([r[1] for r in pmap_result]) ./ T
98+
dCdy = sum([r[2] for r in pmap_result]) ./ T
9999
end
100100

101101
if compute_full_cost

src/optimizers/nelder_mead_mpi.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ function train_with_nelder_mead_mpi!(
1212
JQM.mpi_init()
1313

1414
# extract params
15+
mpi_finalize = get(params, :mpi_finalize, true)
16+
delete!(params, :mpi_finalize)
1517
optim_options = Optim.Options(;params...)
1618

1719
is_done = false
1820
res = nothing
1921
final_sol = []
2022
final_cost = 0.0
23+
T = size(X)[1]
2124
function compute_cost(θ, i)
2225
apply_params(model.forecast, θ)
2326
yhat = model.forecast(X[i,:])
@@ -31,8 +34,8 @@ function train_with_nelder_mead_mpi!(
3134
initial_sol = extract_params(model.forecast)
3235
res = Optim.optimize(initial_sol, NelderMead(), optim_options) do θ
3336
MPI.bcast(is_done, MPI.COMM_WORLD)
34-
c_θ = JQM.pmap((v) -> compute_cost(v[1], v[2]), [[θ, i] for i in eachindex(X)])
35-
return sum(c_θ)
37+
c_θ = JQM.pmap((v) -> compute_cost(v[1], v[2]), [[θ, i] for i=1:T])
38+
return sum(c_θ) ./ T
3639
end
3740

3841
# print solution
@@ -62,7 +65,9 @@ function train_with_nelder_mead_mpi!(
6265
end
6366

6467
JQM.mpi_barrier()
65-
JQM.mpi_finalize()
68+
if mpi_finalize
69+
JQM.mpi_finalize()
70+
end
6671

6772
return Solution(final_cost, final_sol)
6873
end

0 commit comments

Comments
 (0)