Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/advanced/e3tb/loss_analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ loader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=0)
for data in tqdm(loader, desc="doing error analysis"):
with torch.no_grad():
data = data.to("cuda")
batch_info = data.get_batch_info()
batch_info = data.get_batchinfo()
ref_data = AtomicData.to_AtomicDataDict(data)
data = model(ref_data)
data.update(batch_info)
Expand Down
4 changes: 2 additions & 2 deletions dptb/nn/embedding/lem.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def __init__(
res_update_params.clamp_(-6.0, 6.0)

if res_update_ratios_learnable:
self._latent_resnet_update_params = torch.nn.Parameter(
self._res_update_params = torch.nn.Parameter(
res_update_params
)
else:
Expand Down Expand Up @@ -731,7 +731,7 @@ def __init__(
res_update_params.clamp_(-6.0, 6.0)

if res_update_ratios_learnable:
self._latent_resnet_update_params = torch.nn.Parameter(
self._res_update_params = torch.nn.Parameter(
res_update_params
)
else:
Expand Down
6 changes: 3 additions & 3 deletions dptb/nn/embedding/slem.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__(
res_update_params.clamp_(-6.0, 6.0)

if res_update_ratios_learnable:
self._latent_resnet_update_params = torch.nn.Parameter(
self._res_update_params = torch.nn.Parameter(
res_update_params
)
else:
Expand Down Expand Up @@ -743,7 +743,7 @@ def __init__(
res_update_params.clamp_(-6.0, 6.0)

if res_update_ratios_learnable:
self._latent_resnet_update_params = torch.nn.Parameter(
self._res_update_params = torch.nn.Parameter(
res_update_params
)
else:
Expand Down Expand Up @@ -904,7 +904,7 @@ def __init__(
res_update_params.clamp_(-6.0, 6.0)

if res_update_ratios_learnable:
self._latent_resnet_update_params = torch.nn.Parameter(
self._res_update_params = torch.nn.Parameter(
res_update_params
)
else:
Expand Down
16 changes: 8 additions & 8 deletions dptb/nn/embedding/trinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def __init__(

self.sln_n = SeperableLayerNorm(
irreps=self.irreps_out,
eps=1e-5,
eps=5e-3,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -416,7 +416,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.irreps_out,
eps=1e-5,
eps=5e-3,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -586,7 +586,7 @@ def __init__(

self.sln = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=5e-3,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -596,7 +596,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.edge_irreps_in,
eps=1e-5,
eps=5e-3,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -672,7 +672,7 @@ def __init__(
res_update_params.clamp_(-6.0, 6.0)

if res_update_ratios_learnable:
self._latent_resnet_update_params = torch.nn.Parameter(
self._res_update_params = torch.nn.Parameter(
res_update_params
)
else:
Expand Down Expand Up @@ -791,7 +791,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=5e-3,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -801,7 +801,7 @@ def __init__(

self.sln_n = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=5e-3,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -847,7 +847,7 @@ def __init__(
res_update_params.clamp_(-6.0, 6.0)

if res_update_ratios_learnable:
self._latent_resnet_update_params = torch.nn.Parameter(
self._res_update_params = torch.nn.Parameter(
res_update_params
)
else:
Expand Down
19 changes: 15 additions & 4 deletions dptb/nn/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
count_shift = 0
self.shift_index = []
self.scale_index = []
self.scalar_weight_index = []
for mul, ir in self.irreps:
if str(ir) == "0e":
self.num_scalar += mul
Expand All @@ -55,15 +56,18 @@ def __init__(
self.shift_index += [-1] * mul * ir.dim

for _ in range(mul):
if str(ir) == "0e":
self.scalar_weight_index += [count_scales]
self.scale_index += [count_scales] * ir.dim
count_scales += 1

self.shift_index = torch.as_tensor(self.shift_index, dtype=torch.int64, device=self.device)
self.scale_index = torch.as_tensor(self.scale_index, dtype=torch.int64, device=self.device)
self.scalar_weight_index = torch.as_tensor(self.scalar_weight_index, dtype=torch.int64, device=self.device)

if self.affine:
self.affine_weight = nn.Parameter(torch.ones(self.irreps.num_irreps))
self.affine_bias = nn.Parameter(torch.zeros(self.num_scalar))
self.affine_weight = nn.Parameter(torch.ones(1,self.irreps.num_irreps))
self.affine_bias = nn.Parameter(torch.zeros(1,self.num_scalar))
else:
self.register_parameter('affine_weight', None)

Expand Down Expand Up @@ -104,6 +108,8 @@ def forward(self, x):
feature_mean = x[:, self.shift_index.ge(0)].mean(dim=1, keepdim=True)
x = x + 0. # to avoid the inplace operation of x
x[:, self.shift_index.ge(0)] = x[:, self.shift_index.ge(0)] - feature_mean
# compute norm of x0
scalar_norm = 1.0 / (x[:, self.shift_index.ge(0)].norm(dim=1, keepdim=True) + self.eps) # [N, 1]

# 2. compute the norm across all irreps except for 0e
if self.lmax > 0:
Expand All @@ -123,9 +129,14 @@ def forward(self, x):
feature_norm = x[:,self.shift_index.lt(0)].pow(2).mean(1, keepdim=True)

feature_norm = (feature_norm + self.eps).pow(-0.5)
weight = self.affine_weight * feature_norm # [1, n_ir] * [N, 1] = [N, n_ir]
if self.affine:
weight = self.affine_weight * feature_norm # [1, n_ir] * [N, 1] = [N, n_ir]
weight[:,self.scalar_weight_index] = self.affine_weight[:, self.scalar_weight_index] * scalar_norm # [N, n_ir0], only for 0e
x = x * weight[:, self.scale_index]
else:
x[:,self.shift_index.lt(0)] = x[:,self.shift_index.lt(0)] * weight[:, self.scale_index]
x[:,self.shift_index.ge(0)] = x[:,self.shift_index.ge(0)] * scalar_norm

x = x * weight[:, self.scale_index]
x[:, self.shift_index.ge(0)] = x[:, self.shift_index.ge(0)] + self.affine_bias

return x
Expand Down
11 changes: 7 additions & 4 deletions dptb/nn/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,14 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:

edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]

species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()
in_field = data[self.field]
mask = data[self.field][:,0] != 0 # strictly zero valued point must come from masked edges
in_field = data[self.field][mask]
species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()[mask]



assert len(in_field) == len(
edge_center
edge_center[mask]
), "in_field doesnt seem to have correct per-edge shape"

if self.has_scales:
Expand All @@ -314,7 +317,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar)
in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]

data[self.out_field] = in_field
data[self.out_field][mask] = in_field

return data

Expand Down
33 changes: 24 additions & 9 deletions dptb/nn/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def forward(self, x, R, latents=None):
weights = self.radial_emb(latents)

x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device)
angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N))
for (mul, (l,p)), slice in zip(self.irreps_in, self.irreps_in.slices()):
if l > 0:
angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N))
# The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here.
rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0]))
x_[:, slice] = torch.einsum('nji,nmj->nmi', rot_mat_L, x[:, slice].reshape(n,mul,-1)).reshape(n,-1)
else:
x_[:, slice] = x[:, slice]

out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device)
for m in range(self.irreps_out.lmax+1):
Expand All @@ -130,22 +132,35 @@ def forward(self, x, R, latents=None):
else:
out[:, self.m_out_mask[m]] += self.fc_m0(x_[:, self.m_in_mask[m]])
else:
# 1. Prepare input data. .contiguous() is for in-place ops and performance.
# Shape becomes (n, 2, C_in_m / 2)
x_m_in = x_[:, self.m_in_mask[m]].reshape(n, -1, 2).transpose(1, 2).contiguous()

if self.front and self.radial_emb:
out[:, self.m_out_mask[m]] += self.m_linear[m-1](x_[:, self.m_in_mask[m]].reshape(n, 2, -1)*radial_weight).reshape(n, -1)
# Apply weight before linear layer. Use in-place mul_ to save memory.
x_m_in.mul_(radial_weight)
linear_output = self.m_linear[m - 1](x_m_in)
elif self.radial_emb:
out[:, self.m_out_mask[m]] += (self.m_linear[m-1](x_[:, self.m_in_mask[m]].reshape(n, 2, -1))*radial_weight).reshape(n, -1)
# Apply weight after linear layer. Use in-place mul_ to save memory.
linear_output = self.m_linear[m - 1](x_m_in)
linear_output.mul_(radial_weight)
else:
out[:, self.m_out_mask[m]] += self.m_linear[m-1](x_[:, self.m_in_mask[m]].reshape(n, 2, -1)).reshape(n, -1)

out.contiguous()
# No radial embedding, just pass through the linear layer.
linear_output = self.m_linear[m - 1](x_m_in)

# 2. Reshape output and add to the result tensor.
# .contiguous() is necessary before .reshape() after a .transpose().
final_addition = linear_output.transpose(1, 2).contiguous().reshape(n, -1)
out[:, self.m_out_mask[m]] += final_addition

for (mul, (l,p)), slice in zip(self.irreps_out, self.irreps_out.slices()):
if l > 0:
angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N))
# The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here.
rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0]))
out[:, slice] = torch.einsum('nij,nmj->nmi', rot_mat_L, out[:, slice].reshape(n,mul,-1)).reshape(n,-1)

out.contiguous()

return out

class SO2_m_Linear(torch.nn.Module):
Expand Down Expand Up @@ -192,8 +207,8 @@ def __init__(
def forward(self, x_m):
# x_m ~ [N, 2, n_irreps_m]
x_m = self.fc(x_m)
x_r = x_m.narrow(2, 0, self.fc.out_features // 2)
x_i = x_m.narrow(2, self.fc.out_features // 2, self.fc.out_features // 2)
x_r = x_m.narrow(2, 0, self.fc.out_features // 2) #[wmfm, wmf-m]
x_i = x_m.narrow(2, self.fc.out_features // 2, self.fc.out_features // 2) #[w-mfm, w-mf-m]
x_m_r = x_r.narrow(1, 0, 1) - x_i.narrow(1, 1, 1) #x_r[:, 0] - x_i[:, 1]
x_m_i = x_r.narrow(1, 1, 1) + x_i.narrow(1, 0, 1) #x_r[:, 1] + x_i[:, 0]
x_out = torch.cat((x_m_r, x_m_i), dim=1)
Expand Down
Binary file modified dptb/tests/data/e3_band/ref_model/nnenv.ep1474.pth
Binary file not shown.
Loading