Skip to content

Commit e1034c3

Browse files
authored
Add an option to change the orbital order when folding. (#101)
Add an option to choose between different ways of sorting orbitals in the `fold_model` method: Either by their index in the given model (the previous default), or by the occurrence of keys in the `orbital_labels` input (the new default). The reason for this new feature is that the sort order can be unpredictable when it is unclear which orbitals end up inside or outside of the new unit cell.
1 parent eb5393c commit e1034c3

File tree

3 files changed

+106
-20
lines changed

3 files changed

+106
-20
lines changed

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ dev =
5656
pre-commit
5757
pylint==2.6.0
5858
isort==5.5.1
59-
mypy==0.782
59+
mypy==0.812
6060
ruamel.yaml
61+
typing-extensions
6162

6263
[options.entry_points]
6364
console_scripts =

tbmodels/_tb_model.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
Implements the Model class, which describes a tight-binding model.
77
"""
88

9+
from __future__ import annotations
10+
911
import re
1012
import os
1113
import copy
@@ -24,6 +26,8 @@
2426
from fsc.hdf5_io import subscribe_hdf5, HDF5Enabled
2527

2628
if ty.TYPE_CHECKING:
29+
# Replace with typing.Literal once Python 3.7 support is dropped.
30+
from typing_extensions import Literal
2731
import symmetry_representation # pylint: disable=unused-import
2832

2933
from .kdotp import KdotpModel
@@ -334,7 +338,7 @@ def from_hop_list(
334338
hop_list: ty.Iterable[ty.Tuple[complex, int, int, ty.Tuple[int, ...]]] = (),
335339
size: ty.Optional[int] = None,
336340
**kwargs,
337-
) -> "Model":
341+
) -> Model:
338342
"""
339343
Create a :class:`.Model` from a list of hopping terms.
340344
@@ -529,7 +533,7 @@ def _mat_to_hr(R, mat):
529533
@classmethod
530534
def from_wannier_folder(
531535
cls, folder: str = ".", prefix: str = "wannier", **kwargs
532-
) -> "Model":
536+
) -> Model:
533537
"""
534538
Create a :class:`.Model` instance from Wannier90 output files,
535539
given the folder containing the files and file prefix.
@@ -573,7 +577,7 @@ def from_wannier_files( # pylint: disable=too-many-locals
573577
pos_kind: str = "wannier",
574578
distance_ratio_threshold: float = 3.0,
575579
**kwargs,
576-
) -> "Model":
580+
) -> Model:
577581
"""
578582
Create a :class:`.Model` instance from Wannier90 output files.
579583
@@ -988,7 +992,7 @@ def construct_kdotp(self, k: ty.Sequence[float], order: int):
988992
@classmethod
989993
def from_hdf5_file( # pylint: disable=arguments-differ
990994
cls, hdf5_file: str, **kwargs
991-
) -> "Model":
995+
) -> Model:
992996
"""
993997
Returns a :class:`.Model` instance read from a file in HDF5
994998
format.
@@ -1013,7 +1017,7 @@ def from_hdf5_file( # pylint: disable=arguments-differ
10131017
@classmethod
10141018
def from_hdf5( # pylint: disable=arguments-differ
10151019
cls, hdf5_handle, **kwargs
1016-
) -> "Model":
1020+
) -> Model:
10171021
# For compatibility with a development version which created a top-level
10181022
# 'tb_model' attribute.
10191023
try:
@@ -1349,10 +1353,10 @@ def _input_kwargs(self):
13491353

13501354
def symmetrize(
13511355
self,
1352-
symmetries: ty.Sequence["symmetry_representation.SymmetryOperation"],
1356+
symmetries: ty.Sequence[symmetry_representation.SymmetryOperation],
13531357
full_group: bool = False,
13541358
position_tolerance: float = 1e-5,
1355-
) -> "Model":
1359+
) -> Model:
13561360
"""
13571361
Returns a model which is symmetrized w.r.t. the given
13581362
symmetries. This is done by performing a group average over the
@@ -1404,7 +1408,7 @@ def symmetrize(
14041408

14051409
def _apply_operation( # pylint: disable=too-many-locals
14061410
self, symmetry_operation, position_tolerance
1407-
) -> "Model":
1411+
) -> Model:
14081412
"""
14091413
Helper function to apply a symmetry operation to the model.
14101414
"""
@@ -1473,7 +1477,7 @@ def _apply_operation( # pylint: disable=too-many-locals
14731477

14741478
return Model(**co.ChainMap(dict(hop=new_hop), self._input_kwargs))
14751479

1476-
def slice_orbitals(self, slice_idx: ty.List[int]) -> "Model":
1480+
def slice_orbitals(self, slice_idx: ty.List[int]) -> Model:
14771481
"""
14781482
Returns a new model with only the orbitals as given in the
14791483
``slice_idx``. This can also be used to re-order the orbitals.
@@ -1491,7 +1495,7 @@ def slice_orbitals(self, slice_idx: ty.List[int]) -> "Model":
14911495
return Model(**co.ChainMap(dict(hop=new_hop, pos=new_pos), self._input_kwargs))
14921496

14931497
@classmethod
1494-
def join_models(cls, *models: "Model") -> "Model":
1498+
def join_models(cls, *models: Model) -> Model:
14951499
"""
14961500
Creates a tight-binding model which contains all orbitals of the
14971501
given input models. The orbitals are ordered by model, such that
@@ -1559,7 +1563,7 @@ def change_unit_cell( # pylint: disable=too-many-branches
15591563
uc: ty.Optional[ty.Sequence[ty.Sequence[float]]] = None,
15601564
offset: ty.Sequence[float] = (0, 0, 0),
15611565
cartesian: bool = False,
1562-
) -> "Model":
1566+
) -> Model:
15631567
"""Return a model with a different unit cell of the same volume.
15641568
15651569
Creates a model with a changed unit cell - with a different
@@ -1646,7 +1650,7 @@ def change_unit_cell( # pylint: disable=too-many-branches
16461650

16471651
def supercell( # pylint: disable=too-many-locals
16481652
self, size: ty.Sequence[int]
1649-
) -> "Model":
1653+
) -> Model:
16501654
"""Generate a model for a supercell of the current unit cell.
16511655
16521656
Parameters
@@ -1738,7 +1742,8 @@ def fold_model( # pylint: disable=too-many-locals,too-many-branches,too-many-st
17381742
check_uc_volume: bool = True,
17391743
uc_volume_tolerance: float = 1e-6,
17401744
check_orbital_ratio: bool = True,
1741-
) -> "Model":
1745+
order_by: Literal["label", "index"] = "label",
1746+
) -> Model:
17421747
"""
17431748
Returns a model with a smaller unit cell. Orbitals which are related
17441749
by a lattice vector of the new unit cell are "folded" into a single
@@ -1789,6 +1794,15 @@ def fold_model( # pylint: disable=too-many-locals,too-many-branches,too-many-st
17891794
should be checked to be the same as the initial ratio. If
17901795
this is set to False, the resulting model will always
17911796
have ``occ=None``.
1797+
order_by :
1798+
Determines how the orbitals in the new model are ordered.
1799+
For ``order_by="index"``, the orbitals are ordered exactly the
1800+
same as in the original model. Note that this order will
1801+
depend on which orbitals end up inside the unit cell.
1802+
For ``order_by="label"``, the orbitals are ordered
1803+
according to the **occurrence** (not the value) of the
1804+
labels in the ``orbital_labels`` input. Orbitals with the
1805+
same label are again ordered by index.
17921806
"""
17931807
if len(orbital_labels) != self.size:
17941808
raise ValueError(
@@ -1863,6 +1877,24 @@ def fold_model( # pylint: disable=too-many-locals,too-many-branches,too-many-st
18631877
axis=-1,
18641878
)
18651879
).flatten()
1880+
if order_by == "label":
1881+
idx = 0
1882+
orbital_sort_idx = {}
1883+
for label in orbital_labels:
1884+
if label not in orbital_sort_idx:
1885+
orbital_sort_idx[label] = idx
1886+
idx += 1
1887+
in_uc_sort_idx = [
1888+
orbital_sort_idx[orbital_labels[i]] for i in in_uc_indices
1889+
]
1890+
in_uc_indices = in_uc_indices[
1891+
np.argsort(in_uc_sort_idx, kind="mergesort") # need stable sorting
1892+
]
1893+
else:
1894+
if order_by != "index":
1895+
raise ValueError(
1896+
f"Invalid input '{order_by}' for 'order_by', must be either 'label' or 'index'."
1897+
)
18661898
if target_indices is not None:
18671899
if not np.all(target_indices == in_uc_indices):
18681900
raise ValueError(
@@ -2001,7 +2033,7 @@ def get_matching_idx_and_offset(pos_reduced, orbital_label):
20012033
)
20022034
)
20032035

2004-
def __add__(self, model: "Model") -> "Model":
2036+
def __add__(self, model: Model) -> Model:
20052037
"""
20062038
Adds two models together by adding their hopping terms.
20072039
"""
@@ -2059,19 +2091,19 @@ def __add__(self, model: "Model") -> "Model":
20592091
# -------------------
20602092
return Model(**co.ChainMap(dict(hop=new_hop), self._input_kwargs))
20612093

2062-
def __sub__(self, model: "Model") -> "Model":
2094+
def __sub__(self, model: Model) -> Model:
20632095
"""
20642096
Substracts one model from another by substracting all hopping terms.
20652097
"""
20662098
return self + -model
20672099

2068-
def __neg__(self) -> "Model":
2100+
def __neg__(self) -> Model:
20692101
"""
20702102
Changes the sign of all hopping terms.
20712103
"""
20722104
return -1 * self
20732105

2074-
def __mul__(self, x: float) -> "Model":
2106+
def __mul__(self, x: float) -> Model:
20752107
"""
20762108
Multiplies hopping terms by x.
20772109
"""
@@ -2081,13 +2113,13 @@ def __mul__(self, x: float) -> "Model":
20812113

20822114
return Model(**co.ChainMap(dict(hop=new_hop), self._input_kwargs))
20832115

2084-
def __rmul__(self, x: float) -> "Model":
2116+
def __rmul__(self, x: float) -> Model:
20852117
"""
20862118
Multiplies hopping terms by x.
20872119
"""
20882120
return self.__mul__(x)
20892121

2090-
def __truediv__(self, x: float) -> "Model":
2122+
def __truediv__(self, x: float) -> Model:
20912123
"""
20922124
Divides hopping terms by x.
20932125
"""

tests/test_folding.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,56 @@ def test_new_unit_cell_check(get_model):
202202
orbital_labels=["a", "b"] * 2,
203203
)
204204
assert "new unit cell is not contained" in str(excinfo.value)
205+
206+
207+
def test_fold_by_label_order(get_model, models_close):
208+
"""
209+
Test that the ``order_by`` input creates the expected result.
210+
"""
211+
model = get_model(
212+
0.1,
213+
0.4,
214+
uc=[[1, 0, 0.5], [0.1, 0.4, 0.0], [0.0, 0.0, 1.2]],
215+
pos=[[0.1, 0.05, 0.2], [0.62, 0.3, 0.7]],
216+
)
217+
supercell_model = model.supercell(size=(1, 1, 3))
218+
219+
# switch the order of the two target orbitals
220+
supercell_model_resliced = supercell_model.slice_orbitals(
221+
slice_idx=[0, 1, 3, 2, 4, 5]
222+
)
223+
orbital_labels = ["a", "b", "b", "a", "a", "b"]
224+
folded_model_by_label = supercell_model_resliced.fold_model(
225+
new_unit_cell=model.uc,
226+
unit_cell_offset=supercell_model_resliced.uc.T @ [0, 0, 1 / 3],
227+
orbital_labels=orbital_labels,
228+
target_indices=[3, 2],
229+
order_by="label",
230+
)
231+
folded_model_by_index = supercell_model_resliced.fold_model(
232+
new_unit_cell=model.uc,
233+
unit_cell_offset=supercell_model_resliced.uc.T @ [0, 0, 1 / 3],
234+
orbital_labels=orbital_labels,
235+
target_indices=[2, 3],
236+
order_by="index",
237+
)
238+
assert models_close(model.slice_orbitals([1, 0]), folded_model_by_index)
239+
assert models_close(model, folded_model_by_label)
240+
241+
242+
def test_invalid_order_by(get_model):
243+
"""Test that passing an invalid ``order_by`` raises."""
244+
model = get_model(
245+
0.1,
246+
0.3,
247+
uc=np.diag([1, 2, 3]),
248+
pos=[[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
249+
)
250+
supercell_model = model.supercell(size=(1, 1, 2))
251+
with pytest.raises(ValueError) as excinfo:
252+
supercell_model.fold_model(
253+
new_unit_cell=model.uc, orbital_labels=["a", "b"] * 2, order_by="spam"
254+
)
255+
assert "Invalid input" in str(excinfo.value)
256+
assert "order_by" in str(excinfo.value)
257+
assert "spam" in str(excinfo.value)

0 commit comments

Comments
 (0)