Skip to content

Commit c9df846

Browse files
authored
Merge pull request #36 from LAMPSPUC/dev
apply best param on GradientMPIMode and improve epochs print
2 parents 30a9de7 + 66956cc commit c9df846

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApplicationDrivenLearning"
22
uuid = "0856f1c8-ef17-4e14-9230-2773e47a789e"
33
authors = ["Giovanni Amorim", "Joaquim Garcia"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
BilevelJuMP = "485130c0-026e-11ea-0f1a-6992cd14145c"

src/optimizers/gradient.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ function train_with_gradient!(
7272
# store and print cost
7373
trace[epoch] = C
7474
if verbose
75-
println("Epoch $epoch | Cost = $(round(C, digits=2))")
75+
dtime = time() - start_time
76+
println(
77+
"Epoch $epoch | Time = $(round(dtime, digits=1))s | Cost = $(round(C, digits=2))",
78+
)
7679
end
7780

7881
# evaluate if best model

src/optimizers/gradient_mpi.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ function train_with_gradient_mpi!(
106106
# store and print cost
107107
trace[epoch] = curr_C
108108
if verbose
109-
println("Epoch $epoch | Cost = $(round(curr_C, digits=2))")
109+
dtime = time() - start_time
110+
println(
111+
"Epoch $epoch | Time = $(round(dtime, digits=1))s | Cost = $(round(curr_C, digits=2))",
112+
)
110113
end
111114

112115
# evaluate if best model
@@ -129,6 +132,9 @@ function train_with_gradient_mpi!(
129132
is_done = true
130133
MPI.bcast(is_done, MPI.COMM_WORLD)
131134

135+
# fix best model
136+
apply_params(model.forecast, best_θ)
137+
132138
elseif JQM.is_worker_process()
133139
# continuoslly call pmap until controller is done
134140
while true

0 commit comments

Comments
 (0)