diff --git a/dptb/nn/tensor_product.py b/dptb/nn/tensor_product.py index 414d3fbf..94c1f518 100644 --- a/dptb/nn/tensor_product.py +++ b/dptb/nn/tensor_product.py @@ -4,10 +4,87 @@ import torch.nn as nn from torch.nn import Linear import os +import torch.nn.functional as F +from collections import defaultdict +_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False) +_idx_data = torch.load(os.path.join(os.path.dirname(__file__), "z_rot_indices_lmax12.pt"), weights_only=False) -_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False) +def build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes): + """ + angle_stack: (3*N, ) # Input with alpha, beta, gamma stacked together + l_max: int + + Returns: (Xa, Xb, Xc) # Each is of shape (N, D_total, D_total) + """ + N_all = angle_stack.shape[0] + N = N_all // 3 + + D_total = sizes.sum().item() + + # Step 1: Vectorized computation of sine and cosine values + angle_expand = angle_stack[None, :, None] # (1, 3N, 1) + freq_expand = freq[:, None, :] # (L, 1, Mmax) + sin_val = torch.sin(freq_expand * angle_expand) # (L, 3N, Mmax) + cos_val = torch.cos(freq_expand * angle_expand) # (L, 3N, Mmax) + + # Step 2: Construct the block-diagonal matrix + M_total = angle_stack.new_zeros((N_all, D_total, D_total)) + idx_l, idx_row = torch.where(mask) # (K,), (K,) + idx_col_diag = idx_row + idx_col_anti = reversed_inds[idx_l, idx_row] + global_row = offsets[idx_l] + idx_row # (K,) + global_col_diag = offsets[idx_l] + idx_col_diag + global_col_anti = offsets[idx_l] + idx_col_anti + + # Assign values to the diagonal + M_total[:, global_row, global_col_diag] = cos_val[idx_l, :, idx_row].transpose(0,1) + # Assign values to non-overlapping anti-diagonals + overlap_mask = (global_row == global_col_anti) + M_total[:, global_row[~overlap_mask], global_col_anti[~overlap_mask]] = sin_val[idx_l[~overlap_mask], :, idx_row[~overlap_mask]].transpose(0,1) + + # Step 3: Split into three components corresponding to alpha, beta, gamma + Xa = M_total[:N] + Xb = M_total[N:2*N] + Xc = M_total[2*N:] + + return Xa, Xb, Xc + + +def batch_wigner_D(l_max, alpha, beta, gamma, _Jd): + """ + Compute Wigner D matrices for all L (from 0 to l_max) in a single batch. + Returns a tensor of shape [N, D, D], where D = sum(2l+1 for l in 0..l_max). + """ + device = alpha.device + N = alpha.shape[0] + idx_data = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in _idx_data.items()} + + # Load static data + sizes = idx_data["sizes"][:l_max+1] + offsets = idx_data["offsets"][:l_max+1] + mask = idx_data["mask"][:l_max+1] + freq = idx_data["freq"][:l_max+1] + reversed_inds = idx_data["reversed_inds"][:l_max+1] + + # Precompute block structure information + dims = [2*l + 1 for l in range(l_max + 1)] + # offsets = torch.cumsum(torch.tensor([0] + dims[:-1], device=device), 0) + D_total = sum(dims) + + # Construct block-diagonal J matrix + J_full_small = torch.zeros(D_total, D_total, device=device) + for l in range(l_max + 1): + start = offsets[l] + J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l] + + J_full = J_full_small.unsqueeze(0).expand(N, -1, -1) + angle_stack = torch.cat([alpha, beta, gamma], dim=0) + Xa, Xb, Xc = build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes) + + return Xa @ J_full @ Xb @ J_full @ Xc + def wigner_D(l, alpha, beta, gamma): if not l < len(_Jd): @@ -54,7 +131,7 @@ def __init__( extra_m0_outsize: int = 0, ): super(SO2_Linear, self).__init__() - + self.irreps_in = irreps_in.simplify() self.irreps_out = (Irreps(f"{extra_m0_outsize}x0e") + irreps_out).simplify() @@ -105,21 +182,51 @@ def __init__( self.radial_emb = RadialFunction([latent_dim]+radial_channels+[self.m_in_index[-1]]) self.front = front + self.l_max = max(l for (_, (l, _)), _ in zip(self.irreps_in, self.irreps_in.slices()) if l > 0) + self.dims = {l: 2*l + 1 for l in range(self.l_max + 1)} + self.offsets = {} + offset = 0 + for l in range(self.l_max + 1): + self.offsets[l] = offset + offset += self.dims[l] + + def forward(self, x, R, latents=None): n, _ = x.shape if self.radial_emb: weights = self.radial_emb(latents) - x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device) - angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N)) - for (mul, (l,p)), slice in zip(self.irreps_in, self.irreps_in.slices()): - if l > 0: - # The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here. - rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0])) - x_[:, slice] = torch.einsum('nji,nmj->nmi', rot_mat_L, x[:, slice].reshape(n,mul,-1)).reshape(n,-1) - else: - x_[:, slice] = x[:, slice] + x_ = torch.zeros_like(x) + angle = xyz_to_angles(R[:, [1,2,0]]) + + # Compute Wigner D matrices for all l at once + wigner_D_all = batch_wigner_D(self.l_max, angle[0], angle[1], torch.zeros_like(angle[0]), _Jd) + + # 1. group irreps by l + groups = defaultdict(list) + for (mul, (l, p)), slice_info in zip(self.irreps_in, self.irreps_in.slices()): + groups[l].append((mul, slice_info)) + + # 2. Batch process all mul for each l + for l, group in groups.items(): + if l == 0 or not group: + continue + + # Batch combination + muls, slices = zip(*group) + x_parts = [x[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group] + x_combined = torch.cat(x_parts, dim=1) # [n, total_mul, 2l+1] + + start = self.offsets[l] + rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]] + + # Batch feature rotation (n, total_mul, 2l+1) + transformed = torch.bmm(x_combined, rot_mat) # (n, total_mul, 2l+1) + + # Split back into each slice in the original order + for part, slice_info, mul in zip(transformed.split(muls, dim=1), slices, muls): + x_[:, slice_info] = part.reshape(n, -1) out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device) for m in range(self.irreps_out.lmax+1): @@ -153,11 +260,14 @@ def forward(self, x, R, latents=None): final_addition = linear_output.transpose(1, 2).contiguous().reshape(n, -1) out[:, self.m_out_mask[m]] += final_addition - for (mul, (l,p)), slice in zip(self.irreps_out, self.irreps_out.slices()): + + for (mul, (l, p)), slice_in in zip(self.irreps_out, self.irreps_out.slices()): if l > 0: - # The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here. - rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0])) - out[:, slice] = torch.einsum('nij,nmj->nmi', rot_mat_L, out[:, slice].reshape(n,mul,-1)).reshape(n,-1) + start = self.offsets[l] + rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]] + x_slice = out[:, slice_in].reshape(n, mul, -1) + rotated = torch.einsum('nij,nmj->nmi', rot_mat, x_slice) + out[:, slice_in] = rotated.reshape(n, -1) out.contiguous() @@ -240,4 +350,4 @@ def __init__(self, channels_list): def forward(self, inputs): - return self.net(inputs) \ No newline at end of file + return self.net(inputs) diff --git a/dptb/nn/z_rot_indices_lmax12.pt b/dptb/nn/z_rot_indices_lmax12.pt new file mode 100644 index 00000000..ae69bf6b Binary files /dev/null and b/dptb/nn/z_rot_indices_lmax12.pt differ