diff --git a/run_alphafold.py b/run_alphafold.py index b7c6590..0cd8d87 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -32,7 +32,7 @@ import time import typing from typing import overload - +import tarfile from absl import app from absl import flags from alphafold3.common import folding_input @@ -303,7 +303,12 @@ ' and is non-empty. Useful to set this to True to run the data pipeline and' ' the inference separately, but use the same output directory.', ) - +_COMPRESS_OUTPUT_DIR = flags.DEFINE_bool( + 'compress_output_dir', + False, + 'If True, compress the entire output directory into a single .tar.gz archive ' + 'after all outputs are written, and remove the uncompressed version.', +) def make_model_config( *, @@ -528,6 +533,7 @@ def write_outputs( all_inference_results: Sequence[ResultsForSeed], output_dir: os.PathLike[str] | str, job_name: str, + compress_output_dir: bool = False, ) -> None: """Writes outputs to the specified output directory.""" ranking_scores = [] @@ -590,6 +596,12 @@ def write_outputs( writer.writerow(['seed', 'sample', 'ranking_score']) writer.writerows(ranking_scores) + if compress_output_dir: + archive_path = f"{output_dir}.tar.gz" + with tarfile.open(archive_path, "w:gz") as tar: + tar.add(output_dir, arcname=os.path.basename(output_dir)) + shutil.rmtree(output_dir) + print(f"Compressed outputs to {archive_path} and removed {output_dir}") def replace_db_dir(path_with_db_dir: str, db_dirs: Sequence[str]) -> str: """Replaces the DB_DIR placeholder in a path with the given DB_DIR.""" @@ -618,6 +630,7 @@ def process_fold_input( conformer_max_iterations: int | None = None, resolve_msa_overlaps: bool = True, force_output_dir: bool = False, + compress_output_dir: bool = False ) -> folding_input.Input: ... @@ -633,6 +646,7 @@ def process_fold_input( conformer_max_iterations: int | None = None, resolve_msa_overlaps: bool = True, force_output_dir: bool = False, + compress_output_dir: bool = False ) -> Sequence[ResultsForSeed]: ... @@ -647,6 +661,7 @@ def process_fold_input( conformer_max_iterations: int | None = None, resolve_msa_overlaps: bool = True, force_output_dir: bool = False, + compress_output_dir: bool = False ) -> folding_input.Input | Sequence[ResultsForSeed]: """Runs data pipeline and/or inference on a single fold input. @@ -733,6 +748,7 @@ def process_fold_input( all_inference_results=all_inference_results, output_dir=output_dir, job_name=fold_input.sanitised_name(), + compress_output_dir=compress_output_dir ) output = all_inference_results @@ -887,6 +903,7 @@ def main(_): conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value, resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value, force_output_dir=_FORCE_OUTPUT_DIR.value, + compress_output_dir=_COMPRESS_OUTPUT_DIR.value, ) num_fold_inputs += 1