Skip to content

Commit 9ab2eb0

Browse files
committed
fix smooth lddt loss
1 parent 6dad1e5 commit 9ab2eb0

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,7 +2047,7 @@ def forward(
20472047

20482048
smooth_lddt_loss = self.smooth_lddt_loss(
20492049
denoised_atom_pos,
2050-
normalized_atom_pos,
2050+
atom_pos_ground_truth,
20512051
atom_is_dna,
20522052
atom_is_rna,
20532053
coords_mask = atom_mask
@@ -2488,7 +2488,7 @@ def __init__(
24882488
self,
24892489
*,
24902490
dim_single_inputs,
2491-
atompair_dist_bins: Float['d'],
2491+
atompair_dist_bins: Float[' d'],
24922492
dim_single = 384,
24932493
dim_pairwise = 128,
24942494
num_plddt_bins = 50,

alphafold3_pytorch/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import NamedTuple
2+
from typing import NamedTuple, Tuple
33

44
import torch
55
from torch import nn

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.33"
3+
version = "0.0.34"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def test_alphafold3():
428428
pde_labels = pde_labels,
429429
plddt_labels = plddt_labels,
430430
resolved_labels = resolved_labels,
431+
diffusion_add_smooth_lddt_loss = True,
431432
return_loss_breakdown = True
432433
)
433434

@@ -580,7 +581,6 @@ def test_alphafold3_with_packed_atom_repr():
580581

581582
loss.backward()
582583

583-
print(residue_atom_lens)
584584
sampled_atom_pos = alphafold3(
585585
num_sample_steps = 16,
586586
atom_inputs = atom_inputs,

0 commit comments

Comments
 (0)