22
33from functools import wraps , partial
44from dataclasses import dataclass , asdict , field
5- from typing import Type , Literal , Callable , List , Any
5+ from typing import Type , Literal , Callable , List , Any , Tuple
66
77import torch
88from torch import tensor
@@ -350,7 +350,7 @@ class Alphafold3Input:
350350 ligands : List [Mol | str ] = imm_list () # can be given as smiles
351351 ds_dna : List [Int [' _' ] | str ] = imm_list ()
352352 ds_rna : List [Int [' _' ] | str ] = imm_list ()
353- atom_parent_ids : Int ['m' ] | None = None
353+ atom_parent_ids : Int [' m' ] | None = None
354354 additional_token_feats : Float [f'n dtf' ] | None = None
355355 templates : Float ['t n n dt' ] | None = None
356356 msa : Float ['s n dm' ] | None = None
@@ -508,7 +508,7 @@ def alphafold3_input_to_molecule_input(
508508 # convert ligands to rdchem.Mol
509509
510510 ligands = list (alphafold3_input .ligands )
511- mol_ligands = [(mol_from_smile (l ) if isinstance (l , str ) else l ) for l in ligands ]
511+ mol_ligands = [(mol_from_smile (ligand ) if isinstance (ligand , str ) else ligand ) for ligand in ligands ]
512512
513513 # create the molecule input
514514
@@ -643,7 +643,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
643643 num_protein_atoms = get_num_atoms_per_chain (mol_proteins )
644644 num_ss_rna_atoms = get_num_atoms_per_chain (mol_ss_rnas )
645645 num_ss_dna_atoms = get_num_atoms_per_chain (mol_ss_dnas )
646- num_ligand_atoms = [l .GetNumAtoms () for l in mol_ligands ]
646+ num_ligand_atoms = [ligand .GetNumAtoms () for ligand in mol_ligands ]
647647
648648 atom_counts = [* num_protein_atoms , * num_ss_rna_atoms , * num_ss_dna_atoms , * num_ligand_atoms , num_metal_ions ]
649649
@@ -665,7 +665,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
665665 num_protein_tokens = [len (protein ) for protein in proteins ]
666666 num_ss_rna_tokens = [len (rna ) for rna in ss_rnas ]
667667 num_ss_dna_tokens = [len (dna ) for dna in ss_dnas ]
668- num_ligand_tokens = [l .GetNumAtoms () for l in mol_ligands ]
668+ num_ligand_tokens = [ligand .GetNumAtoms () for ligand in mol_ligands ]
669669
670670 token_repeats = tensor ([* num_protein_tokens , * num_ss_rna_tokens , * num_ss_dna_tokens , * num_ligand_tokens , num_metal_ions ])
671671
0 commit comments