33import logging
44from typing import TYPE_CHECKING , Any
55
6+ import numpy as np
7+
68from rydstate .angular .angular_ket import julia_qn_to_dict
79from rydstate .basis .basis_base import BasisBase
810from rydstate .rydberg .rydberg_mqdt import RydbergStateMQDT
911from rydstate .rydberg .rydberg_sqdt import RydbergStateSQDT
1012
1113if TYPE_CHECKING :
14+ from juliacall import (
15+ JuliaError ,
16+ Main as jl , # noqa: N813
17+ convert ,
18+ )
19+
1220 from rydstate .species import SpeciesObject
1321
1422logger = logging .getLogger (__name__ )
3240 logger .exception ("Failed to load Julia MQDT or CGcoefficient package" )
3341 USE_JULIACALL = False
3442
35- FMODEL_MAX_L = {"Sr87" : 2 , "Sr88" : 2 , "Yb171" : 4 , "Yb173" : 1 , "Yb174" : 4 }
36-
3743
3844class BasisMQDT (BasisBase [RydbergStateMQDT [Any ]]):
3945 def __init__ (
@@ -43,54 +49,50 @@ def __init__(
4349 n_max : int | None = None ,
4450 * ,
4551 skip_high_l : bool = True ,
46- model_names : list [str ] | None = None ,
4752 ) -> None :
4853 super ().__init__ (species )
4954
50- if not USE_JULIACALL :
51- raise ImportError ("JuliaCall or the MQDT Julia package is not available." )
52-
53- try :
54- self .jl_species = getattr (jl .MQDT , self .species .name )
55- parameters = self .jl_species .PARA
56- except AttributeError as e :
57- raise ValueError (f"Species '{ species } ' is not supported in the MQDT Julia package." ) from e
58-
59- # TODO use n_min and n_max of the different models
60-
6155 if n_max is None :
6256 raise ValueError ("n_max must be given" )
6357
58+ if not USE_JULIACALL :
59+ raise ImportError ("JuliaCall or the MQDT Julia package is not available." )
60+
6461 # initialize Wigner symbol calculation
6562 if skip_high_l :
66- jl .CGcoefficient .wigner_init_float (5 , "Jmax" , 9 )
63+ jl .CGcoefficient .wigner_init_float (10 , "Jmax" , 9 )
6764 else :
6865 jl .CGcoefficient .wigner_init_float (n_max - 1 , "Jmax" , 9 )
6966
70- logger .debug ("Calculating low l MQDT states..." )
71-
72- jl_species_attr_names = [str (name ) for name in jl .seval (f"names(MQDT.{ self .species .name } , all=true)" )]
73- self .models = {name : getattr (self .jl_species , name ) for name in jl_species_attr_names }
74- self .models = {k : v for k , v in self .models .items () if str (v ).startswith ("fModel" )}
75- if model_names is not None :
76- self .models = {k : v for k , v in self .models .items () if k in model_names }
77-
78- if skip_high_l :
79- logger .debug ("Skipping high l states." )
80- else :
81- logger .debug ("Calculating high l SQDT states..." )
82- l_start = FMODEL_MAX_L [self .species .name ] + 1
83- high_l_models = {
84- f"high_l_{ l_ryd } " : jl .single_channel_models (species , l_ryd , parameters )
85- for l_ryd in range (l_start , n_max )
86- }
87- self .models .update (high_l_models )
88-
89- model_names = list (self .models .keys ())
90- jl_states = {name : jl .eigenstates (n_min , n_max , model , parameters ) for name , model in self .models .items ()}
91- _models_vector = convert (jl .Vector , [self .models [name ] for name in model_names ])
92- _jl_states_vector = convert (jl .Vector , [jl_states [name ] for name in model_names ])
93- jl_basis = jl .basisarray (_jl_states_vector , _models_vector )
67+ jl_species = jl .Symbol (self .species .name )
68+ parameters = jl .MQDT .get_parameters (jl_species )
69+
70+ self .models = []
71+ i_c = self .species .i_c if self .species .i_c is not None else 0
72+ for l in range (n_max ):
73+ jtot_min = min (l , abs (l - 1 ))
74+ jtot_max = l + 1
75+ for f_tot in np .arange (abs (jtot_min - i_c ), jtot_max + i_c + 1 ):
76+ models = jl .MQDT .get_fmodels (jl_species , l , f_tot )
77+ self .models .extend (models )
78+
79+ n_min_high_l = 25
80+
81+ logger .debug ("Calculating MQDT states..." )
82+ jl_states = []
83+ for model in self .models :
84+ _n_min = n_min
85+ if model .name .startswith ("SQDT" ):
86+ if skip_high_l :
87+ continue
88+ _n_min = n_min_high_l
89+
90+ logger .debug (f"{ model .name } :" )
91+ states = jl .MQDT .eigenstates (_n_min , n_max , model , parameters )
92+ jl_states .append (states )
93+ logger .debug (f" found nu_min={ min (states .n )} , nu_max={ max (states .n )} , total states={ len (states .n )} " )
94+
95+ jl_basis = jl .basisarray (convert (jl .Vector , jl_states ), convert (jl .Vector , self .models ))
9496
9597 logger .debug ("Generated state table with %d states" , len (jl_basis .states ))
9698
0 commit comments