Skip to content

Commit aeb2125

Browse files
update MQDT improved fModel
1 parent 29b7369 commit aeb2125

1 file changed

Lines changed: 41 additions & 39 deletions

File tree

src/rydstate/basis/basis_mqdt.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33
import logging
44
from typing import TYPE_CHECKING, Any
55

6+
import numpy as np
7+
68
from rydstate.angular.angular_ket import julia_qn_to_dict
79
from rydstate.basis.basis_base import BasisBase
810
from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT
911
from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT
1012

1113
if TYPE_CHECKING:
14+
from juliacall import (
15+
JuliaError,
16+
Main as jl, # noqa: N813
17+
convert,
18+
)
19+
1220
from rydstate.species import SpeciesObject
1321

1422
logger = logging.getLogger(__name__)
@@ -32,8 +40,6 @@
3240
logger.exception("Failed to load Julia MQDT or CGcoefficient package")
3341
USE_JULIACALL = False
3442

35-
FMODEL_MAX_L = {"Sr87": 2, "Sr88": 2, "Yb171": 4, "Yb173": 1, "Yb174": 4}
36-
3743

3844
class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]):
3945
def __init__(
@@ -43,54 +49,50 @@ def __init__(
4349
n_max: int | None = None,
4450
*,
4551
skip_high_l: bool = True,
46-
model_names: list[str] | None = None,
4752
) -> None:
4853
super().__init__(species)
4954

50-
if not USE_JULIACALL:
51-
raise ImportError("JuliaCall or the MQDT Julia package is not available.")
52-
53-
try:
54-
self.jl_species = getattr(jl.MQDT, self.species.name)
55-
parameters = self.jl_species.PARA
56-
except AttributeError as e:
57-
raise ValueError(f"Species '{species}' is not supported in the MQDT Julia package.") from e
58-
59-
# TODO use n_min and n_max of the different models
60-
6155
if n_max is None:
6256
raise ValueError("n_max must be given")
6357

58+
if not USE_JULIACALL:
59+
raise ImportError("JuliaCall or the MQDT Julia package is not available.")
60+
6461
# initialize Wigner symbol calculation
6562
if skip_high_l:
66-
jl.CGcoefficient.wigner_init_float(5, "Jmax", 9)
63+
jl.CGcoefficient.wigner_init_float(10, "Jmax", 9)
6764
else:
6865
jl.CGcoefficient.wigner_init_float(n_max - 1, "Jmax", 9)
6966

70-
logger.debug("Calculating low l MQDT states...")
71-
72-
jl_species_attr_names = [str(name) for name in jl.seval(f"names(MQDT.{self.species.name}, all=true)")]
73-
self.models = {name: getattr(self.jl_species, name) for name in jl_species_attr_names}
74-
self.models = {k: v for k, v in self.models.items() if str(v).startswith("fModel")}
75-
if model_names is not None:
76-
self.models = {k: v for k, v in self.models.items() if k in model_names}
77-
78-
if skip_high_l:
79-
logger.debug("Skipping high l states.")
80-
else:
81-
logger.debug("Calculating high l SQDT states...")
82-
l_start = FMODEL_MAX_L[self.species.name] + 1
83-
high_l_models = {
84-
f"high_l_{l_ryd}": jl.single_channel_models(species, l_ryd, parameters)
85-
for l_ryd in range(l_start, n_max)
86-
}
87-
self.models.update(high_l_models)
88-
89-
model_names = list(self.models.keys())
90-
jl_states = {name: jl.eigenstates(n_min, n_max, model, parameters) for name, model in self.models.items()}
91-
_models_vector = convert(jl.Vector, [self.models[name] for name in model_names])
92-
_jl_states_vector = convert(jl.Vector, [jl_states[name] for name in model_names])
93-
jl_basis = jl.basisarray(_jl_states_vector, _models_vector)
67+
jl_species = jl.Symbol(self.species.name)
68+
parameters = jl.MQDT.get_parameters(jl_species)
69+
70+
self.models = []
71+
i_c = self.species.i_c if self.species.i_c is not None else 0
72+
for l in range(n_max):
73+
jtot_min = min(l, abs(l - 1))
74+
jtot_max = l + 1
75+
for f_tot in np.arange(abs(jtot_min - i_c), jtot_max + i_c + 1):
76+
models = jl.MQDT.get_fmodels(jl_species, l, f_tot)
77+
self.models.extend(models)
78+
79+
n_min_high_l = 25
80+
81+
logger.debug("Calculating MQDT states...")
82+
jl_states = []
83+
for model in self.models:
84+
_n_min = n_min
85+
if model.name.startswith("SQDT"):
86+
if skip_high_l:
87+
continue
88+
_n_min = n_min_high_l
89+
90+
logger.debug(f"{model.name}:")
91+
states = jl.MQDT.eigenstates(_n_min, n_max, model, parameters)
92+
jl_states.append(states)
93+
logger.debug(f" found nu_min={min(states.n)}, nu_max={max(states.n)}, total states={len(states.n)}")
94+
95+
jl_basis = jl.basisarray(convert(jl.Vector, jl_states), convert(jl.Vector, self.models))
9496

9597
logger.debug("Generated state table with %d states", len(jl_basis.states))
9698

0 commit comments

Comments
 (0)