Skip to content

Commit b0198d4

Browse files
committed
output type for distogram head
1 parent 216eea7 commit b0198d4

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,11 +1720,13 @@ def forward(
17201720
# distogram head
17211721

17221722
class DistogramHead(Module):
1723+
1724+
@typecheck
17231725
def __init__(
17241726
self,
17251727
*,
17261728
dim_pairwise = 128,
1727-
num_dist_bins = 38 # think it is 38?
1729+
num_dist_bins = 38, # think it is 38?
17281730
):
17291731
super().__init__()
17301732

@@ -1737,9 +1739,9 @@ def __init__(
17371739
def forward(
17381740
self,
17391741
pairwise_repr: Float['b n n d']
1740-
):
1741-
logits = self.to_distogram_logits(pairwise_repr)
1742+
) -> Float['b l n n']:
17421743

1744+
logits = self.to_distogram_logits(pairwise_repr)
17431745
return logits
17441746

17451747
# confidence head

0 commit comments

Comments
 (0)