Skip to content

Commit 34f8d44

Browse files
committed
line up molecule ids with paper, addressing #86
1 parent 85bb181 commit 34f8d44

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,10 @@ def map_int_or_string_indices_to_mol(
544544
def maybe_string_to_int(
545545
entries: dict,
546546
indices: Int[' _'] | List[str] | str,
547-
other_index: int = 0
548547
) -> Int[' _']:
548+
549+
unknown_index = len(entries) - 1
550+
549551
if isinstance(indices, str):
550552
indices = list(indices)
551553

@@ -554,7 +556,7 @@ def maybe_string_to_int(
554556

555557
index = {symbol: i for i, symbol in enumerate(entries.keys())}
556558

557-
return tensor([index[c] for c in indices]).long()
559+
return tensor([index.get(c, unknown_index) for c in indices]).long()
558560

559561
@typecheck
560562
def alphafold3_input_to_molecule_input(
@@ -582,12 +584,14 @@ def alphafold3_input_to_molecule_input(
582584
ss_dnas.extend([seq, rc_seq])
583585

584586
# keep track of molecule_ids - for now it is
585-
# other(1) | proteins (20) | rna (4) | dna (4)
587+
# proteins (21) | rna (5) | dna (5) | gap? (1) - unknown for each biomolecule is the very last, ligand is 20
586588

587-
protein_offset = 1
588-
rna_offset = len(HUMAN_AMINO_ACIDS) + protein_offset
589+
rna_offset = len(HUMAN_AMINO_ACIDS)
589590
dna_offset = len(RNA_NUCLEOTIDES) + rna_offset
590591

592+
ligand_id = len(HUMAN_AMINO_ACIDS) - 1
593+
gap_id = len(DNA_NUCLEOTIDES) + dna_offset
594+
591595
molecule_ids = []
592596

593597
# convert all proteins to a List[Mol] of each peptide
@@ -609,7 +613,7 @@ def alphafold3_input_to_molecule_input(
609613

610614
src_tgt_atom_indices.extend([[entry['first_atom_idx'], entry['last_atom_idx']] for entry in protein_entries])
611615

612-
protein_ids = maybe_string_to_int(HUMAN_AMINO_ACIDS, protein) + protein_offset
616+
protein_ids = maybe_string_to_int(HUMAN_AMINO_ACIDS, protein)
613617
molecule_ids.append(protein_ids)
614618

615619
chainable_biomol_entries.append(protein_entries)
@@ -652,11 +656,15 @@ def alphafold3_input_to_molecule_input(
652656
metal_ions = alphafold3_input.metal_ions
653657
mol_metal_ions = map_int_or_string_indices_to_mol(METALS, metal_ions)
654658

659+
molecule_ids.append(tensor([gap_id] * len(mol_metal_ions)))
660+
655661
# convert ligands to rdchem.Mol
656662

657663
ligands = list(alphafold3_input.ligands)
658664
mol_ligands = [(mol_from_smile(ligand) if isinstance(ligand, str) else ligand) for ligand in ligands]
659665

666+
molecule_ids.append(tensor([ligand_id] * len(mol_ligands)))
667+
660668
# create the molecule input
661669

662670
all_protein_mols = flatten(mol_proteins)
@@ -766,7 +774,7 @@ def alphafold3_input_to_molecule_input(
766774

767775
# handle molecule ids
768776

769-
molecule_ids = torch.cat(molecule_ids)
777+
molecule_ids = torch.cat(molecule_ids).long()
770778
molecule_ids = pad_to_len(molecule_ids, num_tokens)
771779

772780
# handle atom_parent_ids

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.131"
3+
version = "0.1.132"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)