Skip to content

Commit 1f05568

Browse files
update BasisState
1 parent 71a3954 commit 1f05568

2 files changed

Lines changed: 129 additions & 25 deletions

File tree

angular.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ function get_state(basisstate::BasisState)
3737
qn = channels[i]
3838
kets[i] = get_ket(qn, basisstate.species)
3939
end
40-
state = rydstate.angular.AngularState(basisstate.coeff, kets; warn_if_not_normalized = false)
40+
coeff = basisstate.coefficients[findall(basisstate.model.core)]
41+
state = rydstate.angular.AngularState(coeff, kets; warn_if_not_normalized = false)
4142
return state
4243
end
4344

utils.jl

Lines changed: 127 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import LinearAlgebra
12
import MQDT
23
include("angular.jl")
34

@@ -14,6 +15,16 @@ macro timelog(expr)
1415
end
1516

1617

18+
function get_relevant_lr(state::MQDT.BasisState)
19+
inds = findall(state.model.core)
20+
return state.lr_list[inds]
21+
end
22+
23+
function get_relevant_nu(state::MQDT.BasisState)
24+
inds = findall(state.model.core)
25+
return state.nu_list[inds]
26+
end
27+
1728
function all_matrix_element(B::MQDT.BasisArray, parameters::MQDT.Parameters)
1829
"""Calculate all relevant matrix elements for a given basis array B.
1930
@@ -29,25 +40,32 @@ function all_matrix_element(B::MQDT.BasisArray, parameters::MQDT.Parameters)
2940
)
3041

3142
states_indexed = [(ids - 1 + START_ID, state) for (ids, state) in enumerate(B.states)]
32-
states_sorted =
33-
sort(states_indexed, by = x -> (minimum(x[2].lr), minimum(x[2].nu), x[1]))
43+
states_sorted = sort(
44+
states_indexed,
45+
by = x ->
46+
(minimum(get_relevant_lr(x[2])), minimum(get_relevant_nu(x[2])), x[1]),
47+
)
3448

35-
for (i1, (id1, b1)) in enumerate(states_sorted)
36-
for (id2, b2) in states_sorted[i1:end]
49+
for (i1, (id1, s1)) in enumerate(states_sorted)
50+
lr1 = get_relevant_lr(s1)
51+
nus1 = get_relevant_nu(s1)
52+
for (id2, s2) in states_sorted[i1:end]
53+
lr2 = get_relevant_lr(s2)
54+
nus2 = get_relevant_nu(s2)
3755

3856
# Skip if all contributions of the two states are far apart in angular momentum
39-
if minimum(b2.lr) - maximum(b1.lr) > k_angular_max
57+
if minimum(lr2) - maximum(lr1) > k_angular_max
4058
continue
4159
end
4260

4361
# Skip if all contributions of the two states are far apart in n and None of them is low-n
44-
if all(abs(nu1-nu2) >= 11 for nu1 in b1.nu for nu2 in b2.nu) &&
45-
all(nu1 > 25 for nu1 in b1.nu) &&
46-
all(nu2 > 25 for nu2 in b2.nu)
62+
if all(abs(nu1-nu2) >= 11 for nu1 in nus1 for nu2 in nus2) &&
63+
all(nu1 > 25 for nu1 in nus1) &&
64+
all(nu2 > 25 for nu2 in nus2)
4765
continue
4866
end
4967

50-
m = MQDT.multipole_moments(b1, b2, parameters)
68+
m = MQDT.multipole_moments(s1, s2, parameters)
5169
# multipole_moments returns the matrix elements in the following order
5270
# electric dipole, electric quadrupole, diamagnetic, magnetic
5371
table_keys = [
@@ -56,7 +74,7 @@ function all_matrix_element(B::MQDT.BasisArray, parameters::MQDT.Parameters)
5674
"matrix_elements_q0",
5775
"matrix_elements_mu",
5876
]
59-
prefactor_transposed = (-1)^(b2.f - b1.f)
77+
prefactor_transposed = (-1)^(s2.f - s1.f)
6078

6179
for (i, key) in enumerate(table_keys)
6280
if m[i] != 0
@@ -84,37 +102,122 @@ function rcv_to_df(row_col_value::Vector{Tuple{Int64,Int64,Float64}})
84102
end
85103

86104

87-
function get_n(T::MQDT.BasisArray)
88-
# TODO for now just return round(nu)
89-
# later calculate radial overlap with different sqdt states and take the corresponding n
90-
nu = MQDT.get_nu(T)
91-
return round.(Int, nu)
92-
end
93-
94-
95105
function basis_to_df(T::MQDT.BasisArray, P::MQDT.Parameters)
96106
df = DataFrame(
97107
id = collect(START_ID:(size(T)-1+START_ID)),
98108
energy = MQDT.get_e(T, P) / 219474.6313632, # convert 1/cm to atomic units
99109
parity = MQDT.get_p(T),
100-
n = get_n(T),
110+
n = get_n(T, P),
101111
nu = MQDT.get_nu(T),
102112
f = MQDT.get_f(T),
103-
exp_nui = MQDT.exp_nui(T),
113+
exp_nui = exp_nui(T),
104114
exp_l = calc_exp_qn(T, "l_tot"),
105115
exp_j = calc_exp_qn(T, "j_tot"),
106116
exp_s = calc_exp_qn(T, "s_tot"),
107117
exp_l_ryd = calc_exp_qn(T, "l_r"),
108118
exp_j_ryd = calc_exp_qn(T, "j_r"),
109-
std_nui = MQDT.std_nui(T),
119+
std_nui = std_nui(T),
110120
std_l = calc_std_qn(T, "l_tot"),
111121
std_j = calc_std_qn(T, "j_tot"),
112122
std_s = calc_std_qn(T, "s_tot"),
113123
std_l_ryd = calc_std_qn(T, "l_r"),
114124
std_j_ryd = calc_std_qn(T, "j_r"),
115-
is_j_total_momentum = MQDT.is_J(T, P),
116-
is_calculated_with_mqdt = MQDT.is_mqdt(T),
117-
underspecified_channel_contribution = MQDT.get_neg(T),
125+
is_j_total_momentum = repeat([iszero(P.spin)], size(T)),
126+
is_calculated_with_mqdt = is_mqdt(T),
127+
underspecified_channel_contribution = get_neg(T),
118128
)
119129
return df
120130
end
131+
132+
133+
function exp_nui(T::MQDT.BasisArray)
134+
t = Vector{Float64}(undef, size(T))
135+
for (i, state) in enumerate(T.states)
136+
t[i] = exp_q(state.nu_list, state.coefficients)
137+
end
138+
return t
139+
end
140+
141+
function std_nui(T::MQDT.BasisArray)
142+
t = Vector{Float64}(undef, size(T))
143+
for (i, state) in enumerate(T.states)
144+
t[i] = std_q(state.nu_list, state.coefficients)
145+
end
146+
return t
147+
end
148+
149+
function is_mqdt(T::MQDT.BasisArray)
150+
t = Vector{Bool}(undef, size(T))
151+
for (i, state) in enumerate(T.states)
152+
t[i] = !isone(length(state.coefficients))
153+
end
154+
return t
155+
end
156+
157+
158+
function get_neg(T::MQDT.BasisArray)
159+
t = Vector{Float64}(undef, size(T))
160+
for (i, state) in enumerate(T.states)
161+
irrel = findall(iszero, state.model.core)
162+
t[i] = sum(state.coefficients[irrel] .^ 2)
163+
end
164+
return t
165+
end
166+
167+
function exp_q(q::Vector, n::Vector)
168+
if allequal(q)
169+
return Float64(q[1])
170+
else
171+
m = n .^ 2
172+
M = sum(m)
173+
if M > 1
174+
m /= M
175+
end
176+
return LinearAlgebra.dot(q, m)
177+
end
178+
end
179+
180+
function std_q(q::Vector, n::Vector)
181+
if allequal(q)
182+
return 0.0
183+
else
184+
m = n .^ 2
185+
M = sum(m)
186+
if M > 1
187+
m /= M
188+
end
189+
e1 = LinearAlgebra.dot(q, m)^2
190+
e2 = LinearAlgebra.dot(q .^ 2, m)
191+
if abs(e1 - e2) < 1e-11
192+
return 0.0
193+
else
194+
return sqrt(e2 - e1)
195+
end
196+
end
197+
end
198+
199+
function get_n(T::MQDT.BasisArray, P::MQDT.Parameters)
200+
nu = MQDT.get_nu(T)
201+
l = round.(Int, calc_exp_qn(T, "l_r"))
202+
return get_n(nu, l, P.species)
203+
end
204+
205+
function get_n(nu::Vector{Float64}, l::Vector{Int}, species::Symbol)
206+
i0 = findall(iszero, l)
207+
i1 = findall(iszero, l .- 1)
208+
i2 = findall(iszero, l .- 2)
209+
i3 = findall(iszero, l .- 3)
210+
j0 = findall(x->x<2, nu)
211+
nu[j0] .+= 1
212+
if occursin("Yb", String(species))
213+
nu[i0] .+= 4
214+
nu[i1] .+= 3
215+
nu[i2] .+= 2
216+
nu[i3] .+= 1
217+
else
218+
nu[i0] .+= 3
219+
nu[i1] .+= 2
220+
nu[i2] .+= 2
221+
end
222+
return ceil.(Int, nu)
223+
end

0 commit comments

Comments
 (0)