Skip to content

Commit f7f1a62

Browse files
committed
make tests pass
1 parent 76126d2 commit f7f1a62

File tree

5 files changed

+14
-4
lines changed

5 files changed

+14
-4
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Any, Callable, Dict, List, Literal, Set, Tuple, Type
1313

1414
import einx
15-
from einops import pack
15+
from einops import pack, rearrange
1616

1717
import numpy as np
1818
from numpy.lib.format import open_memmap
@@ -391,7 +391,7 @@ class MoleculeInput:
391391
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
392392
src_tgt_atom_indices: Int['n 2']
393393
token_bonds: Bool['n n']
394-
is_molecule_mod: Bool['n num_mods'] | None = None
394+
is_molecule_mod: Bool['n num_mods'] | Bool[' n'] | None = None
395395
molecule_atom_indices: List[int | None] | None = None
396396
distogram_atom_indices: List[int | None] | None = None
397397
missing_atom_indices: List[Int[' _'] | None] | None = None
@@ -657,6 +657,13 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
657657

658658
chains = tensor([default(chain, -1) for chain in i.chains]).long()
659659

660+
# handle is_molecule_mod being one dimensional
661+
662+
is_molecule_mod = i.is_molecule_mod
663+
664+
if is_molecule_mod.ndim == 1:
665+
is_molecule_mod = rearrange(is_molecule_mod, 'n -> n 1')
666+
660667
# atom input
661668

662669
atom_input = AtomInput(
@@ -666,7 +673,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
666673
molecule_ids=i.molecule_ids,
667674
molecule_atom_indices=i.molecule_atom_indices,
668675
distogram_atom_indices=i.distogram_atom_indices,
669-
is_molecule_mod=i.is_molecule_mod,
676+
is_molecule_mod=is_molecule_mod,
670677
msa=i.msa,
671678
templates=i.templates,
672679
msa_mask=i.msa_mask,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.95"
3+
version = "0.2.96"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/configs/trainer_with_atom_dataset_created_from_pdb.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ model:
1717
num_plddt_bins: 50
1818
num_pde_bins: 64
1919
num_pae_bins: 64
20+
num_molecule_mods: 1
2021
sigma_data: 16
2122
diffusion_num_augmentations: 4
2223
loss_confidence_weight: 0.0001

tests/configs/trainer_with_pdb_dataset.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ model:
1717
num_plddt_bins: 50
1818
num_pde_bins: 64
1919
num_pae_bins: 64
20+
num_molecule_mods: 1
2021
sigma_data: 16
2122
diffusion_num_augmentations: 4
2223
loss_confidence_weight: 0.0001

tests/configs/trainer_with_pdb_dataset_and_weighted_sampling.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ model:
1717
num_plddt_bins: 50
1818
num_pde_bins: 64
1919
num_pae_bins: 64
20+
num_molecule_mods: 1
2021
sigma_data: 16
2122
diffusion_num_augmentations: 4
2223
loss_confidence_weight: 0.0001

0 commit comments

Comments
 (0)