We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 216eea7 commit b0198d4Copy full SHA for b0198d4
alphafold3_pytorch/alphafold3.py
@@ -1720,11 +1720,13 @@ def forward(
1720
# distogram head
1721
1722
class DistogramHead(Module):
1723
+
1724
+ @typecheck
1725
def __init__(
1726
self,
1727
*,
1728
dim_pairwise = 128,
- num_dist_bins = 38 # think it is 38?
1729
+ num_dist_bins = 38, # think it is 38?
1730
):
1731
super().__init__()
1732
@@ -1737,9 +1739,9 @@ def __init__(
1737
1739
def forward(
1738
1740
1741
pairwise_repr: Float['b n n d']
- ):
- logits = self.to_distogram_logits(pairwise_repr)
1742
+ ) -> Float['b l n n']:
1743
1744
+ logits = self.to_distogram_logits(pairwise_repr)
1745
return logits
1746
1747
# confidence head
0 commit comments