Skip to content

Commit 78b7ad4

Browse files
committed
add tensor
1 parent 63999c7 commit 78b7ad4

File tree

13 files changed

+436
-50
lines changed

13 files changed

+436
-50
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2626
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2727
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2828
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
29+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
2930
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3031
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3132

example.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@ using ACEpsi
55
method = 1
66
mol = ACEpsi.molecules.Be
77
mol_name = "Be"
8-
setup(mol, mol_name, method)
8+
TD = SCPMultipleW(14) #No_Decomposition()
9+
worldsize = N_procs = 8
10+
setup(mol, mol_name, method, TD, worldsize)
911

10-
x0, optimizer, model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list = load_setup(mol_name);
12+
x0, optimizer, model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list = load_setup(mol_name, TD);
1113
solver = (SPRINGSolver(), SketchSolver(800, 50, 50, 1.4), SVDSolver(800, 50, 50, 1.4))
1214
optimizer.sr_method = solver[method]
1315
string = method == 1 ? "SPRING" :
1416
method == 2 ? "SKETCH" :
1517
method == 3 ? "WSSR" : error("Invalid method")
16-
optimizer.res_path = "$mol_name/$string/"
18+
clean_TD = replace("$(TD)", r"[^A-Za-z0-9]" => "")
19+
optimizer.res_path = "$mol_name/$string/$clean_TD/"
20+
1721

18-
N_procs = 8
1922
using Pkg
2023
using Distributed
2124
Pkg.activate(Base.current_project())

src/ACEpsi.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include("molecule/molecule.jl")
66
include("molecule/molecules.jl")
77
include("molecule/hamiltonian.jl")
88
include("model/layers.jl")
9+
include("model/tensorlayers.jl")
910
include("model/spec.jl")
1011
include("model/wavefunction.jl")
1112
include("model/multilevel.jl")

src/model/layers.jl

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,16 @@ function evaluate(jl::JastrowLayer, X::Vector{SVector{3, TX}}, Σ::Vector{Char})
5454

5555
γ += -c_ij / (1 + dist)
5656
end
57-
return γ
57+
return exp(γ)
5858
end
5959

60+
struct BackflowPoolingLayer_TD{TT}<: AbstractLuxLayer
61+
spec::Vector{TT}
62+
Σ::Vector{Char}
63+
end
64+
65+
(l::BackflowPoolingLayer_TD)(x, ps, st) = evaluate(l, x, l.Σ), ps, st
66+
6067
function evaluate(l::BackflowPoolingLayer, x, Σ::Vector{Char})
6168
T = promote_type(eltype(x[1]))
6269
Nnlm = length(l.spec)
@@ -93,6 +100,43 @@ function evaluate(l::BackflowPoolingLayer, x, Σ::Vector{Char})
93100
return A # shape: (Nel, 3 * Nnlm)
94101
end
95102

103+
function evaluate(l::BackflowPoolingLayer_TD, x, Σ::Vector{Char})
104+
T = promote_type(eltype(x[1]))
105+
Nnlm = length(l.spec)
106+
Nel = length(Σ)
107+
108+
@assert spin2idx() == 1
109+
@assert spin2idx() == 2
110+
@assert spin2idx(∅) == 3
111+
112+
A = zeros(T, Nel, 3, Nnlm) # (Nel, 3 channel, Nnlm)
113+
Aall = zeros(T, 2, Nnlm) # (spin channel ∈ {↑, ↓}, Nnlm)
114+
115+
@inbounds begin
116+
for k = 1:Nnlm
117+
@simd ivdep for i = 1:Nel
118+
= spin2idx(Σ[i])
119+
if 2
120+
Aall[iσ, k] += x[i, k]
121+
end
122+
A[i, 3, k] = x[i, k]
123+
end
124+
end
125+
126+
for k = 1:Nnlm
127+
@simd ivdep for= 1:2
128+
σ = idx2spin(iσ)
129+
for i = 1:Nel
130+
A[i, iσ, k] = Aall[iσ, k] - (Σ[i] == σ ? x[i, k] : zero(T))
131+
end
132+
end
133+
end
134+
end
135+
136+
return A # shape: (Nel, 3, Nnlm)
137+
end
138+
139+
96140
function ChainRulesCore.rrule(::typeof(evaluate), l::Diff_layer{Nnuc}, X::Vector{SVector{3, TX}}, ps::NamedTuple, st::NamedTuple) where {Nnuc, TX}
97141
val = ntuple(i -> _getdiff(X, l.nuc[i].rr), Val(Nnuc))
98142
function pb(dA)
@@ -119,6 +163,14 @@ function ChainRulesCore.rrule(::typeof(evaluate), pooling::BackflowPoolingLayer,
119163
return A, pb
120164
end
121165

166+
function ChainRulesCore.rrule(::typeof(evaluate), pooling::BackflowPoolingLayer_TD, x, Σ::Vector{Char})
167+
A = evaluate(pooling, x, pooling.Σ)
168+
function pb(∂A)
169+
return NoTangent(), NoTangent(), _pullback_evaluate(∂A, pooling, x, Σ), NoTangent()
170+
end
171+
return A, pb
172+
end
173+
122174
# helper function
123175

124176
function _getdiff(X::AbstractArray{SVector{3, T}}, d::SVector{3, TT}) where {T, TT}
@@ -156,6 +208,35 @@ function _pullback_evaluate(∂A, l::BackflowPoolingLayer, x, Σ::Vector{Char})
156208
return ∂x
157209
end
158210

211+
function _pullback_evaluate(∂A, l::BackflowPoolingLayer_TD, x, Σ::Vector{Char})
212+
TA = eltype(x[1])
213+
Nel = length(Σ)
214+
Nnlm = length(l.spec)
215+
216+
∂x = zeros(TA, Nel, Nnlm)
217+
218+
@inbounds begin
219+
for k = 1:Nnlm
220+
for i = 1:Nel
221+
σi = Σ[i]
222+
= spin2idx(σi)
223+
224+
∂x[i, k] += ∂A[i, 3, k]
225+
226+
if 2
227+
for j = 1:Nel
228+
if j != i && Σ[j] == σi
229+
∂x[i, k] += ∂A[j, iσ, k]
230+
end
231+
end
232+
end
233+
end
234+
end
235+
end
236+
237+
return ∂x
238+
end
239+
159240

160241
function LuxCore.initialparameters(rng::AbstractRNG, d::Dense)
161242
weight = if d.init_weight === nothing

src/model/multilevel.jl

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,35 @@
11
export transfer_weights!, model_generator
22

33
function transfer_weights!(ps1, ps2, spec1, spec2, spec1p1, spec1p2)
4-
readable_spec1 = displayspec(spec1, spec1p1)
5-
readable_spec2 = displayspec(spec2, spec1p2)
4+
readable_spec1 = displayspec(spec1, spec1p1, ps1)
5+
readable_spec2 = displayspec(spec2, spec1p2, ps2)
66
_map = _invmap(readable_spec2)
7-
ps2.branch.bf.linear.weight .= 0.0
8-
for (idx, t) in enumerate(readable_spec1)
9-
ps2.branch.bf.linear.weight[:, _map[t]] = ps1.branch.bf.linear.weight[:, idx]
7+
if :TK in keys(ps2.branch.bf)
8+
_map2 = _invmap(spec1p2)
9+
ps2.branch.bf.TK.W .= 0.0
10+
W = ps1.branch.bf.TK.W
11+
if length(size(W)) == 3
12+
for (idx, t) in enumerate(spec1p1)
13+
ps2.branch.bf.TK.W[:,1:size(W)[2], _map2[t]] .= W[:, 1:size(W)[2], idx]
14+
end
15+
elseif length(size(W)) == 4
16+
for (idx, t) in enumerate(spec1p1)
17+
ps2.branch.bf.TK.W[:, :, 1:size(W)[3], _map2[t]] .= W[:, :, 1:size(W)[3], idx]
18+
end
19+
end
20+
end
21+
if :linear in keys(ps2.branch.bf)
22+
ps2.branch.bf.linear.weight .= 0.0
23+
for (idx, t) in enumerate(readable_spec1)
24+
ps2.branch.bf.linear.weight[:, _map[t]] = ps1.branch.bf.linear.weight[:, idx]
25+
end
26+
else
27+
for i in keys(ps2.branch.bf.bAA)
28+
ps2.branch.bf.bAA[i].layer_2.weight .= 0.0
29+
for (idx, t) in enumerate(readable_spec1)
30+
ps2.branch.bf.bAA[i].layer_2.weight[:, _map[t]] = ps1.branch.bf.bAA[i].layer_2.weight[:, idx]
31+
end
32+
end
1033
end
1134
return ps2
1235
end
@@ -19,18 +42,18 @@ function transfer_weights_idx!(ps1, ps2, spec1, spec2, spec1p1, spec1p2)
1942
return index
2043
end
2144

22-
function model_generator(mol, basis_set, totdeg, ν; ratio = 0.5, multilevel = true, filename = "basis.json")
45+
function model_generator(mol, basis_set, totdeg, ν; TD = No_Decomposition(), ratio = 0.5, multilevel = true, filename = "basis.json")
2346
if multilevel
24-
totdeg_list, ν_list = build_totdeglevels(mol, basis_set, totdeg, ν; ratio = ratio)
25-
results = [build_wavefunction(mol, basis_set, totdeg_list[i], ν_list[i]; filename = filename) for i = 1:length(totdeg_list)]
47+
totdeg_list, ν_list = build_totdeglevels(mol, basis_set, totdeg, ν, TD; ratio = ratio)
48+
results = [build_wavefunction(mol, basis_set, totdeg_list[i], ν_list[i], TD; filename = filename) for i = 1:length(totdeg_list)]
2649
model_list = getindex.(results, 1)
2750
ps_list = getindex.(results, 2)
2851
st_list = getindex.(results, 3)
2952
spec_list = getindex.(results, 4)
3053
spec1p_list = getindex.(results, 5)
3154
return model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list
3255
else
33-
model, ps, st, spec, spec1p = build_wavefunction(mol, basis_set, totdeg, ν; filename = fieldname)
56+
model, ps, st, spec, spec1p = build_wavefunction(mol, basis_set, totdeg, ν, TD; filename = fieldname)
3457
return [model], [ps], [st], [spec], [spec1p], [totdeg], [ν]
3558
end
3659
end

src/model/spec.jl

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,24 @@ function parseTotdegToInt(
5454
return _totdegn
5555
end
5656

57-
function displayspec(spec, spec1p)
58-
nicespec = []
59-
for k = 1:length(spec)
60-
push!(nicespec, [spec1p[spec[k][j]] for j = 1:length(spec[k])])
57+
function displayspec(spec, spec1p, ps)
58+
if :TK keys(ps.branch.bf)
59+
nicespec = []
60+
for k = 1:length(spec)
61+
push!(nicespec, [spec1p[spec[k][j]] for j = 1:length(spec[k])])
62+
end
63+
else
64+
P = 0
65+
if length(size(ps.branch.bf.TK.W)) == 4
66+
P = size(ps.branch.bf.TK.W)[3]
67+
elseif length(size(ps.branch.bf.TK.W)) == 3
68+
P = size(ps.branch.bf.TK.W)[2]
69+
end
70+
spec1p = get_spec1p(P)
71+
nicespec = []
72+
for k = 1:length(spec)
73+
push!(nicespec, [spec1p[spec[k][j]] for j = 1:length(spec[k])])
74+
end
6175
end
6276
return nicespec
6377
end
@@ -95,6 +109,18 @@ function get_spec1p(basis::Vector{TS}; spin = false) where {TS}
95109
return spec[:]
96110
end
97111

112+
function get_spec1p(P)
113+
spec = Array{Any}(undef, (3, P))
114+
115+
for k = 1:P
116+
for (is, s) in enumerate(extspins())
117+
spec[is, k] = (s=s, P = k)
118+
end
119+
end
120+
121+
return spec[:]
122+
end
123+
98124
function _invmap(a)
99125
inva = Dict{eltype(a), Int}()
100126
for i = 1:length(a)
@@ -145,7 +171,7 @@ function sample_evenly(arr::AbstractVector; N = 10)
145171
end
146172

147173

148-
function build_totdeglevels(mol, basis_set, totdeg, ν; ratio = 0.5, max_level::Union{Nothing, Int} = nothing)
174+
function build_totdeglevels(mol, basis_set, totdeg, ν, TD::No_Decomposition; ratio = 0.5, max_level::Union{Nothing, Int} = nothing)
149175
_, orbital = auto_load_basis(mol, basis_set; return_spec = true)
150176
n_atom = length(orbital)
151177

@@ -203,3 +229,45 @@ function build_totdeglevels(mol, basis_set, totdeg, ν; ratio = 0.5, max_level::
203229

204230
return totdeglevels, νlevels
205231
end
232+
233+
function build_totdeglevels(mol, basis_set, totdeg, ν, TD; ratio = 0.5, max_level::Union{Nothing, Int} = nothing)
234+
maxdim = length(totdeg)
235+
levels = Vector{Vector{Int}}()
236+
237+
d = mol.Nel
238+
deg_split = floor(Int, totdeg[1] * ratio)
239+
for x = d:deg_split
240+
push!(levels, [x])
241+
end
242+
243+
function extend_levels(levels, curdim)
244+
result = Vector{Vector{Int}}()
245+
for v in levels
246+
if length(v) == curdim &&
247+
all(v[i] floor(Int, totdeg[i] * ratio) for i in 1:curdim)
248+
for j = mol.Nel:floor(Int, totdeg[curdim+1] * ratio)
249+
push!(result, vcat(v, j))
250+
end
251+
end
252+
end
253+
return result
254+
end
255+
256+
for curdim = 1:maxdim-1
257+
new_levels = extend_levels(levels, curdim)
258+
append!(levels, new_levels)
259+
end
260+
261+
last_diag = [floor(Int, totdeg[i] * ratio) for i in 1:maxdim]
262+
while all(last_diag[i] totdeg[i] for i in 1:maxdim)
263+
push!(levels, copy(last_diag))
264+
for i in 1:maxdim
265+
last_diag[i] += 1
266+
end
267+
end
268+
269+
νlevels = [length(v) for v in levels]
270+
return levels, νlevels
271+
end
272+
273+

0 commit comments

Comments
 (0)