Skip to content

Commit e53409c

Browse files
improve basis mqdt
1 parent 9e58cd3 commit e53409c

1 file changed

Lines changed: 20 additions & 16 deletions

File tree

src/rydstate/basis/basis_mqdt.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,22 @@
3434
USE_JULIACALL = False
3535

3636

37-
if USE_JULIACALL:
38-
try:
39-
jl.seval("using MQDT")
40-
jl.seval("using CGcoefficient")
41-
except JuliaError:
42-
logger.exception("Failed to load Julia MQDT or CGcoefficient package")
43-
USE_JULIACALL = False
37+
IS_MQDT_IMPORTED = False
38+
39+
40+
def import_mqdt() -> bool:
41+
"""Load the MQDT Julia package.
42+
43+
Since this might be time-consuming, we only do it if needed and ensure it is called only once.
44+
"""
45+
global IS_MQDT_IMPORTED # noqa: PLW0603
46+
if not IS_MQDT_IMPORTED:
47+
try:
48+
jl.seval("using MQDT")
49+
IS_MQDT_IMPORTED = True
50+
except JuliaError:
51+
logger.exception("Failed to load Julia MQDT package")
52+
return IS_MQDT_IMPORTED
4453

4554

4655
class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]):
@@ -52,20 +61,15 @@ def __init__( # noqa: PLR0915, C901, PLR0912
5261
*,
5362
skip_high_l: bool = True,
5463
) -> None:
64+
if not USE_JULIACALL:
65+
raise ImportError("JuliaCall is not available, try `pip install rydstate[mqdt]`.")
66+
if not import_mqdt():
67+
raise ImportError("Failed to load the MQDT Julia package, try `pip install rydstate[mqdt]`.")
5568
super().__init__(species)
5669

5770
if n_max is None:
5871
raise ValueError("n_max must be given")
5972

60-
if not USE_JULIACALL:
61-
raise ImportError("JuliaCall or the MQDT Julia package is not available.")
62-
63-
# initialize Wigner symbol calculation
64-
if skip_high_l:
65-
jl.CGcoefficient.wigner_init_float(10, "Jmax", 9)
66-
else:
67-
jl.CGcoefficient.wigner_init_float(n_max - 1, "Jmax", 9)
68-
6973
jl_species = jl.Symbol(self.species.name)
7074
parameters = jl.MQDT.get_parameters(jl_species)
7175

0 commit comments

Comments
 (0)