Skip to content

Commit 63999c7

Browse files
committed
add checkpoints
1 parent fa5a945 commit 63999c7

File tree

8 files changed

+386
-169
lines changed

8 files changed

+386
-169
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ EquivariantTensors = "5e107534-7145-4f8f-b06f-47a52840c895"
1212
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1414
HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97"
15+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1516
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
LowRankApprox = "898213cb-b102-5a47-900c-97e73b919f73"
@@ -29,9 +30,10 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2930
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3031

3132
[compat]
32-
julia = "1.10"
3333
EquivariantTensors = "0.1.2"
3434
Polynomials4ML = "0.5.0"
35+
julia = "1.10"
36+
3537
[extras]
3638
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3739

example.jl

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,34 @@
11
# add https://github.com/ACEsuit/ChemBasisSets.jl.git
22
# add ACEpsi #lux
3-
using Distributed
4-
N_procs = 8
5-
6-
import Pkg
7-
Pkg.activate(Base.current_project())
8-
9-
if nprocs() == 1
10-
addprocs(N_procs - 1, exeflags="--project=$(Base.current_project())")
11-
end
12-
133
using ACEpsi
144

5+
method = 1
156
mol = ACEpsi.molecules.Be
7+
mol_name = "Be"
8+
setup(mol, mol_name, method)
169

17-
basis_set = "cc-pvtz"
18-
totdeg = [["1f", "1f"]]
19-
ν = 2
10+
x0, optimizer, model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list = load_setup(mol_name);
11+
solver = (SPRINGSolver(), SketchSolver(800, 50, 50, 1.4), SVDSolver(800, 50, 50, 1.4))
12+
optimizer.sr_method = solver[method]
13+
string = method == 1 ? "SPRING" :
14+
method == 2 ? "SKETCH" :
15+
method == 3 ? "WSSR" : error("Invalid method")
16+
optimizer.res_path = "$mol_name/$string/"
2017

21-
model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list = model_generator(mol, basis_set, totdeg, ν; ratio = 0.5);
22-
test_wavefunction(model_list, ps_list, st_list, spec_list, spec1p_list, mol)
23-
24-
solver = (SPRINGSolver(), SketchSolver(800, 50, 50, 1.4), SVDSolver(800, 20, 20, 1.4))
18+
N_procs = 8
19+
using Pkg
20+
using Distributed
21+
Pkg.activate(Base.current_project())
2522

26-
iterations = 1000 * ones(Int64, length(spec_list))
27-
iterations[end] = 30000
28-
optimizer = OPTSETTING(solver[1], iterations = iterations, burnin = 1000, lag = 10, nchains = 2^8,
29-
Δt = 0.08, acc_step = 10, acc_range = [0.45, 0.80],
30-
clip = 5.0, lr = 0.02, lr_dc = 10000, m = 0.99,
31-
damping = 0.001, damping_decay = 100, damping_min = 0.001,
32-
norm_constrain = 0.001, η = 0.95, res_path = "Be/SPRING/")
23+
if nprocs() == 1
24+
addprocs(N_procs - 1, exeflags="--project=$(Base.current_project())")
25+
end
3326

3427
@everywhere begin
3528
using ACEpsi
3629
using Statistics
3730
using LinearAlgebra
3831
using Optimisers: destructure
3932
end
40-
model_list, ps_list, st_list, val_list, var_list, rank_list = train(mol, model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list, optimizer);
33+
34+
model_list, ps_list, st_list, val_list, var_list, rank_list = train(x0, mol, model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list, optimizer);

src/model/spec.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,12 @@ end
146146

147147

148148
function build_totdeglevels(mol, basis_set, totdeg, ν; ratio = 0.5, max_level::Union{Nothing, Int} = nothing)
149-
_, orbital = ACEpsi.auto_load_basis(mol, basis_set; return_spec = true)
149+
_, orbital = auto_load_basis(mol, basis_set; return_spec = true)
150150
n_atom = length(orbital)
151151

152-
lj = maximum(length.(ACEpsi.iteratespec1p.(ACEpsi.orbital_for_iatom_jord.(Ref(orbital), Ref(totdeg), 1:n_atom, 1), Ref(1))))
152+
lj = maximum(length.(iteratespec1p.(orbital_for_iatom_jord.(Ref(orbital), Ref(totdeg), 1:n_atom, 1), Ref(1))))
153153
_lj = max(Int(ceil(lj * ratio)), 1)
154-
spec1pl = [ACEpsi.iteratespec1p(ACEpsi.orbital_for_iatom_jord(orbital, totdeg, i, 1), lj) for i = 1:n_atom]
154+
spec1pl = [iteratespec1p(orbital_for_iatom_jord(orbital, totdeg, i, 1), lj) for i = 1:n_atom]
155155

156156
totdeglevels = Vector{Vector{Vector{String}}}()
157157
νlevels = Int[]

src/molecule/molecules.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ CH4 = create_molecule(SVector(
2424
Nuc("H", [-1.18886, 1.18886, -1.18886])
2525
))
2626

27-
# === Diatomic and small molecules with spacing parameter ===
28-
Li2(spacing) = create_molecule(SVector(
29-
Nuc("Li", [-spacing/2, 0.0, 0.0]),
30-
Nuc("Li", [ spacing/2, 0.0, 0.0])
27+
Li2 = create_molecule(SVector(
28+
Nuc("Li", [-5.051/2, 0.0, 0.0]),
29+
Nuc("Li", [ 5.051/2, 0.0, 0.0])
3130
))
32-
LiH(spacing) = create_molecule(SVector(
33-
Nuc("Li", [-spacing/2, 0.0, 0.0]),
34-
Nuc("H", [ spacing/2, 0.0, 0.0])
31+
LiH = create_molecule(SVector(
32+
Nuc("Li", [-3.015/2, 0.0, 0.0]),
33+
Nuc("H", [ 3.015/2, 0.0, 0.0])
3534
))
35+
# === Diatomic and small molecules with spacing parameter ===
3636
N2(spacing) = create_molecule(SVector(
3737
Nuc("N", [-spacing/2, 0.0, 0.0]),
3838
Nuc("N", [ spacing/2, 0.0, 0.0])

src/train/opt/sr.jl

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export opts!
22

33
function opts!(i, OptParams::OPTPARAMS, optimizer::DirectSolver, Eloc, o::Matrix{T}, nchains, damping::Float64, dim_ps::Int64, η::Float64, norm_constrain, γ, m) where {T}
4-
s = o * o'
4+
s = o * o' # O * O'
55
@inbounds @simd for i = 1:dim_ps
66
s[i, i] += damping
77
end
@@ -19,45 +19,41 @@ end
1919

2020
function opts!(i, OptParams::OPTPARAMSPRING, optimizer::SPRINGSolver, Eloc, o::Matrix{T}, nchains, damping::Float64, dim_ps::Int64, η::Float64, norm_constrain, γ, m) where {T}
2121
res = norm(OptParams.f)
22-
lmul!(-γ, Eloc)
22+
lmul!(-γ, Eloc) # -delta tau * (E - E_mean)
2323
s = o' * o
2424

2525
s .+= 1/(nchains)
26-
Tvecs, Tvals = svd(Symmetric(s))
26+
Tvals, Tvecs = eigen(Symmetric(s))
2727
Tvals = max.(Tvals, 0.0) .+ damping
2828

29-
if OptParams.dw_tot[1] !== 0.0
30-
mul!(OptParams.dow, transpose(o), OptParams.dw_tot)
31-
epsilon_tilde = Eloc .- η * OptParams.dow
32-
else
33-
epsilon_tilde = Eloc
34-
end
29+
mul!(OptParams.dow, transpose(o), OptParams.dw_tot)
30+
epsilon_tilde = Eloc .- η * OptParams.dow
3531

3632
mul!(OptParams.dow, Tvecs', epsilon_tilde)
37-
OptParams.dow = Diagonal(1 ./ Tvals) * OptParams.dow
33+
ldiv!(Diagonal(Tvals), OptParams.dow)
3834
OptParams.dow = Tvecs * OptParams.dow
39-
mul!(OptParams.f, o, OptParams.dow)
40-
OptParams.dw_tot .*= η
41-
OptParams.dw_tot .+= OptParams.f
35+
OptParams.dow .-= mean(OptParams.dow)
36+
37+
mul!(OptParams.f, o, OptParams.dow)
38+
OptParams.dw_tot .= η * OptParams.dw_tot .+ OptParams.f / sqrt(nchains)
4239
OptParams.dw_tot .*= min(1, sqrt(norm_constrain)/norm(OptParams.dw_tot))
4340
return OptParams.dw_tot, length(OptParams.f), res
4441
end
4542

4643
function opts!(i, OptParams::OPTPARAMMINSR, optimizer::MINSRSolver, Eloc, o::Matrix{T}, nchains, damping::Float64, dim_ps::Int64, η::Float64, norm_constrain, γ, m) where {T}
47-
lmul!(sqrt(nchains), Eloc)
48-
ldiv!(sqrt(nchains), o)
49-
res = norm(o * Eloc)
50-
lmul!(-γ, Eloc)
44+
res = norm(OptParams.f)
45+
lmul!(-γ, Eloc) # -delta tau * (E - E_mean)
5146
s = o' * o
5247

53-
Tvecs, Tvals = svd(Symmetric(s))
48+
Tvals, Tvecs = eigen(Symmetric(s))
5449
Tvals = max.(Tvals, 0.0) .+ damping
50+
5551
mul!(OptParams.dow, Tvecs', Eloc)
56-
OptParams.dow = Diagonal(1 ./ Tvals) * OptParams.dow
52+
ldiv!(Diagonal(Tvals), OptParams.dow)
5753
OptParams.dow = Tvecs * OptParams.dow
58-
mul!(OptParams.f, o, OptParams.dow)
59-
OptParams.dw_tot .*= η
60-
OptParams.dw_tot .+= (1-η) * OptParams.f
54+
55+
mul!(OptParams.f, o, OptParams.dow)
56+
OptParams.dw_tot .= η * OptParams.dw_tot .+ (1-η) * OptParams.f / sqrt(nchains)
6157
OptParams.dw_tot .*= min(1, sqrt(norm_constrain)/norm(OptParams.dw_tot))
6258
return OptParams.dw_tot, length(OptParams.f), res
6359
end

src/train/opt/struct.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export MINSRSolver, SPRINGSolver, DirectSolver, SketchSolver, SVDSolver
44

55
abstract type SR_method end
66

7-
struct OPTSETTING
7+
mutable struct OPTSETTING
88
sr_method::SR_method
99
iterations::Vector{Int64}
1010
burnin::Int
@@ -13,6 +13,7 @@ struct OPTSETTING
1313
Δt::Float64
1414
acc_step::Int
1515
acc_range::Vector{Float64}
16+
acc_opt
1617
clip::Float64
1718
lr::Float64
1819
lr_dc::Int
@@ -23,16 +24,17 @@ struct OPTSETTING
2324
norm_constrain::Float64
2425
η::Float64
2526
res_path
27+
checkpoints
2628
end
2729

2830
function OPTSETTING(sr_method::SR_method; iterations::Vector{Int}, burnin::Int, lag::Int, nchains::Int,
29-
Δt::Float64, acc_step::Int, acc_range::Vector{Float64},
31+
Δt::Float64, acc_step::Int, acc_range::Vector{Float64}, acc_opt,
3032
clip::Float64, lr::Float64, lr_dc::Int, m::Float64,
3133
damping::Float64, damping_decay::Int, damping_min::Float64,
32-
norm_constrain::Float64, η::Float64, res_path)
33-
return OPTSETTING(sr_method, iterations, burnin, lag, nchains, Δt, acc_step, acc_range,
34+
norm_constrain::Float64, η::Float64, res_path, checkpoints)
35+
return OPTSETTING(sr_method, iterations, burnin, lag, nchains, Δt, acc_step, acc_range, acc_opt,
3436
clip, lr, lr_dc, m, damping, damping_decay, damping_min,
35-
norm_constrain, η, res_path)
37+
norm_constrain, η, res_path, checkpoints)
3638
end
3739

3840
struct MINSRSolver <: SR_method

0 commit comments

Comments
 (0)