Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import time
import typing
from typing import overload

import tarfile
from absl import app
from absl import flags
from alphafold3.common import folding_input
Expand Down Expand Up @@ -303,7 +303,12 @@
' and is non-empty. Useful to set this to True to run the data pipeline and'
' the inference separately, but use the same output directory.',
)

_COMPRESS_OUTPUT_DIR = flags.DEFINE_bool(
'compress_output_dir',
False,
'If True, compress the entire output directory into a single .tar.gz archive '
'after all outputs are written, and remove the uncompressed version.',
)

def make_model_config(
*,
Expand Down Expand Up @@ -528,6 +533,7 @@ def write_outputs(
all_inference_results: Sequence[ResultsForSeed],
output_dir: os.PathLike[str] | str,
job_name: str,
compress_output_dir: bool = False,
) -> None:
"""Writes outputs to the specified output directory."""
ranking_scores = []
Expand Down Expand Up @@ -590,6 +596,12 @@ def write_outputs(
writer.writerow(['seed', 'sample', 'ranking_score'])
writer.writerows(ranking_scores)

if compress_output_dir:
archive_path = f"{output_dir}.tar.gz"
with tarfile.open(archive_path, "w:gz") as tar:
tar.add(output_dir, arcname=os.path.basename(output_dir))
shutil.rmtree(output_dir)
print(f"Compressed outputs to {archive_path} and removed {output_dir}")

def replace_db_dir(path_with_db_dir: str, db_dirs: Sequence[str]) -> str:
"""Replaces the DB_DIR placeholder in a path with the given DB_DIR."""
Expand Down Expand Up @@ -618,6 +630,7 @@ def process_fold_input(
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
compress_output_dir: bool = False
) -> folding_input.Input:
...

Expand All @@ -633,6 +646,7 @@ def process_fold_input(
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
compress_output_dir: bool = False
) -> Sequence[ResultsForSeed]:
...

Expand All @@ -647,6 +661,7 @@ def process_fold_input(
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
compress_output_dir: bool = False
) -> folding_input.Input | Sequence[ResultsForSeed]:
"""Runs data pipeline and/or inference on a single fold input.

Expand Down Expand Up @@ -733,6 +748,7 @@ def process_fold_input(
all_inference_results=all_inference_results,
output_dir=output_dir,
job_name=fold_input.sanitised_name(),
compress_output_dir=compress_output_dir
)
output = all_inference_results

Expand Down Expand Up @@ -887,6 +903,7 @@ def main(_):
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
force_output_dir=_FORCE_OUTPUT_DIR.value,
compress_output_dir=_COMPRESS_OUTPUT_DIR.value,
)
num_fold_inputs += 1

Expand Down