Skip to content

Commit 7c8ff53

Browse files
committed
add entity_ids in additional_molecule_feats
1 parent 2947827 commit 7c8ff53

File tree

2 files changed

+56
-23
lines changed

2 files changed

+56
-23
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Type, Literal, Callable, List, Any
66

77
import torch
8+
from torch import tensor
89
import torch.nn.functional as F
910
import einx
1011

@@ -169,7 +170,7 @@ def molecule_to_atom_input(
169170
else:
170171
atom_lens.append(num_atoms)
171172

172-
atom_lens = torch.tensor(atom_lens)
173+
atom_lens = tensor(atom_lens)
173174
total_atoms = atom_lens.sum().item()
174175

175176
# molecule_atom_lens
@@ -194,7 +195,7 @@ def molecule_to_atom_input(
194195

195196
atom_ids.append(atom_index[atom_symbol])
196197

197-
atom_ids = torch.tensor(atom_ids, dtype = torch.long)
198+
atom_ids = tensor(atom_ids, dtype = torch.long)
198199

199200
# handle maybe atompair embeds
200201

@@ -228,8 +229,8 @@ def molecule_to_atom_input(
228229

229230
updates.extend([bond_id, bond_id])
230231

231-
coordinates = torch.tensor(coordinates).long()
232-
updates = torch.tensor(updates).long()
232+
coordinates = tensor(coordinates).long()
233+
updates = tensor(updates).long()
233234

234235
mol_atompair_ids = einx.set_at('[h w], c [2], c -> [h w]', mol_atompair_ids, coordinates, updates)
235236

@@ -269,7 +270,7 @@ def molecule_to_atom_input(
269270
pos = mol.GetConformer().GetAtomPosition(i)
270271
all_atom_pos.append([pos.x, pos.y, pos.z])
271272

272-
all_atom_pos_tensor = torch.tensor(all_atom_pos)
273+
all_atom_pos_tensor = tensor(all_atom_pos)
273274

274275
dist_matrix = torch.cdist(all_atom_pos_tensor, all_atom_pos_tensor)
275276

@@ -281,9 +282,9 @@ def molecule_to_atom_input(
281282
offset += num_atoms
282283

283284
atom_input = AtomInput(
284-
atom_inputs = torch.tensor(atom_inputs, dtype = torch.float),
285+
atom_inputs = tensor(atom_inputs, dtype = torch.float),
285286
atompair_inputs = atompair_inputs,
286-
molecule_atom_lens = torch.tensor(atom_lens, dtype = torch.long),
287+
molecule_atom_lens = tensor(atom_lens, dtype = torch.long),
287288
molecule_ids = mol_input.molecule_ids,
288289
additional_token_feats = mol_input.additional_token_feats,
289290
additional_molecule_feats = mol_input.additional_molecule_feats,
@@ -380,25 +381,27 @@ def maybe_string_to_int(
380381

381382
index = {symbol: i for i, symbol in enumerate(entries.keys())}
382383

383-
return torch.tensor([index[c] for c in indices]).long()
384+
return tensor([index[c] for c in indices]).long()
384385

385386
@typecheck
386387
def alphafold3_input_to_molecule_input(
387388
alphafold3_input: Alphafold3Input
388389
) -> MoleculeInput:
389390

390-
ss_rnas = list(alphafold3_input.ss_rna)
391-
ss_dnas = list(alphafold3_input.ss_dna)
391+
i = alphafold3_input
392+
393+
ss_rnas = list(i.ss_rna)
394+
ss_dnas = list(i.ss_dna)
392395

393396
# any double stranded nucleic acids is added to single stranded lists with its reverse complement
394397
# rc stands for reverse complement
395398

396-
for seq in alphafold3_input.ds_rna:
399+
for seq in i.ds_rna:
397400
rc_fn = partial(reverse_complement, nucleic_acid_type = 'rna') if isinstance(seq, str) else reverse_complement_tensor
398401
rc_seq = rc_fn(seq)
399402
ss_rnas.extend([seq, rc_seq])
400403

401-
for seq in alphafold3_input.ds_dna:
404+
for seq in i.ds_dna:
402405
rc_fn = partial(reverse_complement, nucleic_acid_type = 'dna') if isinstance(seq, str) else reverse_complement_tensor
403406
rc_seq = rc_fn(seq)
404407
ss_dnas.extend([seq, rc_seq])
@@ -414,7 +417,7 @@ def alphafold3_input_to_molecule_input(
414417

415418
# convert all proteins to a List[Mol] of each peptide
416419

417-
proteins = alphafold3_input.proteins
420+
proteins = i.proteins
418421
mol_proteins = []
419422
protein_entries = []
420423
molecule_atom_indices = []
@@ -497,7 +500,7 @@ def alphafold3_input_to_molecule_input(
497500

498501
arange = torch.arange(num_tokens)[:, None]
499502

500-
molecule_types_lens_cumsum = torch.tensor([0, *molecule_type_token_lens]).cumsum(dim = -1)
503+
molecule_types_lens_cumsum = tensor([0, *molecule_type_token_lens]).cumsum(dim = -1)
501504
left, right = molecule_types_lens_cumsum[:-1], molecule_types_lens_cumsum[1:]
502505

503506
is_molecule_types = (arange >= left) & (arange < right)
@@ -552,8 +555,8 @@ def alphafold3_input_to_molecule_input(
552555

553556
updates.extend([True, True])
554557

555-
coordinates = torch.tensor(coordinates).long()
556-
updates = torch.tensor(updates).bool()
558+
coordinates = tensor(coordinates).long()
559+
updates = tensor(updates).bool()
557560

558561
has_bond = einx.set_at('[h w], c [2], c -> [h w]', has_bond, coordinates, updates)
559562

@@ -590,7 +593,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
590593

591594
unflattened_atom_parent_ids = [([asym_id] * num_tokens) for asym_id, num_tokens in enumerate([*num_protein_atoms, *num_ss_rna_atoms, *num_ss_dna_atoms, *num_ligand_atoms, num_metal_ions])]
592595

593-
atom_parent_ids = torch.tensor(flatten(unflattened_atom_parent_ids))
596+
atom_parent_ids = tensor(flatten(unflattened_atom_parent_ids))
594597

595598
# constructing the additional_molecule_feats
596599
# which is in turn used to derive relative positions
@@ -609,27 +612,57 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
609612

610613
unflattened_asym_ids = [([asym_id] * num_tokens) for asym_id, num_tokens in enumerate([*num_protein_tokens, *num_ss_rna_tokens, *num_ss_dna_tokens, *num_ligand_tokens, num_metal_ions])]
611614

612-
asym_ids = torch.tensor(flatten(unflattened_asym_ids))
615+
asym_ids = tensor(flatten(unflattened_asym_ids))
616+
617+
# entity ids
618+
619+
entity_ids = []
620+
curr_id = 0
621+
622+
def add_entity_id(length):
623+
nonlocal curr_id
624+
entity_ids.extend([curr_id] * length)
625+
curr_id += 1
626+
627+
add_entity_id(sum(num_protein_tokens))
628+
629+
for ss_rna in i.ss_rna:
630+
add_entity_id(len(ss_rna))
631+
632+
for ds_rna in i.ds_rna:
633+
add_entity_id(len(ds_rna) * 2)
634+
635+
for ss_dna in i.ss_dna:
636+
add_entity_id(len(ss_dna))
637+
curr_id += 1
638+
639+
for ds_dna in i.ds_dna:
640+
add_entity_id(len(ds_dna) * 2)
641+
642+
for l in mol_ligands:
643+
add_entity_id(l.GetNumAtoms())
644+
645+
add_entity_id(num_metal_ions)
646+
647+
entity_ids = tensor(entity_ids).long()
613648

614649
# concat for all of additional_molecule_feats
615650

616651
additional_molecule_feats = torch.stack((
617652
molecule_ids,
618653
torch.arange(num_tokens),
619654
asym_ids,
620-
torch.zeros(num_tokens).long(),
655+
entity_ids,
621656
torch.zeros(num_tokens).long(),
622657
), dim = -1)
623658

624659
# molecule atom indices
625660

626-
molecule_atom_indices = torch.tensor(molecule_atom_indices)
661+
molecule_atom_indices = tensor(molecule_atom_indices)
627662
molecule_atom_indices = pad_to_len(molecule_atom_indices, num_tokens, value = -1)
628663

629664
# create molecule input
630665

631-
i = alphafold3_input
632-
633666
molecule_input = MoleculeInput(
634667
molecules = molecules,
635668
molecule_token_pool_lens = token_pool_lens,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.83"
3+
version = "0.1.84"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)