11export transfer_weights!, model_generator
22
33function transfer_weights! (ps1, ps2, spec1, spec2, spec1p1, spec1p2)
4- readable_spec1 = displayspec (spec1, spec1p1)
5- readable_spec2 = displayspec (spec2, spec1p2)
4+ readable_spec1 = displayspec (spec1, spec1p1, ps1 )
5+ readable_spec2 = displayspec (spec2, spec1p2, ps2 )
66 _map = _invmap (readable_spec2)
7- ps2. branch. bf. linear. weight .= 0.0
8- for (idx, t) in enumerate (readable_spec1)
9- ps2. branch. bf. linear. weight[:, _map[t]] = ps1. branch. bf. linear. weight[:, idx]
7+ if :TK in keys (ps2. branch. bf)
8+ _map2 = _invmap (spec1p2)
9+ ps2. branch. bf. TK. W .= 0.0
10+ W = ps1. branch. bf. TK. W
11+ if length (size (W)) == 3
12+ for (idx, t) in enumerate (spec1p1)
13+ ps2. branch. bf. TK. W[:,1 : size (W)[2 ], _map2[t]] .= W[:, 1 : size (W)[2 ], idx]
14+ end
15+ elseif length (size (W)) == 4
16+ for (idx, t) in enumerate (spec1p1)
17+ ps2. branch. bf. TK. W[:, :, 1 : size (W)[3 ], _map2[t]] .= W[:, :, 1 : size (W)[3 ], idx]
18+ end
19+ end
20+ end
21+ if :linear in keys (ps2. branch. bf)
22+ ps2. branch. bf. linear. weight .= 0.0
23+ for (idx, t) in enumerate (readable_spec1)
24+ ps2. branch. bf. linear. weight[:, _map[t]] = ps1. branch. bf. linear. weight[:, idx]
25+ end
26+ else
27+ for i in keys (ps2. branch. bf. bAA)
28+ ps2. branch. bf. bAA[i]. layer_2. weight .= 0.0
29+ for (idx, t) in enumerate (readable_spec1)
30+ ps2. branch. bf. bAA[i]. layer_2. weight[:, _map[t]] = ps1. branch. bf. bAA[i]. layer_2. weight[:, idx]
31+ end
32+ end
1033 end
1134 return ps2
1235end
@@ -19,18 +42,18 @@ function transfer_weights_idx!(ps1, ps2, spec1, spec2, spec1p1, spec1p2)
1942 return index
2043end
2144
22- function model_generator (mol, basis_set, totdeg, ν; ratio = 0.5 , multilevel = true , filename = " basis.json" )
45+ function model_generator (mol, basis_set, totdeg, ν; TD = No_Decomposition (), ratio = 0.5 , multilevel = true , filename = " basis.json" )
2346 if multilevel
24- totdeg_list, ν_list = build_totdeglevels (mol, basis_set, totdeg, ν; ratio = ratio)
25- results = [build_wavefunction (mol, basis_set, totdeg_list[i], ν_list[i]; filename = filename) for i = 1 : length (totdeg_list)]
47+ totdeg_list, ν_list = build_totdeglevels (mol, basis_set, totdeg, ν, TD ; ratio = ratio)
48+ results = [build_wavefunction (mol, basis_set, totdeg_list[i], ν_list[i], TD ; filename = filename) for i = 1 : length (totdeg_list)]
2649 model_list = getindex .(results, 1 )
2750 ps_list = getindex .(results, 2 )
2851 st_list = getindex .(results, 3 )
2952 spec_list = getindex .(results, 4 )
3053 spec1p_list = getindex .(results, 5 )
3154 return model_list, ps_list, st_list, spec_list, spec1p_list, totdeg_list, ν_list
3255 else
33- model, ps, st, spec, spec1p = build_wavefunction (mol, basis_set, totdeg, ν; filename = fieldname)
56+ model, ps, st, spec, spec1p = build_wavefunction (mol, basis_set, totdeg, ν, TD ; filename = fieldname)
3457 return [model], [ps], [st], [spec], [spec1p], [totdeg], [ν]
3558 end
3659end
0 commit comments