Skip to content

Commit c5cea69

Browse files
authored
Update alphafold3.py (#285)
1 parent 910e075 commit c5cea69

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2782,7 +2782,7 @@ def forward(
27822782
return_loss_breakdown = False,
27832783
single_structure_input=False,
27842784
verbose=None,
2785-
filepaths: List[str] | Tuple[str] | None = None,
2785+
filepath: List[str] | Tuple[str] | None = None,
27862786
) -> ElucidatedAtomDiffusionReturn:
27872787
verbose = default(verbose, self.verbose)
27882788

@@ -2865,7 +2865,7 @@ def forward(
28652865
)
28662866
except Exception as e:
28672867
# NOTE: For many (random) unit test inputs, permutation alignment can be unstable
2868-
logger.warning(f"Skipping multi-chain permutation alignment {f'for {filepaths}' if exists(filepaths) else ''} due to: {e}")
2868+
logger.warning(f"Skipping multi-chain permutation alignment {f'for {filepath}' if exists(filepath) else ''} due to: {e}")
28692869

28702870
# main diffusion mse loss
28712871

@@ -6499,7 +6499,8 @@ def forward(
64996499
max_conf_resolution: float = 4.0,
65006500
hard_validate: bool = False,
65016501
verbose: bool | None = None,
6502-
filepaths: List[str] | Tuple[str] | None = None
6502+
chains: Int["b 2"] | None = None,
6503+
filepath: List[str] | Tuple[str] | None = None,
65036504
) -> (
65046505
Float['b m 3'] |
65056506
List[Structure] |
@@ -7140,7 +7141,7 @@ def forward(
71407141
ligand_loss_weight = self.ligand_loss_weight,
71417142
single_structure_input = single_structure_input,
71427143
verbose = verbose,
7143-
filepaths = filepaths,
7144+
filepath = filepath,
71447145
)
71457146

71467147
# confidence head
@@ -7221,7 +7222,7 @@ def forward(
72217222
)
72227223
except Exception as e:
72237224
# NOTE: For many (random) unit test inputs, permutation alignment can be unstable
7224-
logger.warning(f"Skipping multi-chain permutation alignment {f'for {filepaths}' if exists(filepaths) else ''} due to: {e}")
7225+
logger.warning(f"Skipping multi-chain permutation alignment {f'for {filepath}' if exists(filepath) else ''} due to: {e}")
72257226

72267227
assert exists(
72277228
distogram_atom_indices

0 commit comments

Comments
 (0)