Skip to content

Commit 6bd8813

Browse files
committed
simplify onehot fn
1 parent 83955fc commit 6bd8813

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -942,12 +942,10 @@ def forward(
942942
)
943943

944944
def onehot(x, bins):
945-
x, packed_shape = pack_one(x, '*')
946-
dist_from_bins = einx.subtract('i, j -> i j', x, bins)
947-
indexes = dist_from_bins.abs().min(dim = 1, keepdim = True).indices
948-
indexes = rearrange(indexes.long(), 'i j -> (i j) 1')
949-
one_hots = torch.zeros(indexes.shape[0], len(bins)).scatter_(1, indexes, 1)
950-
return unpack_one(one_hots, packed_shape, '* d')
945+
dist_from_bins = einx.subtract('... i, j -> ... i j', x, bins)
946+
indices = dist_from_bins.abs().min(dim = -1, keepdim = True).indices
947+
one_hots = F.one_hot(indices.long(), num_classes = len(bins))
948+
return one_hots.float()
951949

952950
r_arange = torch.arange(2*self.r_max + 2, device = device)
953951
s_arange = torch.arange(2*self.s_max + 2, device = device)

0 commit comments

Comments
 (0)