Skip to content

Commit b11a06a

Browse files
committed
record shell normalization and allow basisset renormalization
1 parent f87f79f commit b11a06a

File tree

2 files changed

+422
-0
lines changed

2 files changed

+422
-0
lines changed

qcelemental/models/basis.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import math
12
from enum import Enum
3+
from functools import lru_cache
24
from typing import Dict, List, Optional
35

46
from pydantic import Field, constr, validator
57

68
from ..exceptions import ValidationError
79
from .basemodels import ProtoModel
810

11+
M_SQRTPI_CUBED = math.pi * math.sqrt(math.pi)
12+
913

1014
class HarmonicType(str, Enum):
1115
"""
@@ -16,6 +20,13 @@ class HarmonicType(str, Enum):
1620
cartesian = "cartesian"
1721

1822

23+
class NormalizationScheme(str, Enum):
24+
25+
cca = "cca"
26+
erd = "erd"
27+
bse = "bse"
28+
29+
1930
class ElectronShell(ProtoModel):
2031
"""
2132
Information for a single electronic shell
@@ -28,6 +39,9 @@ class ElectronShell(ProtoModel):
2839
...,
2940
description="General contraction coefficients for this shell, individual list components will be the individual segment contraction coefficients.",
3041
)
42+
normalization_scheme: Optional[NormalizationScheme] = Field(
43+
None, description="Normalization scheme for this shell."
44+
)
3145

3246
@validator("coefficients")
3347
def _check_coefficient_length(cls, v, values):
@@ -46,6 +60,45 @@ def _check_general_contraction_or_fused(cls, v, values):
4660

4761
return v
4862

63+
@validator("normalization_scheme", always=True)
64+
def _check_normsch(cls, v, values):
65+
66+
# Bad construction, pass on errors
67+
try:
68+
normsch = cls._calculate_normsch(values["angular_momentum"], values["exponents"], values["coefficients"])
69+
except KeyError:
70+
return v
71+
72+
if v is None:
73+
v = normsch
74+
else:
75+
if v != normsch:
76+
raise ValidationError(f"Calculated normalization_scheme ({normsch}) does not match supplied ({v}).")
77+
78+
return v
79+
80+
@classmethod
81+
def _calculate_normsch(cls, am: int, exps: List[float], coefs: List[List[float]]) -> NormalizationScheme:
82+
def single_am(idx, l):
83+
m = l + 1.5
84+
85+
# try CCA
86+
candidate_already_normalized_coefs = coefs[idx]
87+
norm = cls._cca_contraction_normalization(l, exps, candidate_already_normalized_coefs)
88+
if abs(norm - 1) < 1.0e-10:
89+
return NormalizationScheme.cca
90+
91+
# try ERD
92+
candidate_already_normalized_coefs = [coefs[idx][i] / pow(exps[i], 0.5 * m) for i in range(len(exps))]
93+
norm = cls._erd_normalization(l, exps, candidate_already_normalized_coefs)
94+
if abs(norm - 1) < 1.0e-10:
95+
return NormalizationScheme.erd
96+
97+
# BSE not confirmable
98+
return NormalizationScheme.bse
99+
100+
return _collapse_equal_list(single_am(idx, l) for idx, l in enumerate(am))
101+
49102
def nfunctions(self) -> int:
50103
"""
51104
Computes the number of basis functions on this shell.
@@ -73,6 +126,126 @@ def is_contracted(self) -> bool:
73126

74127
return (len(self.coefficients) != 1) and (len(self.angular_momentum) == 1)
75128

129+
def normalize_shell(self, dtype: NormalizationScheme) -> "ElectronShell":
130+
"""Construct new ElectronShell with coefficients normalized by ``dtype``."""
131+
132+
naive_coefs = self._denormalize_to_bse()
133+
134+
bse_shell = self.dict()
135+
bse_shell["coefficients"] = naive_coefs
136+
bse_shell["normalization_scheme"] = "bse"
137+
bse_shell = ElectronShell(**bse_shell)
138+
139+
if dtype == "bse":
140+
return bse_shell
141+
elif dtype in ["cca", "psi4"]:
142+
norm_coef = bse_shell._cca_normalize_shell()
143+
elif dtype == "erd":
144+
norm_coef = bse_shell._erd_normalize_shell()
145+
146+
new_shell = self.dict()
147+
new_shell["coefficients"] = norm_coef
148+
new_shell["normalization_scheme"] = dtype
149+
return ElectronShell(**new_shell)
150+
151+
def _denormalize_to_bse(self) -> List[List[float]]:
152+
"""Compute replacement coefficients for any-normalization shell ``self`` that are within a scale factor of BSE unnormalization."""
153+
154+
def single_am(idx, l):
155+
156+
if self.normalization_scheme == "cca":
157+
prim_norm = self._cca_primitive_normalization(l, self.exponents)
158+
return [self.coefficients[idx][i] / prim_norm[i] for i in range(len(self.exponents))]
159+
160+
elif self.normalization_scheme == "erd":
161+
m = l + 1.5
162+
return [self.coefficients[idx][i] / pow(self.exponents[i], 0.5 * m) for i in range(len(self.exponents))]
163+
164+
elif self.normalization_scheme == "bse":
165+
return self.coefficients[idx]
166+
167+
return [single_am(idx, l) for idx, l in enumerate(self.angular_momentum)]
168+
169+
@staticmethod
170+
def _cca_primitive_normalization(l: int, exps: List[float]) -> List[float]:
171+
"""Compute CCA normalization factor for primitive shell using angular momentum ``l`` and exponents ``exps``."""
172+
m = l + 1.5
173+
prim_norm = [
174+
math.sqrt((pow(2.0, l) * pow(2.0 * exps[p], m)) / (M_SQRTPI_CUBED * _df(2 * l))) for p in range(len(exps))
175+
]
176+
177+
return prim_norm
178+
179+
@staticmethod
180+
def _cca_contraction_normalization(l: int, exps: List[float], coefs: List[List[float]]) -> float:
181+
"""Compute CCA normalization factor for coefficients ``coefs`` using angular momentum ``l`` and exponents ``exps``."""
182+
183+
m = l + 1.5
184+
summ = 0.0
185+
for i in range(len(exps)):
186+
for j in range(len(exps)):
187+
z = pow(exps[i] + exps[j], m)
188+
summ += (coefs[i] * coefs[j]) / z
189+
190+
tmp = (M_SQRTPI_CUBED * _df(2 * l)) / pow(2.0, l)
191+
norm = math.sqrt(1.0 / (tmp * summ))
192+
# except (ZeroDivisionError, ValueError): [idx][i] = 1.0
193+
194+
return norm
195+
196+
def _cca_normalize_shell(self) -> List[List[float]]:
197+
"""Compute replacement coefficients for unnormalized (BSE-normalized) shell ``self`` that fulfill CCA normalization."""
198+
199+
if self.normalization_scheme != "bse":
200+
raise TypeError('Unnormalized shells expected. Use ``normalize_shell(dtype="cca")`` for flexibility.')
201+
202+
def single_am(idx, l):
203+
prim_norm = ElectronShell._cca_primitive_normalization(l, self.exponents)
204+
norm = ElectronShell._cca_contraction_normalization(
205+
l, self.exponents, [self.coefficients[idx][i] * prim_norm[i] for i in range(len(self.exponents))]
206+
)
207+
return [self.coefficients[idx][i] * norm * prim_norm[i] for i in range(len(self.exponents))]
208+
209+
return [single_am(idx, l) for idx, l in enumerate(self.angular_momentum)]
210+
211+
def _erd_normalization(l: int, exps: List[float], coefs: List[List[float]]) -> float:
212+
"""Compute ERD normalization factor for coefficients ``coefs`` using angular momentum ``l`` and exponents ``exps``."""
213+
214+
m = l + 1.5
215+
summ = 0.0
216+
for i in range(len(exps)):
217+
for j in range(i + 1):
218+
temp = coefs[i] * coefs[j]
219+
temp3 = 2.0 * math.sqrt(exps[i] * exps[j]) / (exps[i] + exps[j])
220+
temp *= pow(temp3, m)
221+
222+
summ += temp
223+
if i != j:
224+
summ += temp
225+
226+
prefac = 1.0
227+
if l > 1:
228+
prefac = pow(2.0, 2 * l) / _df(2 * l)
229+
norm = math.sqrt(prefac / summ)
230+
231+
return norm
232+
233+
def _erd_normalize_shell(self) -> List[List[float]]:
234+
"""Compute replacement coefficients for unnormalized (BSE-normalized) shell ``self`` that fulfill ERD normalization."""
235+
236+
if self.normalization_scheme != "bse":
237+
raise TypeError('Unnormalized shells expected. Use ``normalize_shell(dtype="erd")`` for flexibility.')
238+
239+
def single_am(idx, l):
240+
m = l + 1.5
241+
242+
norm = ElectronShell._erd_normalization(l, self.exponents, self.coefficients[idx])
243+
return [
244+
self.coefficients[idx][i] * norm * pow(self.exponents[i], 0.5 * m) for i in range(len(self.exponents))
245+
]
246+
247+
return [single_am(idx, l) for idx, l in enumerate(self.angular_momentum)]
248+
76249

77250
class ECPType(str, Enum):
78251
"""
@@ -189,3 +362,52 @@ def _calculate_nbf(cls, atom_map, center_data) -> int:
189362
ret += center_count[center]
190363

191364
return ret
365+
366+
def normalize_shell(self, dtype: NormalizationScheme) -> "BasisSet":
367+
"""Construct new BasisSet with coefficients of all shells in center_data normalized by ``dtype``."""
368+
369+
new_bs = self.dict()
370+
371+
for lbl, center in self.center_data.items():
372+
for ish, sh in enumerate(center.electron_shells):
373+
new_bs["center_data"][lbl]["electron_shells"][ish] = sh.normalize_shell(dtype)
374+
375+
return BasisSet(**new_bs)
376+
377+
def normalization_scheme(self) -> NormalizationScheme:
378+
"""Identify probable normalization scheme governing shell ``coefficients`` in center_data.
379+
380+
Returns
381+
-------
382+
NormalizationScheme
383+
Satisfied by all ElectronShells.
384+
385+
Raises
386+
------
387+
TypeError
388+
If the BasisSet's ElectronShells are detected to have inconsistent normalization schemes.
389+
390+
"""
391+
shell_norm = []
392+
for lbl, center in self.center_data.items():
393+
for ish, sh in enumerate(center.electron_shells):
394+
shell_norm.append(sh.normalization_scheme)
395+
396+
return _collapse_equal_list(shell_norm)
397+
398+
399+
@lru_cache(maxsize=500)
400+
def _df(i):
401+
if i in [0, 1, 2]:
402+
return 1.0
403+
else:
404+
return (i - 1) * _df(i - 2)
405+
406+
407+
def _collapse_equal_list(lst):
408+
lst = list(lst)
409+
first = lst[0]
410+
if lst.count(first) == len(lst):
411+
return first
412+
else:
413+
raise TypeError(f"Inconsistent members in list: {lst}")

0 commit comments

Comments
 (0)