From e95d945f9229dac05ec2c9cd6ea8f0e893b287da Mon Sep 17 00:00:00 2001 From: Kai Qi <48313251+Kai-Qi@users.noreply.github.com> Date: Wed, 21 May 2025 20:03:45 +0800 Subject: [PATCH 1/5] Update tensor_product.py --- dptb/nn/tensor_product.py | 66 ++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/dptb/nn/tensor_product.py b/dptb/nn/tensor_product.py index 8edb75ac..00f8e7ac 100644 --- a/dptb/nn/tensor_product.py +++ b/dptb/nn/tensor_product.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.nn import Linear import os - +from collections import defaultdict _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False) @@ -111,13 +111,36 @@ def forward(self, x, R, latents=None): if self.radial_emb: weights = self.radial_emb(latents) + + ###################################################################### + ###### 改进,对角量子数进分组处理 ###################################### x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device) - for (mul, (l,p)), slice in zip(self.irreps_in, self.irreps_in.slices()): - if l > 0: - angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N)) - # The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here. - rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0])) - x_[:, slice] = torch.einsum('nji,nmj->nmi', rot_mat_L, x[:, slice].reshape(n,mul,-1)).reshape(n,-1) + + R_transformed = R[:, [1, 2, 0]] + alpha, beta = xyz_to_angles(R_transformed) + gamma = torch.zeros_like(alpha) + + # 按角量子数l分组处理不可约表示 + groups = defaultdict(list) + for (mul, (l, p)), slice_info in zip(self.irreps_in, self.irreps_in.slices()): + groups[l].append((mul, slice_info)) + + # 处理每个l的分组 + for l, group in groups.items(): + if l == 0 or not group: + continue + + muls, slices = zip(*group) + x_parts = [x[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group] + x_combined = torch.cat(x_parts, dim=1) # [n, total_mul, 2l+1] + + rot_mat = wigner_D(l, alpha, beta, gamma) # [n, 2l+1, 2l+1] + transformed = torch.bmm(x_combined, rot_mat) # [n, total_mul, 2l+1] + + for part, slice_info in zip(transformed.split(muls, dim=1), slices): + x_[:, slice_info] = part.reshape(n, -1) + ###################################################################### + out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device) for m in range(self.irreps_out.lmax+1): @@ -139,12 +162,27 @@ def forward(self, x, R, latents=None): out.contiguous() - for (mul, (l,p)), slice in zip(self.irreps_out, self.irreps_out.slices()): - if l > 0: - angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N)) - # The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here. - rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0])) - out[:, slice] = torch.einsum('nij,nmj->nmi', rot_mat_L, out[:, slice].reshape(n,mul,-1)).reshape(n,-1) + + ###################################################################### + ###### 改进,对角量子数进分组处理 ###################################### + out_groups = defaultdict(list) + for (mul, (l, p)), slice_info in zip(self.irreps_out, self.irreps_out.slices()): + out_groups[l].append((mul, slice_info)) + + for l, group in out_groups.items(): + if l == 0 or not group: + continue + + muls, slices = zip(*group) + out_parts = [out[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group] + out_combined = torch.cat(out_parts, dim=1) + + rot_mat = wigner_D(l, alpha, beta, gamma) + transformed = torch.bmm(rot_mat, out_combined.transpose(1,2)).transpose(1,2) # [n, total_mul, 2l+1] + + for part, slice_info in zip(transformed.split(muls, dim=1), slices): + out[:, slice_info] = part.reshape(n, -1) + ###################################################################### return out @@ -225,4 +263,4 @@ def __init__(self, channels_list): def forward(self, inputs): - return self.net(inputs) \ No newline at end of file + return self.net(inputs) From aeccb93d785a3892792dce08f04873d8cd0f7b71 Mon Sep 17 00:00:00 2001 From: Kai Qi <48313251+Kai-Qi@users.noreply.github.com> Date: Wed, 21 May 2025 20:08:18 +0800 Subject: [PATCH 2/5] Update tensor_product.py --- dptb/nn/tensor_product.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/dptb/nn/tensor_product.py b/dptb/nn/tensor_product.py index 00f8e7ac..da37150a 100644 --- a/dptb/nn/tensor_product.py +++ b/dptb/nn/tensor_product.py @@ -110,22 +110,22 @@ def forward(self, x, R, latents=None): if self.radial_emb: weights = self.radial_emb(latents) - - - ###################################################################### - ###### 改进,对角量子数进分组处理 ###################################### + + # ====================================================================== + # ====== Improved: group input irreps by angular quantum number ======== + # ====================================================================== x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device) R_transformed = R[:, [1, 2, 0]] alpha, beta = xyz_to_angles(R_transformed) gamma = torch.zeros_like(alpha) - # 按角量子数l分组处理不可约表示 + # Group irreducible representations by angular quantum number groups = defaultdict(list) for (mul, (l, p)), slice_info in zip(self.irreps_in, self.irreps_in.slices()): groups[l].append((mul, slice_info)) - # 处理每个l的分组 + # Process each l group for l, group in groups.items(): if l == 0 or not group: continue @@ -139,8 +139,7 @@ def forward(self, x, R, latents=None): for part, slice_info in zip(transformed.split(muls, dim=1), slices): x_[:, slice_info] = part.reshape(n, -1) - ###################################################################### - + # ====================================================================== out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device) for m in range(self.irreps_out.lmax+1): @@ -162,9 +161,9 @@ def forward(self, x, R, latents=None): out.contiguous() - - ###################################################################### - ###### 改进,对角量子数进分组处理 ###################################### + # ====================================================================== + # ====== Improved: group input irreps by angular quantum number ======== + # ====================================================================== out_groups = defaultdict(list) for (mul, (l, p)), slice_info in zip(self.irreps_out, self.irreps_out.slices()): out_groups[l].append((mul, slice_info)) @@ -182,7 +181,6 @@ def forward(self, x, R, latents=None): for part, slice_info in zip(transformed.split(muls, dim=1), slices): out[:, slice_info] = part.reshape(n, -1) - ###################################################################### return out From a408b4a9f90035500a14c7e4148da709117e7e86 Mon Sep 17 00:00:00 2001 From: Kai Qi <48313251+Kai-Qi@users.noreply.github.com> Date: Fri, 13 Jun 2025 15:24:58 +0800 Subject: [PATCH 3/5] Update tensor_product.py --- dptb/nn/tensor_product.py | 152 ++++++++++++++++++++++++++++---------- 1 file changed, 114 insertions(+), 38 deletions(-) diff --git a/dptb/nn/tensor_product.py b/dptb/nn/tensor_product.py index da37150a..99f0c44b 100644 --- a/dptb/nn/tensor_product.py +++ b/dptb/nn/tensor_product.py @@ -4,10 +4,87 @@ import torch.nn as nn from torch.nn import Linear import os +import torch.nn.functional as F from collections import defaultdict - _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False) +_idx_data = torch.load(os.path.join(os.path.dirname(__file__), "z_rot_indices_lmax12.pt"), weights_only=False) + + +def build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes): + """ + angle_stack: (3*N, ) # Input with alpha, beta, gamma stacked together + l_max: int + + Returns: (Xa, Xb, Xc) # Each is of shape (N, D_total, D_total) + """ + N_all = angle_stack.shape[0] + N = N_all // 3 + + D_total = sizes.sum().item() + + # Step 1: Vectorized computation of sine and cosine values + angle_expand = angle_stack[None, :, None] # (1, 3N, 1) + freq_expand = freq[:, None, :] # (L, 1, Mmax) + sin_val = torch.sin(freq_expand * angle_expand) # (L, 3N, Mmax) + cos_val = torch.cos(freq_expand * angle_expand) # (L, 3N, Mmax) + + # Step 2: Construct the block-diagonal matrix + M_total = angle_stack.new_zeros((N_all, D_total, D_total)) + idx_l, idx_row = torch.where(mask) # (K,), (K,) + idx_col_diag = idx_row + idx_col_anti = reversed_inds[idx_l, idx_row] + global_row = offsets[idx_l] + idx_row # (K,) + global_col_diag = offsets[idx_l] + idx_col_diag + global_col_anti = offsets[idx_l] + idx_col_anti + + # Assign values to the diagonal + M_total[:, global_row, global_col_diag] = cos_val[idx_l, :, idx_row].transpose(0,1) + # Assign values to non-overlapping anti-diagonals + overlap_mask = (global_row == global_col_anti) + M_total[:, global_row[~overlap_mask], global_col_anti[~overlap_mask]] = sin_val[idx_l[~overlap_mask], :, idx_row[~overlap_mask]].transpose(0,1) + + # Step 3: Split into three components corresponding to alpha, beta, gamma + Xa = M_total[:N] + Xb = M_total[N:2*N] + Xc = M_total[2*N:] + + return Xa, Xb, Xc + + +def batch_wigner_D(l_max, alpha, beta, gamma, _Jd): + """ + Compute Wigner D matrices for all L (from 0 to l_max) in a single batch. + Returns a tensor of shape [N, D, D], where D = sum(2l+1 for l in 0..l_max). + """ + device = alpha.device + N = alpha.shape[0] + idx_data = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in _idx_data.items()} + + # Load static data + sizes = idx_data["sizes"][:l_max+1] + offsets = idx_data["offsets"][:l_max+1] + mask = idx_data["mask"][:l_max+1] + freq = idx_data["freq"][:l_max+1] + reversed_inds = idx_data["reversed_inds"][:l_max+1] + + # Precompute block structure information + dims = [2*l + 1 for l in range(l_max + 1)] + offsets = torch.cumsum(torch.tensor([0] + dims[:-1], device='cuda'), 0) + D_total = sum(dims) + + # Construct block-diagonal J matrix + J_full_small = torch.zeros(D_total, D_total, device='cuda') + for l in range(l_max + 1): + start = offsets[l] + J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l] + + J_full = J_full_small.unsqueeze(0).expand(N, -1, -1) + angle_stack = torch.cat([alpha, beta, gamma], dim=0) + Xa, Xb, Xc = build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes) + + return Xa @ J_full @ Xb @ J_full @ Xc + def wigner_D(l, alpha, beta, gamma): if not l < len(_Jd): @@ -54,7 +131,7 @@ def __init__( extra_m0_outsize: int = 0, ): super(SO2_Linear, self).__init__() - + self.irreps_in = irreps_in.simplify() self.irreps_out = (Irreps(f"{extra_m0_outsize}x0e") + irreps_out).simplify() @@ -105,41 +182,52 @@ def __init__( self.radial_emb = RadialFunction([latent_dim]+radial_channels+[self.m_in_index[-1]]) self.front = front + self.l_max = max(l for (_, (l, _)), _ in zip(self.irreps_in, self.irreps_in.slices()) if l > 0) + self.dims = {l: 2*l + 1 for l in range(self.l_max + 1)} + self.offsets = {} + offset = 0 + for l in range(self.l_max + 1): + self.offsets[l] = offset + offset += self.dims[l] + + def forward(self, x, R, latents=None): n, _ = x.shape if self.radial_emb: weights = self.radial_emb(latents) - - # ====================================================================== - # ====== Improved: group input irreps by angular quantum number ======== - # ====================================================================== - x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device) - R_transformed = R[:, [1, 2, 0]] - alpha, beta = xyz_to_angles(R_transformed) - gamma = torch.zeros_like(alpha) + x_ = torch.zeros_like(x) + angle = xyz_to_angles(R[:, [1,2,0]]) + + # Compute Wigner D matrices for all l at once + wigner_D_all = batch_wigner_D(self.l_max, angle[0], angle[1], torch.zeros_like(angle[0]), _Jd) - # Group irreducible representations by angular quantum number + # 1. group irreps by l groups = defaultdict(list) for (mul, (l, p)), slice_info in zip(self.irreps_in, self.irreps_in.slices()): groups[l].append((mul, slice_info)) - # Process each l group + # 2. Batch process all mul for each l for l, group in groups.items(): if l == 0 or not group: continue - + + # Batch combination muls, slices = zip(*group) x_parts = [x[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group] x_combined = torch.cat(x_parts, dim=1) # [n, total_mul, 2l+1] - - rot_mat = wigner_D(l, alpha, beta, gamma) # [n, 2l+1, 2l+1] - transformed = torch.bmm(x_combined, rot_mat) # [n, total_mul, 2l+1] - - for part, slice_info in zip(transformed.split(muls, dim=1), slices): + + start = self.offsets[l] + rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]] + + # Batch feature rotation (n, total_mul, 2l+1) + transformed = torch.bmm(x_combined, rot_mat) # (n, total_mul, 2l+1) + + # Split back into each slice in the original order + for part, slice_info, mul in zip(transformed.split(muls, dim=1), slices, muls): x_[:, slice_info] = part.reshape(n, -1) - # ====================================================================== + out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device) for m in range(self.irreps_out.lmax+1): @@ -161,26 +249,14 @@ def forward(self, x, R, latents=None): out.contiguous() - # ====================================================================== - # ====== Improved: group input irreps by angular quantum number ======== - # ====================================================================== - out_groups = defaultdict(list) - for (mul, (l, p)), slice_info in zip(self.irreps_out, self.irreps_out.slices()): - out_groups[l].append((mul, slice_info)) - for l, group in out_groups.items(): - if l == 0 or not group: - continue - - muls, slices = zip(*group) - out_parts = [out[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group] - out_combined = torch.cat(out_parts, dim=1) - - rot_mat = wigner_D(l, alpha, beta, gamma) - transformed = torch.bmm(rot_mat, out_combined.transpose(1,2)).transpose(1,2) # [n, total_mul, 2l+1] - - for part, slice_info in zip(transformed.split(muls, dim=1), slices): - out[:, slice_info] = part.reshape(n, -1) + for (mul, (l, p)), slice_in in zip(self.irreps_out, self.irreps_out.slices()): + if l > 0: + start = self.offsets[l] + rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]] + x_slice = out[:, slice_in].reshape(n, mul, -1) + rotated = torch.einsum('nij,nmj->nmi', rot_mat, x_slice) + out[:, slice_in] = rotated.reshape(n, -1) return out From a3e11d7fb0f77da03348383c31adf29b202cdce1 Mon Sep 17 00:00:00 2001 From: Kai Qi <48313251+Kai-Qi@users.noreply.github.com> Date: Fri, 13 Jun 2025 15:25:43 +0800 Subject: [PATCH 4/5] Add files via upload --- dptb/nn/z_rot_indices_lmax12.pt | Bin 0 -> 9416 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 dptb/nn/z_rot_indices_lmax12.pt diff --git a/dptb/nn/z_rot_indices_lmax12.pt b/dptb/nn/z_rot_indices_lmax12.pt new file mode 100644 index 0000000000000000000000000000000000000000..ae69bf6b5aad3c7af8040c2dbaf2ac206ad8a60d GIT binary patch literal 9416 zcmdT~&rcgi6rP1(z(61g`QZYAG_-{_82!Jj8F`_q}gt zcJ|wN(`xTA#s&u1t~SN`*u7k}V&=-5#d1M2a{8_OooH;am^bs`+c))8Qeii0JzoQx z%OVE5l{apJQZOskLMfcvGRwLV%vH57x5|1kXKI^9rJCE01)*ndF58dC8+c@XLxWzC z5=>Sq`e)+1ARNmoct`~Lvb{JKZf`;0+EvAb{_84JN7T>*7*Ip1@&E=O777(z*9vC2 zvS|cAtroSaR{XeJFoO`B&%lrg!x0LM#MK#53GdL?f*SVt|^D2$Z^+ldC- zNx^pNVTpr|n{YY2l9;bGFee1&q`;gq;d}at zs%>jkLo4!AZ@`QIFA1bLT#!hwH;^t0q*;OVh6!(G;H_)`k@TEiHcXgvKrF$0DWHy+ z@V13;AKp=w`*2BQFPIR{z+$!!vyJjS&47p_c^RT2IVO|iJb6hZFPpHEfp@L>t*oyb znrXmgNBRo9C(_@S=^ya)4@LT_2_I$PigjME;cA0~8p;k@GVouyyspEVI=mH5tNjCi zeK+*Qb`N8`r#)Bk1Y7f7#fy52m{SSR-jo(26l%U0e8q2aE2TK zd&m{o3GE6{s7XMCng)2tBcLdJIwxy?+`Wni@!o1WV=bmLno5TEIGs)BuIpdct_aB5 zuU&`U0ZcIF;HmF~y7gBAIT&^~Poj?Bcd6A9NP*ppr&h;UhwVuMHDGVY^Q69iqB`R5 zg{Qs`u9wKMAD)+c;aX#?Bc5lhZ5u-l+^9R|8Ee5cgZSO#nQ6imL5_pqiP*Rj$Z?Q7 ziD!H}uKvC7XdEp7@2$4ui!HX}SSmT)ZaZ$fG<5h(-8#4n=ZzbEKecu` zxXa}&AciCRxmO+h0gJaL4|XDONS~eR;12`51;ub^pWW6mB$W)mVnqk7gFnskc3cbx z_W7!Hj7lYAT~>71I!;JECp*-2)O4JZddByt>!|BEE%lsfqswnP&PqM!Kg zpTmux_g0>Cw1p>)r;LPzfe$TJE$j_0#^Q$iMZ{+9v zs*C&2f9Ahh?bLrRdH7HAZRB^z|B!DWA3^tDx*yYhjP7f6KcM>q?W?^07yte=Q9{4P z`+oiNva5coo9d-HsXnTU>Y@Ff_DkAt=z68~PU|Hk{UNV^`-kw)oz}nNs-Ns9>7esX zb<;jbb&{$EMc_b&J=0r-wXDjBprz{Gy$kH`<9E4E*air1XSZjJSd9zb-g| zzQqaf28z0Y?2m*_pjo^Uyn$kFAp5oL1bT{FuQyQK4P?K_oIsbSTLCS(f$Xc@3G`n} zpk+6Zec3sIp3bzYXvGa=UuBby>A!Qq14ypfv>FIpk{RcF7aF;AmC^XfTDgopM?Bsz a&)P`VgO(zM!VZGwYrkn!WivHt_0dX+)| literal 0 HcmV?d00001 From f69467a92efefd6eb8ba0d44cb59d94842935fa7 Mon Sep 17 00:00:00 2001 From: Kai Qi <48313251+Kai-Qi@users.noreply.github.com> Date: Fri, 20 Jun 2025 18:00:15 +0800 Subject: [PATCH 5/5] Update tensor_product.py --- dptb/nn/tensor_product.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dptb/nn/tensor_product.py b/dptb/nn/tensor_product.py index 99f0c44b..25521195 100644 --- a/dptb/nn/tensor_product.py +++ b/dptb/nn/tensor_product.py @@ -70,11 +70,11 @@ def batch_wigner_D(l_max, alpha, beta, gamma, _Jd): # Precompute block structure information dims = [2*l + 1 for l in range(l_max + 1)] - offsets = torch.cumsum(torch.tensor([0] + dims[:-1], device='cuda'), 0) + # offsets = torch.cumsum(torch.tensor([0] + dims[:-1], device=device), 0) D_total = sum(dims) # Construct block-diagonal J matrix - J_full_small = torch.zeros(D_total, D_total, device='cuda') + J_full_small = torch.zeros(D_total, D_total, device=device) for l in range(l_max + 1): start = offsets[l] J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l]