Skip to content

Commit 09ca3ba

Browse files
committed
switch it back to the way Alex had before and release 0.2.0
1 parent ecf8bfd commit 09ca3ba

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3382,7 +3382,7 @@ def forward(
33823382
return_present_sampled_atoms: bool = False,
33833383
num_rollout_steps: int = 20,
33843384
rollout_show_tqdm_pbar: bool = False
3385-
) -> Float['b m 3'] | List[Float['l 3']] | Float[''] | Tuple[Float[''], LossBreakdown]:
3385+
) -> Float['b m 3'] | Float['l 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
33863386

33873387
atom_seq_len = atom_inputs.shape[-2]
33883388

@@ -3627,7 +3627,7 @@ def forward(
36273627
sampled_atom_pos = einx.where('b m, b m c, -> b m c', atom_mask, sampled_atom_pos, 0.)
36283628

36293629
if exists(missing_atom_mask) and return_present_sampled_atoms:
3630-
sampled_atom_pos = [one_sampled_atom_pos[~one_missing_atom_mask] for one_sampled_atom_pos, one_missing_atom_mask in zip(sampled_atom_pos, missing_atom_mask)]
3630+
sampled_atom_pos = sampled_atom_pos[~missing_atom_mask]
36313631

36323632
return sampled_atom_pos
36333633

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.136"
3+
version = "0.2.0"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_pdbinput_input():
191191

192192
alphafold3.eval()
193193

194-
sampled_atom_pos, = alphafold3(
194+
sampled_atom_pos = alphafold3(
195195
**batched_eval_atom_input.dict(), return_loss=False, return_present_sampled_atoms=True
196196
)
197197

0 commit comments

Comments
 (0)