Skip to content

Commit 5871ce3

Browse files
committed
add g_tol parameter to gradient modes
1 parent c6a5fb4 commit 5871ce3

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/optimizers/gradient.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ function train_with_gradient!(
3434
verbose = get(params, :verbose, true)
3535
compute_cost_every = get(params, :compute_cost_every, 1)
3636
time_limit = get(params, :time_limit, Inf)
37+
g_tol = get(params, :g_tol, 0)
3738

3839
# init parameters
3940
start_time = time()
@@ -87,6 +88,17 @@ function train_with_gradient!(
8788

8889
# check time limit reach
8990
if time() - start_time > time_limit
91+
if verbose
92+
println("Time limit reached.")
93+
end
94+
break
95+
end
96+
97+
# check gradient tolerance
98+
if maximum(abs.(dC)) < g_tol
99+
if verbose
100+
println("Gradient tolerance reached.")
101+
end
90102
break
91103
end
92104

src/optimizers/gradient_mpi.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ function train_with_gradient_mpi!(
1717
compute_cost_every = get(params, :compute_cost_every, 1)
1818
mpi_finalize = get(params, :mpi_finalize, true)
1919
time_limit = get(params, :time_limit, Inf)
20+
g_tol = get(params, :g_tol, 0)
2021

2122
JQM.mpi_init()
2223

2324
# init parameters
2425
start_time = time()
2526
is_done = false
2627
best_C = Inf
27-
best_θ = []
28+
best_θ = extract_params(model.forecast)
2829
curr_C = 0.0
2930
trace = Array{Float64}(undef, epochs)
3031
dCdz = Vector{Float32}(undef, size(model.policy_vars, 1))
@@ -121,6 +122,17 @@ function train_with_gradient_mpi!(
121122

122123
# check time limit reach
123124
if time() - start_time > time_limit
125+
if verbose
126+
println("Time limit reached.")
127+
end
128+
break
129+
end
130+
131+
# check gradient tolerance
132+
if maximum(abs.(dCdy)) < g_tol
133+
if verbose
134+
println("Gradient tolerance reached.")
135+
end
124136
break
125137
end
126138

0 commit comments

Comments
 (0)