diff --git a/dmpfold/network.py b/dmpfold/network.py index 5bf681f..a801163 100644 --- a/dmpfold/network.py +++ b/dmpfold/network.py @@ -244,7 +244,8 @@ def forward(self, x, x2, nloops=5, refine_steps=0): dm = torch.abs(dm) # See https://math.stackexchange.com/questions/156161/finding-the-coordinates-of-points-from-distance-matrix M = 0.5 * (dm[:, 0:1, :].expand(-1, nres, -1) ** 2 + dm[:, :, 0:1].expand(-1, -1, nres) ** 2 - dm ** 2) - w, v = torch.symeig(M.float(), eigenvectors=True) + # w, v = torch.symeig(M.float(), eigenvectors=True) + w, v = torch.linalg.eigh(M.float()) w = torch.clamp(F.relu(w, inplace=False), min = 1e-8) w = torch.diag_embed(w.sqrt()) mds_coords = torch.matmul(v, w)[:, :, -8:] @@ -289,7 +290,8 @@ def forward(self, x, x2, nloops=5, refine_steps=0): dm = torch.abs(dm) # See https://math.stackexchange.com/questions/156161/finding-the-coordinates-of-points-from-distance-matrix M = 0.5 * (dm[:, 0:1, :].expand(-1, nres, -1) ** 2 + dm[:, :, 0:1].expand(-1, -1, nres) ** 2 - dm ** 2) - w, v = torch.symeig(M.float(), eigenvectors=True) + # w, v = torch.symeig(M.float(), eigenvectors=True) + w, v = torch.linalg.eigh(M.float()) w = torch.clamp(F.relu(w, inplace=False), min = 1e-8) w = torch.diag_embed(w.sqrt()) mds_coords = torch.matmul(v, w)[:, :, -8:]