1+ import math
12from enum import Enum
3+ from functools import lru_cache
24from typing import Dict , List , Optional
35
46from pydantic import Field , constr , validator
57
68from ..exceptions import ValidationError
79from .basemodels import ProtoModel
810
11+ M_SQRTPI_CUBED = math .pi * math .sqrt (math .pi )
12+
913
1014class HarmonicType (str , Enum ):
1115 """
@@ -16,6 +20,13 @@ class HarmonicType(str, Enum):
1620 cartesian = "cartesian"
1721
1822
23+ class NormalizationScheme (str , Enum ):
24+
25+ cca = "cca"
26+ erd = "erd"
27+ bse = "bse"
28+
29+
1930class ElectronShell (ProtoModel ):
2031 """
2132 Information for a single electronic shell
@@ -28,6 +39,9 @@ class ElectronShell(ProtoModel):
2839 ...,
2940 description = "General contraction coefficients for this shell, individual list components will be the individual segment contraction coefficients." ,
3041 )
42+ normalization_scheme : Optional [NormalizationScheme ] = Field (
43+ None , description = "Normalization scheme for this shell."
44+ )
3145
3246 @validator ("coefficients" )
3347 def _check_coefficient_length (cls , v , values ):
@@ -46,6 +60,45 @@ def _check_general_contraction_or_fused(cls, v, values):
4660
4761 return v
4862
63+ @validator ("normalization_scheme" , always = True )
64+ def _check_normsch (cls , v , values ):
65+
66+ # Bad construction, pass on errors
67+ try :
68+ normsch = cls ._calculate_normsch (values ["angular_momentum" ], values ["exponents" ], values ["coefficients" ])
69+ except KeyError :
70+ return v
71+
72+ if v is None :
73+ v = normsch
74+ else :
75+ if v != normsch :
76+ raise ValidationError (f"Calculated normalization_scheme ({ normsch } ) does not match supplied ({ v } )." )
77+
78+ return v
79+
80+ @classmethod
81+ def _calculate_normsch (cls , am : int , exps : List [float ], coefs : List [List [float ]]) -> NormalizationScheme :
82+ def single_am (idx , l ):
83+ m = l + 1.5
84+
85+ # try CCA
86+ candidate_already_normalized_coefs = coefs [idx ]
87+ norm = cls ._cca_contraction_normalization (l , exps , candidate_already_normalized_coefs )
88+ if abs (norm - 1 ) < 1.0e-10 :
89+ return NormalizationScheme .cca
90+
91+ # try ERD
92+ candidate_already_normalized_coefs = [coefs [idx ][i ] / pow (exps [i ], 0.5 * m ) for i in range (len (exps ))]
93+ norm = cls ._erd_normalization (l , exps , candidate_already_normalized_coefs )
94+ if abs (norm - 1 ) < 1.0e-10 :
95+ return NormalizationScheme .erd
96+
97+ # BSE not confirmable
98+ return NormalizationScheme .bse
99+
100+ return _collapse_equal_list (single_am (idx , l ) for idx , l in enumerate (am ))
101+
49102 def nfunctions (self ) -> int :
50103 """
51104 Computes the number of basis functions on this shell.
@@ -73,6 +126,126 @@ def is_contracted(self) -> bool:
73126
74127 return (len (self .coefficients ) != 1 ) and (len (self .angular_momentum ) == 1 )
75128
129+ def normalize_shell (self , dtype : NormalizationScheme ) -> "ElectronShell" :
130+ """Construct new ElectronShell with coefficients normalized by ``dtype``."""
131+
132+ naive_coefs = self ._denormalize_to_bse ()
133+
134+ bse_shell = self .dict ()
135+ bse_shell ["coefficients" ] = naive_coefs
136+ bse_shell ["normalization_scheme" ] = "bse"
137+ bse_shell = ElectronShell (** bse_shell )
138+
139+ if dtype == "bse" :
140+ return bse_shell
141+ elif dtype in ["cca" , "psi4" ]:
142+ norm_coef = bse_shell ._cca_normalize_shell ()
143+ elif dtype == "erd" :
144+ norm_coef = bse_shell ._erd_normalize_shell ()
145+
146+ new_shell = self .dict ()
147+ new_shell ["coefficients" ] = norm_coef
148+ new_shell ["normalization_scheme" ] = dtype
149+ return ElectronShell (** new_shell )
150+
151+ def _denormalize_to_bse (self ) -> List [List [float ]]:
152+ """Compute replacement coefficients for any-normalization shell ``self`` that are within a scale factor of BSE unnormalization."""
153+
154+ def single_am (idx , l ):
155+
156+ if self .normalization_scheme == "cca" :
157+ prim_norm = self ._cca_primitive_normalization (l , self .exponents )
158+ return [self .coefficients [idx ][i ] / prim_norm [i ] for i in range (len (self .exponents ))]
159+
160+ elif self .normalization_scheme == "erd" :
161+ m = l + 1.5
162+ return [self .coefficients [idx ][i ] / pow (self .exponents [i ], 0.5 * m ) for i in range (len (self .exponents ))]
163+
164+ elif self .normalization_scheme == "bse" :
165+ return self .coefficients [idx ]
166+
167+ return [single_am (idx , l ) for idx , l in enumerate (self .angular_momentum )]
168+
169+ @staticmethod
170+ def _cca_primitive_normalization (l : int , exps : List [float ]) -> List [float ]:
171+ """Compute CCA normalization factor for primitive shell using angular momentum ``l`` and exponents ``exps``."""
172+ m = l + 1.5
173+ prim_norm = [
174+ math .sqrt ((pow (2.0 , l ) * pow (2.0 * exps [p ], m )) / (M_SQRTPI_CUBED * _df (2 * l ))) for p in range (len (exps ))
175+ ]
176+
177+ return prim_norm
178+
179+ @staticmethod
180+ def _cca_contraction_normalization (l : int , exps : List [float ], coefs : List [List [float ]]) -> float :
181+ """Compute CCA normalization factor for coefficients ``coefs`` using angular momentum ``l`` and exponents ``exps``."""
182+
183+ m = l + 1.5
184+ summ = 0.0
185+ for i in range (len (exps )):
186+ for j in range (len (exps )):
187+ z = pow (exps [i ] + exps [j ], m )
188+ summ += (coefs [i ] * coefs [j ]) / z
189+
190+ tmp = (M_SQRTPI_CUBED * _df (2 * l )) / pow (2.0 , l )
191+ norm = math .sqrt (1.0 / (tmp * summ ))
192+ # except (ZeroDivisionError, ValueError): [idx][i] = 1.0
193+
194+ return norm
195+
196+ def _cca_normalize_shell (self ) -> List [List [float ]]:
197+ """Compute replacement coefficients for unnormalized (BSE-normalized) shell ``self`` that fulfill CCA normalization."""
198+
199+ if self .normalization_scheme != "bse" :
200+ raise TypeError ('Unnormalized shells expected. Use ``normalize_shell(dtype="cca")`` for flexibility.' )
201+
202+ def single_am (idx , l ):
203+ prim_norm = ElectronShell ._cca_primitive_normalization (l , self .exponents )
204+ norm = ElectronShell ._cca_contraction_normalization (
205+ l , self .exponents , [self .coefficients [idx ][i ] * prim_norm [i ] for i in range (len (self .exponents ))]
206+ )
207+ return [self .coefficients [idx ][i ] * norm * prim_norm [i ] for i in range (len (self .exponents ))]
208+
209+ return [single_am (idx , l ) for idx , l in enumerate (self .angular_momentum )]
210+
211+ def _erd_normalization (l : int , exps : List [float ], coefs : List [List [float ]]) -> float :
212+ """Compute ERD normalization factor for coefficients ``coefs`` using angular momentum ``l`` and exponents ``exps``."""
213+
214+ m = l + 1.5
215+ summ = 0.0
216+ for i in range (len (exps )):
217+ for j in range (i + 1 ):
218+ temp = coefs [i ] * coefs [j ]
219+ temp3 = 2.0 * math .sqrt (exps [i ] * exps [j ]) / (exps [i ] + exps [j ])
220+ temp *= pow (temp3 , m )
221+
222+ summ += temp
223+ if i != j :
224+ summ += temp
225+
226+ prefac = 1.0
227+ if l > 1 :
228+ prefac = pow (2.0 , 2 * l ) / _df (2 * l )
229+ norm = math .sqrt (prefac / summ )
230+
231+ return norm
232+
233+ def _erd_normalize_shell (self ) -> List [List [float ]]:
234+ """Compute replacement coefficients for unnormalized (BSE-normalized) shell ``self`` that fulfill ERD normalization."""
235+
236+ if self .normalization_scheme != "bse" :
237+ raise TypeError ('Unnormalized shells expected. Use ``normalize_shell(dtype="erd")`` for flexibility.' )
238+
239+ def single_am (idx , l ):
240+ m = l + 1.5
241+
242+ norm = ElectronShell ._erd_normalization (l , self .exponents , self .coefficients [idx ])
243+ return [
244+ self .coefficients [idx ][i ] * norm * pow (self .exponents [i ], 0.5 * m ) for i in range (len (self .exponents ))
245+ ]
246+
247+ return [single_am (idx , l ) for idx , l in enumerate (self .angular_momentum )]
248+
76249
77250class ECPType (str , Enum ):
78251 """
@@ -189,3 +362,52 @@ def _calculate_nbf(cls, atom_map, center_data) -> int:
189362 ret += center_count [center ]
190363
191364 return ret
365+
366+ def normalize_shell (self , dtype : NormalizationScheme ) -> "BasisSet" :
367+ """Construct new BasisSet with coefficients of all shells in center_data normalized by ``dtype``."""
368+
369+ new_bs = self .dict ()
370+
371+ for lbl , center in self .center_data .items ():
372+ for ish , sh in enumerate (center .electron_shells ):
373+ new_bs ["center_data" ][lbl ]["electron_shells" ][ish ] = sh .normalize_shell (dtype )
374+
375+ return BasisSet (** new_bs )
376+
377+ def normalization_scheme (self ) -> NormalizationScheme :
378+ """Identify probable normalization scheme governing shell ``coefficients`` in center_data.
379+
380+ Returns
381+ -------
382+ NormalizationScheme
383+ Satisfied by all ElectronShells.
384+
385+ Raises
386+ ------
387+ TypeError
388+ If the BasisSet's ElectronShells are detected to have inconsistent normalization schemes.
389+
390+ """
391+ shell_norm = []
392+ for lbl , center in self .center_data .items ():
393+ for ish , sh in enumerate (center .electron_shells ):
394+ shell_norm .append (sh .normalization_scheme )
395+
396+ return _collapse_equal_list (shell_norm )
397+
398+
399+ @lru_cache (maxsize = 500 )
400+ def _df (i ):
401+ if i in [0 , 1 , 2 ]:
402+ return 1.0
403+ else :
404+ return (i - 1 ) * _df (i - 2 )
405+
406+
407+ def _collapse_equal_list (lst ):
408+ lst = list (lst )
409+ first = lst [0 ]
410+ if lst .count (first ) == len (lst ):
411+ return first
412+ else :
413+ raise TypeError (f"Inconsistent members in list: { lst } " )
0 commit comments