Skip to content

Commit 23690bb

Browse files
committed
get around some jaxtyping issue
1 parent d6d34d2 commit 23690bb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3063,12 +3063,12 @@ def __init__(
30633063
@typecheck
30643064
def forward(
30653065
self,
3066-
pred_coords: Float['b n 3'] | Float['b m 3'],
3067-
true_coords: Float['b n 3'] | Float['b m 3'],
3066+
pred_coords: Float['b m_or_n 3'],
3067+
true_coords: Float['b m_or_n 3'],
30683068
pred_frames: Float['b n 3 3'],
30693069
true_frames: Float['b n 3 3'],
30703070
molecule_atom_lens: Int['b n'] | None = None
3071-
) -> Float['b n n'] | Float['b m m']:
3071+
) -> Float['b m_or_n m_or_n']:
30723072
"""
30733073
pred_coords: predicted coordinates
30743074
true_coords: true coordinates

0 commit comments

Comments
 (0)