From ee73c919d4936c9db13d1e9f37b7eec8315e8680 Mon Sep 17 00:00:00 2001 From: zhanghao Date: Wed, 11 Dec 2024 03:01:17 +0800 Subject: [PATCH 1/3] init spline decay branch --- dptb/data/dataset/_default_dataset.py | 13 ++- dptb/nn/rescale.py | 138 +++++++++++++++++++++++++- 2 files changed, 147 insertions(+), 4 deletions(-) diff --git a/dptb/data/dataset/_default_dataset.py b/dptb/data/dataset/_default_dataset.py index 94d34b43..b0d93db1 100644 --- a/dptb/data/dataset/_default_dataset.py +++ b/dptb/data/dataset/_default_dataset.py @@ -447,12 +447,14 @@ def _E3edgespecies_stat(self, typed_dataset, decay): # calculate norm & mean typed_norm = {} + typed_scalar = {} typed_norm_ave = torch.ones(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps) typed_norm_std = torch.zeros(len(idp.bond_to_type), idp.orbpair_irreps.num_irreps) typed_scalar_ave = torch.ones(len(idp.bond_to_type), n_scalar) typed_scalar_std = torch.zeros(len(idp.bond_to_type), n_scalar) for bt, tp in idp.bond_to_type.items(): norms_per_irrep = [] + scalar_per_irrep = [] count_scalar = 0 for ir, s in enumerate(irrep_slices): sub_tensor = typed_hopping[bt][:, s] @@ -472,13 +474,17 @@ def _E3edgespecies_stat(self, typed_dataset, decay): norms = torch.ones_like(sub_tensor[:, 0]) if decay: - norms_per_irrep.append(norms) + if not torch.isnan(sub_tensor).all(): + if sub_tensor.shape[-1] == 1: # is scalar + scalar_per_irrep.append(sub_tensor) + norms_per_irrep.append(norms) assert count_scalar <= n_scalar # shape of typed_norm: (n_irreps, n_edges) if decay: - typed_norm[bt] = torch.stack(norms_per_irrep) + typed_scalar[bt] = torch.stack(scalar_per_irrep) # [n_scalar, n_edge] + typed_norm[bt] = torch.stack(norms_per_irrep) # [n_irreps, n_edge] edge_stats = { "norm_ave": typed_norm_ave, @@ -495,11 +501,12 @@ def _E3edgespecies_stat(self, typed_dataset, decay): lengths_bt = typed_dataset["edge_lengths"][typed_dataset["edge_type"].flatten().eq(tp)] sorted_lengths, indices = lengths_bt.sort() # from small to large # sort the norms by irrep l - sorted_norms = typed_norm[bt][idp.orbpair_irreps.sort().inv, :] + sorted_norms = typed_norm[bt] # sort the norms by edge length sorted_norms = sorted_norms[:, indices] decay_bt["edge_length"] = sorted_lengths decay_bt["norm_decay"] = sorted_norms + decay_bt["scalar_decay"] = typed_scalar[bt][:, indices] decay[bt] = decay_bt edge_stats["decay"] = decay diff --git a/dptb/nn/rescale.py b/dptb/nn/rescale.py index 168e8351..e559cfef 100644 --- a/dptb/nn/rescale.py +++ b/dptb/nn/rescale.py @@ -6,6 +6,7 @@ from typing import Optional, List, Union import torch.nn.functional from e3nn.o3 import Linear +from .sktb import HoppingFormul from e3nn.util.jit import compile_mode from dptb.data import AtomicDataDict import e3nn.o3 as o3 @@ -523,4 +524,139 @@ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]=None): else: x = x - return x \ No newline at end of file + return x + + +class E3PerEdgeSpeciesRadialDpdtScaleShift(torch.nn.Module): + """Sum edgewise energies. + + Includes optional per-species-pair edgewise energy scales. + """ + + field: str + out_field: str + scales_trainble: bool + shifts_trainable: bool + has_scales: bool + has_shifts: bool + + def __init__( + self, + field: str, + num_types: int, + irreps_in, + shifts: Optional[torch.Tensor], + scales: Optional[torch.Tensor], + out_field: Optional[str] = None, + scales_trainable: bool = False, + shifts_trainable: bool = False, + shift_func: Optional[torch.nn.Module] = "poly3pow", + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu"), + **kwargs, + ): + """Sum edges into nodes.""" + super(E3PerEdgeSpeciesRadialDpdtScaleShift, self).__init__() + self.num_types = num_types + self.field = field + self.out_field = f"shifted_{field}" if out_field is None else out_field + self.irreps_in = irreps_in + self.num_scalar = 0 + self.device = device + self.dtype = dtype + self.shift_index = [] + self.scale_index = [] + self.shift_func = shift_func + + self.shift_func = HoppingFormul(functype=self.functype) + + start = 0 + start_scalar = 0 + for mul, ir in irreps_in: + if str(ir) == "0e": + self.num_scalar += mul + self.shift_index += list(range(start_scalar, start_scalar + mul)) + start_scalar += mul + else: + self.shift_index += [-1] * mul * ir.dim + + for _ in range(mul): + self.scale_index += [start] * ir.dim + start += 1 + + self.shift_index = torch.as_tensor(self.shift_index, dtype=torch.long, device=device) + self.scale_index = torch.as_tensor(self.scale_index, dtype=torch.long, device=device) + + self.has_shifts = shifts is not None + self.has_scales = scales is not None + if scales is not None: + scales = torch.as_tensor(scales, dtype=self.dtype, device=device) + if len(scales.reshape(-1)) == 1: + scales = scales * torch.ones(num_types*num_types, self.irreps_in.num_irreps, dtype=self.dtype, device=self.device) + assert scales.shape == (num_types*num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}" + self.scales_trainable = scales_trainable + if scales_trainable: + self.scales = torch.nn.Parameter(scales) + else: + self.register_buffer("scales", scales) + + if shifts is not None: + shifts = torch.as_tensor(shifts, dtype=self.dtype, device=device) + if len(shifts.reshape(-1)) == 1: + shifts = shifts * torch.ones(num_types*num_types, self.num_scalar, dtype=self.dtype, device=self.device) + assert shifts.shape == (num_types*num_types, self.num_scalar), f"Invalid shape of shifts {shifts}" + self.shifts_trainable = shifts_trainable + if shifts_trainable: + self.shifts = torch.nn.Parameter(shifts) + else: + self.register_buffer("shifts", shifts) + + def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None): + self.has_scales = scales is not None or self.has_scales + if scales is not None: + assert scales.shape == (self.num_types*self.num_types, self.irreps_in.num_irreps), f"Invalid shape of scales {scales}" + if self.scales_trainable: + self.scales = torch.nn.Parameter(scales) + else: + self.register_buffer("scales", scales) + + self.has_shifts = shifts is not None or self.has_shifts + if shifts is not None: + assert shifts.shape == (self.num_types*self.num_types, self.num_scalar, self.shift_func.num_paras), f"Invalid shape of shifts {shifts}" + if self.shifts_trainable: + self.shifts = torch.nn.Parameter(shifts) + else: + self.register_buffer("shifts", shifts) + + def fit_radialdpdt_shift(self, decay): + shifts = torch.randn(self.num_types*self.num_types, self.num_scalar, self.shift_func.num_paras, dtype=self.dtype, device=self.device) + shifts.requires_grad_() + + + return shifts + + + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + + if not (self.has_scales or self.has_shifts): + return data + + edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0] + + species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten() + in_field = data[self.field] + + assert len(in_field) == len( + edge_center + ), "in_field doesnt seem to have correct per-edge shape" + + if self.has_scales: + in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + if self.has_shifts: + shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) + in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] + + data[self.out_field] = in_field + + return data \ No newline at end of file From cb1d8655e0deddc701987b8e33ae0db124a5cf10 Mon Sep 17 00:00:00 2001 From: zhanghao Date: Wed, 1 Jan 2025 02:37:45 +0800 Subject: [PATCH 2/3] update fix warning --- dptb/data/dataset/_default_dataset.py | 6 +- dptb/entrypoints/train.py | 2 +- dptb/nn/deeptb.py | 38 +++++++--- dptb/nn/rescale.py | 103 +++++++++++++++++++++++--- dptb/nn/tensor_product.py | 2 +- dptb/utils/argcheck.py | 2 + 6 files changed, 126 insertions(+), 27 deletions(-) diff --git a/dptb/data/dataset/_default_dataset.py b/dptb/data/dataset/_default_dataset.py index b0d93db1..f804d284 100644 --- a/dptb/data/dataset/_default_dataset.py +++ b/dptb/data/dataset/_default_dataset.py @@ -422,6 +422,10 @@ def E3statistics(self, model: torch.nn.Module=None, decay=False): edge_scales = stats["edge"]["norm_ave"] edge_scales[:,scalar_mask] = stats["edge"]["scalar_std"] model.node_prediction_h.set_scale_shift(scales=node_scales, shifts=node_shifts) + + if decay: + edge_shifts = model.edge_prediction_h.fit_radialdpdt_shift(stats["edge"]["decay"], self.type_mapper) + edge_scales = None model.edge_prediction_h.set_scale_shift(scales=edge_scales, shifts=edge_shifts) return stats @@ -476,7 +480,7 @@ def _E3edgespecies_stat(self, typed_dataset, decay): if decay: if not torch.isnan(sub_tensor).all(): if sub_tensor.shape[-1] == 1: # is scalar - scalar_per_irrep.append(sub_tensor) + scalar_per_irrep.append(sub_tensor.squeeze(-1)) norms_per_irrep.append(norms) assert count_scalar <= n_scalar diff --git a/dptb/entrypoints/train.py b/dptb/entrypoints/train.py index d2bc3484..012c1eb7 100644 --- a/dptb/entrypoints/train.py +++ b/dptb/entrypoints/train.py @@ -192,7 +192,7 @@ def train( # build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint. checkpoint = init_model if init_model else None model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"]) - train_datasets.E3statistics(model=model) + train_datasets.E3statistics(model=model, decay=jdata["model_options"]["prediction"]["decay"]) trainer = Trainer( train_options=jdata["train_options"], common_options=jdata["common_options"], diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 0b102dd6..9c548564 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -10,7 +10,7 @@ from dptb.nn.nnsk import NNSK from dptb.nn.dftbsk import DFTBSK from e3nn.o3 import Linear -from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift +from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift, E3PerEdgeSpeciesRadialDpdtScaleShift import logging log = logging.getLogger(__name__) @@ -180,18 +180,32 @@ def __init__( device=self.device, **prediction_copy, ) + + if prediction_copy.get("decay"): + self.edge_prediction_h = E3PerEdgeSpeciesRadialDpdtScaleShift( + field=AtomicDataDict.EDGE_FEATURES_KEY, + num_types=n_species, + irreps_in=self.embedding.out_edge_irreps, + out_field = AtomicDataDict.EDGE_FEATURES_KEY, + shifts=0., + scales=1., + dtype=self.dtype, + device=self.device, + **prediction_copy, + ) + else: + self.edge_prediction_h = E3PerEdgeSpeciesScaleShift( + field=AtomicDataDict.EDGE_FEATURES_KEY, + num_types=n_species, + irreps_in=self.embedding.out_edge_irreps, + out_field = AtomicDataDict.EDGE_FEATURES_KEY, + shifts=0., + scales=1., + dtype=self.dtype, + device=self.device, + **prediction_copy, + ) - self.edge_prediction_h = E3PerEdgeSpeciesScaleShift( - field=AtomicDataDict.EDGE_FEATURES_KEY, - num_types=n_species, - irreps_in=self.embedding.out_edge_irreps, - out_field = AtomicDataDict.EDGE_FEATURES_KEY, - shifts=0., - scales=1., - dtype=self.dtype, - device=self.device, - **prediction_copy, - ) if overlap: self.idp_sk = OrbitalMapper(self.idp.basis, method="sktb", device=self.device) diff --git a/dptb/nn/rescale.py b/dptb/nn/rescale.py index e559cfef..83e5e1d2 100644 --- a/dptb/nn/rescale.py +++ b/dptb/nn/rescale.py @@ -6,11 +6,14 @@ from typing import Optional, List, Union import torch.nn.functional from e3nn.o3 import Linear -from .sktb import HoppingFormul +from dptb.nn.sktb import HoppingFormula, bond_length_list from e3nn.util.jit import compile_mode from dptb.data import AtomicDataDict +from dptb.utils.constants import atomic_num_dict import e3nn.o3 as o3 +log = logging.getLogger(__name__) + class PerSpeciesScaleShift(torch.nn.Module): """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters. @@ -550,7 +553,6 @@ def __init__( out_field: Optional[str] = None, scales_trainable: bool = False, shifts_trainable: bool = False, - shift_func: Optional[torch.nn.Module] = "poly3pow", dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), **kwargs, @@ -566,9 +568,6 @@ def __init__( self.dtype = dtype self.shift_index = [] self.scale_index = [] - self.shift_func = shift_func - - self.shift_func = HoppingFormul(functype=self.functype) start = 0 start_scalar = 0 @@ -611,6 +610,8 @@ def __init__( else: self.register_buffer("shifts", shifts) + self.r0 = [] # initilize r0 + def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None): self.has_scales = scales is not None or self.has_scales if scales is not None: @@ -622,18 +623,70 @@ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None): self.has_shifts = shifts is not None or self.has_shifts if shifts is not None: - assert shifts.shape == (self.num_types*self.num_types, self.num_scalar, self.shift_func.num_paras), f"Invalid shape of shifts {shifts}" + assert shifts.shape == (self.num_types*self.num_types, self.num_scalar, 7), f"Invalid shape of shifts {shifts}" if self.shifts_trainable: self.shifts = torch.nn.Parameter(shifts) else: self.register_buffer("shifts", shifts) - def fit_radialdpdt_shift(self, decay): - shifts = torch.randn(self.num_types*self.num_types, self.num_scalar, self.shift_func.num_paras, dtype=self.dtype, device=self.device) + def fit_radialdpdt_shift(self, decay, idp): + shifts = torch.randn(self.num_types*self.num_types, self.num_scalar, 7, dtype=self.dtype, device=self.device) shifts.requires_grad_() + optimizer = torch.optim.Adam([shifts], lr=0.01) + lrsch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3000, threshold=1e-5, eps=1e-5, verbose=True) + bond_sym = list(decay.keys()) + bsz = 128 + + for sym in idp.type_names: + self.r0.append(bond_length_list[atomic_num_dict[sym]-1]) + self.r0 = torch.tensor(self.r0, device=self.device, dtype=self.dtype) + + #TODO: check wether exist some bond that does not have eneough values, this may appear in sparse dopping. + #TODO: check whether there is bond that does not cover the range bwtween equilirbium r0 to r_cut. This may appear in some hetrogenous system. + n_edge_length = [] + edge_lengths = {} + scalar_decays = {} + for bsym in decay: + n_edge_length.append(len(decay[bsym]["edge_length"])) + edge_lengths[bsym] = decay[bsym]["edge_length"].type(self.dtype).to(self.device) + scalar_decays[bsym] = decay[bsym]["scalar_decay"].type(self.dtype).to(self.device) + + + if min(n_edge_length) <= bsz: + log.warning("There exist edge that does not have enough values for fitting edge decaying behaviour, please use decay == False.") - - return shifts + edge_number = idp._index_to_ZZ.T + for i in range(40000): + optimizer.zero_grad() + rs = [None] * len(bond_sym) + frs = [None] * len(bond_sym) + # construct the dataset + for bsym in decay: + bt = idp.bond_to_type[bsym] + random_index = torch.randint(0, len(edge_lengths[bsym]), (bsz,)) + rs[bt] = edge_lengths[bsym][random_index] + frs[bt] = scalar_decays[bsym][:,random_index].T # [bsz, n_scalar] + rs = torch.cat(rs, dim=0) + frs = torch.cat(frs, dim=0) + r0 = 0.5*bond_length_list.type(self.dtype).to(self.device)[edge_number-1].sum(0) + r0 = r0.unsqueeze(1).repeat(1, bsz).reshape(-1) + + paraArray=shifts.reshape(-1, 1, self.num_scalar, 7).repeat(1,bsz,1,1).reshape(-1, self.num_scalar, 7) + + fr_ = self.poly5pow( + rij=rs, + paraArray=paraArray, + r0 = r0, + ) + + loss = (fr_ - frs).pow(2).mean() + + log.info("Decaying function fitting Step {}, loss: {:.4f}, lr: {:.5f}".format(i, loss.item(), lrsch.get_last_lr()[0])) + loss.backward() + optimizer.step() + lrsch.step(loss.item()) + + return shifts.detach() @@ -645,6 +698,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0] species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten() + edge_atom_type = data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[data[AtomicDataDict.EDGE_INDEX_KEY]] in_field = data[self.field] assert len(in_field) == len( @@ -654,9 +708,34 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if self.has_scales: in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field if self.has_shifts: - shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) + shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar, 7) + r0 = self.r0[edge_atom_type].sum(0) * 0.5 + shifts = self.poly5pow( + rij=data[AtomicDataDict.EDGE_LENGTH_KEY], + r0=r0, + paraArray=shifts + ) # [n_edge, n_scalar] in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] data[self.out_field] = in_field - return data \ No newline at end of file + return data + + def poly5pow(self, rij, paraArray, r0:torch.Tensor): + """> This function calculates SK integrals without the environment dependence of the form of powerlaw + + $$ h(rij) = alpha_1 * (rij / r_ij0)^(lambda + alpha_2) $$ + """ + + #alpha1, alpha2, alpha3, alpha4 = paraArray[:, 0], paraArray[:, 1]**2, paraArray[:, 2]**2, paraArray[:, 3]**2 + alpha1, alpha2, alpha3, alpha4, alpha5, alpha6, alpha7 = paraArray[..., 0], paraArray[..., 1], paraArray[..., 2], paraArray[..., 3], paraArray[..., 4], paraArray[..., 5], paraArray[..., 6].abs() + #[N, n_op] + shape = [-1]+[1] * (len(alpha1.shape)-1) + # [-1, 1] + rij = rij.reshape(shape) + r0 = r0.reshape(shape) + + r0 = r0 / 1.8897259886 + + return (alpha1 + alpha2 * (rij-r0) + 0.5 * alpha3 * (rij - r0)**2 + 1/6 * alpha4 * (rij-r0)**3 + 1./24 * alpha5 * (rij-r0)**4 + 1./120 * alpha6 * (rij-r0)**5) * (r0/rij)**(1 + alpha7) + \ No newline at end of file diff --git a/dptb/nn/tensor_product.py b/dptb/nn/tensor_product.py index c36fc68a..0c76f708 100644 --- a/dptb/nn/tensor_product.py +++ b/dptb/nn/tensor_product.py @@ -7,7 +7,7 @@ -_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt")) +_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=True) def wigner_D(l, alpha, beta, gamma): if not l < len(_Jd): diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index f4fb1708..a1bd4082 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -605,6 +605,7 @@ def sktb_prediction(): def e3tb_prediction(): doc_scales_trainable = "whether to scale the trianing target." doc_shifts_trainable = "whether to shift the training target." + doc_decay = "whether the edge normalization takes into account the decaying behaviour of the edge irreps" doc_neurons = "neurons in the neural network." doc_activation = "activation function." doc_if_batch_normalized = "if to turn on batch normalization" @@ -612,6 +613,7 @@ def e3tb_prediction(): nn = [ Argument("scales_trainable", bool, optional=True, default=False, doc=doc_scales_trainable), Argument("shifts_trainable", bool, optional=True, default=False, doc=doc_shifts_trainable), + Argument("decay", bool, optional=True, default=False, doc=doc_decay), Argument("neurons", list, optional=True, default=None, doc=doc_neurons), Argument("activation", str, optional=True, default="tanh", doc=doc_activation), Argument("if_batch_normalized", bool, optional=True, default=False, doc=doc_if_batch_normalized), From af1032aa860a16bec033413c825a1027a348f997 Mon Sep 17 00:00:00 2001 From: zhanghao Date: Sun, 5 Jan 2025 09:13:01 +0800 Subject: [PATCH 3/3] fix non prediction bug --- dptb/entrypoints/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dptb/entrypoints/train.py b/dptb/entrypoints/train.py index 012c1eb7..2ea194d8 100644 --- a/dptb/entrypoints/train.py +++ b/dptb/entrypoints/train.py @@ -192,7 +192,8 @@ def train( # build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint. checkpoint = init_model if init_model else None model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"]) - train_datasets.E3statistics(model=model, decay=jdata["model_options"]["prediction"]["decay"]) + decay = jdata["model_options"].get("prediction", {}).get("decay", False) + train_datasets.E3statistics(model=model, decay=decay) trainer = Trainer( train_options=jdata["train_options"], common_options=jdata["common_options"],