66Implements the Model class, which describes a tight-binding model.
77"""
88
9+ from __future__ import annotations
10+
911import re
1012import os
1113import copy
2426from fsc .hdf5_io import subscribe_hdf5 , HDF5Enabled
2527
2628if ty .TYPE_CHECKING :
29+ # Replace with typing.Literal once Python 3.7 support is dropped.
30+ from typing_extensions import Literal
2731 import symmetry_representation # pylint: disable=unused-import
2832
2933from .kdotp import KdotpModel
@@ -334,7 +338,7 @@ def from_hop_list(
334338 hop_list : ty .Iterable [ty .Tuple [complex , int , int , ty .Tuple [int , ...]]] = (),
335339 size : ty .Optional [int ] = None ,
336340 ** kwargs ,
337- ) -> " Model" :
341+ ) -> Model :
338342 """
339343 Create a :class:`.Model` from a list of hopping terms.
340344
@@ -529,7 +533,7 @@ def _mat_to_hr(R, mat):
529533 @classmethod
530534 def from_wannier_folder (
531535 cls , folder : str = "." , prefix : str = "wannier" , ** kwargs
532- ) -> " Model" :
536+ ) -> Model :
533537 """
534538 Create a :class:`.Model` instance from Wannier90 output files,
535539 given the folder containing the files and file prefix.
@@ -573,7 +577,7 @@ def from_wannier_files( # pylint: disable=too-many-locals
573577 pos_kind : str = "wannier" ,
574578 distance_ratio_threshold : float = 3.0 ,
575579 ** kwargs ,
576- ) -> " Model" :
580+ ) -> Model :
577581 """
578582 Create a :class:`.Model` instance from Wannier90 output files.
579583
@@ -988,7 +992,7 @@ def construct_kdotp(self, k: ty.Sequence[float], order: int):
988992 @classmethod
989993 def from_hdf5_file ( # pylint: disable=arguments-differ
990994 cls , hdf5_file : str , ** kwargs
991- ) -> " Model" :
995+ ) -> Model :
992996 """
993997 Returns a :class:`.Model` instance read from a file in HDF5
994998 format.
@@ -1013,7 +1017,7 @@ def from_hdf5_file( # pylint: disable=arguments-differ
10131017 @classmethod
10141018 def from_hdf5 ( # pylint: disable=arguments-differ
10151019 cls , hdf5_handle , ** kwargs
1016- ) -> " Model" :
1020+ ) -> Model :
10171021 # For compatibility with a development version which created a top-level
10181022 # 'tb_model' attribute.
10191023 try :
@@ -1349,10 +1353,10 @@ def _input_kwargs(self):
13491353
13501354 def symmetrize (
13511355 self ,
1352- symmetries : ty .Sequence [" symmetry_representation.SymmetryOperation" ],
1356+ symmetries : ty .Sequence [symmetry_representation .SymmetryOperation ],
13531357 full_group : bool = False ,
13541358 position_tolerance : float = 1e-5 ,
1355- ) -> " Model" :
1359+ ) -> Model :
13561360 """
13571361 Returns a model which is symmetrized w.r.t. the given
13581362 symmetries. This is done by performing a group average over the
@@ -1404,7 +1408,7 @@ def symmetrize(
14041408
14051409 def _apply_operation ( # pylint: disable=too-many-locals
14061410 self , symmetry_operation , position_tolerance
1407- ) -> " Model" :
1411+ ) -> Model :
14081412 """
14091413 Helper function to apply a symmetry operation to the model.
14101414 """
@@ -1473,7 +1477,7 @@ def _apply_operation( # pylint: disable=too-many-locals
14731477
14741478 return Model (** co .ChainMap (dict (hop = new_hop ), self ._input_kwargs ))
14751479
1476- def slice_orbitals (self , slice_idx : ty .List [int ]) -> " Model" :
1480+ def slice_orbitals (self , slice_idx : ty .List [int ]) -> Model :
14771481 """
14781482 Returns a new model with only the orbitals as given in the
14791483 ``slice_idx``. This can also be used to re-order the orbitals.
@@ -1491,7 +1495,7 @@ def slice_orbitals(self, slice_idx: ty.List[int]) -> "Model":
14911495 return Model (** co .ChainMap (dict (hop = new_hop , pos = new_pos ), self ._input_kwargs ))
14921496
14931497 @classmethod
1494- def join_models (cls , * models : " Model" ) -> " Model" :
1498+ def join_models (cls , * models : Model ) -> Model :
14951499 """
14961500 Creates a tight-binding model which contains all orbitals of the
14971501 given input models. The orbitals are ordered by model, such that
@@ -1559,7 +1563,7 @@ def change_unit_cell( # pylint: disable=too-many-branches
15591563 uc : ty .Optional [ty .Sequence [ty .Sequence [float ]]] = None ,
15601564 offset : ty .Sequence [float ] = (0 , 0 , 0 ),
15611565 cartesian : bool = False ,
1562- ) -> " Model" :
1566+ ) -> Model :
15631567 """Return a model with a different unit cell of the same volume.
15641568
15651569 Creates a model with a changed unit cell - with a different
@@ -1646,7 +1650,7 @@ def change_unit_cell( # pylint: disable=too-many-branches
16461650
16471651 def supercell ( # pylint: disable=too-many-locals
16481652 self , size : ty .Sequence [int ]
1649- ) -> " Model" :
1653+ ) -> Model :
16501654 """Generate a model for a supercell of the current unit cell.
16511655
16521656 Parameters
@@ -1738,7 +1742,8 @@ def fold_model( # pylint: disable=too-many-locals,too-many-branches,too-many-st
17381742 check_uc_volume : bool = True ,
17391743 uc_volume_tolerance : float = 1e-6 ,
17401744 check_orbital_ratio : bool = True ,
1741- ) -> "Model" :
1745+ order_by : Literal ["label" , "index" ] = "label" ,
1746+ ) -> Model :
17421747 """
17431748 Returns a model with a smaller unit cell. Orbitals which are related
17441749 by a lattice vector of the new unit cell are "folded" into a single
@@ -1789,6 +1794,15 @@ def fold_model( # pylint: disable=too-many-locals,too-many-branches,too-many-st
17891794 should be checked to be the same as the initial ratio. If
17901795 this is set to False, the resulting model will always
17911796 have ``occ=None``.
1797+ order_by :
1798+ Determines how the orbitals in the new model are ordered.
1799+ For ``order_by="index"``, the orbitals are ordered exactly the
1800+ same as in the original model. Note that this order will
1801+ depend on which orbitals end up inside the unit cell.
1802+ For ``order_by="label"``, the orbitals are ordered
1803+ according to the **occurrence** (not the value) of the
1804+ labels in the ``orbital_labels`` input. Orbitals with the
1805+ same label are again ordered by index.
17921806 """
17931807 if len (orbital_labels ) != self .size :
17941808 raise ValueError (
@@ -1863,6 +1877,24 @@ def fold_model( # pylint: disable=too-many-locals,too-many-branches,too-many-st
18631877 axis = - 1 ,
18641878 )
18651879 ).flatten ()
1880+ if order_by == "label" :
1881+ idx = 0
1882+ orbital_sort_idx = {}
1883+ for label in orbital_labels :
1884+ if label not in orbital_sort_idx :
1885+ orbital_sort_idx [label ] = idx
1886+ idx += 1
1887+ in_uc_sort_idx = [
1888+ orbital_sort_idx [orbital_labels [i ]] for i in in_uc_indices
1889+ ]
1890+ in_uc_indices = in_uc_indices [
1891+ np .argsort (in_uc_sort_idx , kind = "mergesort" ) # need stable sorting
1892+ ]
1893+ else :
1894+ if order_by != "index" :
1895+ raise ValueError (
1896+ f"Invalid input '{ order_by } ' for 'order_by', must be either 'label' or 'index'."
1897+ )
18661898 if target_indices is not None :
18671899 if not np .all (target_indices == in_uc_indices ):
18681900 raise ValueError (
@@ -2001,7 +2033,7 @@ def get_matching_idx_and_offset(pos_reduced, orbital_label):
20012033 )
20022034 )
20032035
2004- def __add__ (self , model : " Model" ) -> " Model" :
2036+ def __add__ (self , model : Model ) -> Model :
20052037 """
20062038 Adds two models together by adding their hopping terms.
20072039 """
@@ -2059,19 +2091,19 @@ def __add__(self, model: "Model") -> "Model":
20592091 # -------------------
20602092 return Model (** co .ChainMap (dict (hop = new_hop ), self ._input_kwargs ))
20612093
2062- def __sub__ (self , model : " Model" ) -> " Model" :
2094+ def __sub__ (self , model : Model ) -> Model :
20632095 """
20642096 Substracts one model from another by substracting all hopping terms.
20652097 """
20662098 return self + - model
20672099
2068- def __neg__ (self ) -> " Model" :
2100+ def __neg__ (self ) -> Model :
20692101 """
20702102 Changes the sign of all hopping terms.
20712103 """
20722104 return - 1 * self
20732105
2074- def __mul__ (self , x : float ) -> " Model" :
2106+ def __mul__ (self , x : float ) -> Model :
20752107 """
20762108 Multiplies hopping terms by x.
20772109 """
@@ -2081,13 +2113,13 @@ def __mul__(self, x: float) -> "Model":
20812113
20822114 return Model (** co .ChainMap (dict (hop = new_hop ), self ._input_kwargs ))
20832115
2084- def __rmul__ (self , x : float ) -> " Model" :
2116+ def __rmul__ (self , x : float ) -> Model :
20852117 """
20862118 Multiplies hopping terms by x.
20872119 """
20882120 return self .__mul__ (x )
20892121
2090- def __truediv__ (self , x : float ) -> " Model" :
2122+ def __truediv__ (self , x : float ) -> Model :
20912123 """
20922124 Divides hopping terms by x.
20932125 """
0 commit comments