Skip to content

Commit 863a03c

Browse files
Merge pull request #39 from michellab/bugfix-close-plots
Ensure all matplotlib resources get closed after analysis
2 parents a347fe6 + 8c051dd commit 863a03c

File tree

3 files changed

+153
-143
lines changed

3 files changed

+153
-143
lines changed

a3fe/analyse/plot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,8 @@ def plot_equilibration_time(lam_windows: _List["LamWindows"], output_dir: str) -
544544
transparent=False,
545545
)
546546

547+
_plt.close(fig)
548+
547549

548550
def plot_overlap_mat(
549551
ax: _plt.Axes,
@@ -735,6 +737,8 @@ def plot_overlap_mats(
735737
)
736738
fig.savefig(name)
737739

740+
_plt.close(fig)
741+
738742

739743
def plot_convergence(
740744
fracts: _np.ndarray,

a3fe/run/stage.py

Lines changed: 148 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import threading as _threading
1010
from copy import deepcopy as _deepcopy
1111
from math import ceil as _ceil
12+
import matplotlib.pyplot as _plt
1213
from multiprocessing import get_context as _get_context
1314
from time import sleep as _sleep
1415
from typing import Any as _Any
@@ -771,171 +772,175 @@ def analyse(
771772
"Despite equilibration being detected, no equilibration time was found."
772773
)
773774

774-
if get_frnrg:
775-
self._logger.info(
776-
f"Computing free energy changes using the MBAR for runs {run_nos}"
777-
)
775+
try: # Conduct analysis
776+
if get_frnrg:
777+
self._logger.info(
778+
f"Computing free energy changes using the MBAR for runs {run_nos}"
779+
)
778780

779-
# Remove unequilibrated data from the equilibrated output directory
780-
for win in self.lam_windows:
781-
win._write_equilibrated_simfiles()
781+
# Remove unequilibrated data from the equilibrated output directory
782+
for win in self.lam_windows:
783+
win._write_equilibrated_simfiles()
784+
785+
# Run MBAR and compute mean and 95 % C.I. of free energy
786+
if not slurm:
787+
free_energies, errors, mbar_outfiles, _ = _run_mbar(
788+
run_nos=run_nos,
789+
output_dir=self.output_dir,
790+
percentage_end=fraction * 100,
791+
percentage_start=0,
792+
subsampling=subsampling,
793+
equilibrated=True,
794+
)
795+
else:
796+
jobs, mbar_outfiles, tmp_simfiles = _submit_mbar_slurm(
797+
output_dir=self.output_dir,
798+
virtual_queue=self.virtual_queue,
799+
run_nos=run_nos,
800+
run_somd_dir=self.input_dir,
801+
percentage_end=fraction * 100,
802+
percentage_start=0,
803+
subsampling=subsampling,
804+
equilibrated=True,
805+
)
782806

783-
# Run MBAR and compute mean and 95 % C.I. of free energy
784-
if not slurm:
785-
free_energies, errors, mbar_outfiles, _ = _run_mbar(
786-
run_nos=run_nos,
807+
free_energies, errors, *_ = _collect_mbar_slurm(
808+
output_dir=self.output_dir,
809+
run_nos=run_nos,
810+
jobs=jobs,
811+
mbar_out_files=mbar_outfiles,
812+
virtual_queue=self.virtual_queue,
813+
tmp_simfiles=tmp_simfiles,
814+
)
815+
816+
mean_free_energy = _np.mean(free_energies)
817+
# Gaussian 95 % C.I.
818+
conf_int = (
819+
_stats.t.interval(
820+
0.95,
821+
len(free_energies) - 1,
822+
mean_free_energy,
823+
scale=_stats.sem(free_energies),
824+
)[1]
825+
- mean_free_energy
826+
) # 95 % C.I.
827+
828+
# Write overall MBAR stats to file
829+
with open(f"{self.output_dir}/overall_stats.dat", "a") as ofile:
830+
if get_frnrg:
831+
ofile.write(
832+
"###################################### Free Energies ########################################\n"
833+
)
834+
ofile.write(
835+
f"Mean free energy: {mean_free_energy: .3f} + /- {conf_int:.3f} kcal/mol\n"
836+
)
837+
for i in range(len(free_energies)):
838+
ofile.write(
839+
f"Free energy from run {i + 1}: {free_energies[i]: .3f} +/- {errors[i]:.3f} kcal/mol\n"
840+
)
841+
ofile.write(
842+
"Errors are 95 % C.I.s based on the assumption of a Gaussian distribution of free energies\n"
843+
)
844+
ofile.write(f"Runs analysed: {run_nos}\n")
845+
846+
# Plot overlap matrices and PMFs
847+
_plot_overlap_mats(
787848
output_dir=self.output_dir,
788-
percentage_end=fraction * 100,
789-
percentage_start=0,
790-
subsampling=subsampling,
791-
equilibrated=True,
849+
nlam=len(self.lam_windows),
850+
mbar_outfiles=mbar_outfiles,
792851
)
793-
else:
794-
jobs, mbar_outfiles, tmp_simfiles = _submit_mbar_slurm(
852+
_plot_mbar_pmf(mbar_outfiles, self.output_dir)
853+
equilibrated_gradient_data = _GradientData(
854+
lam_winds=self.lam_windows, equilibrated=True
855+
)
856+
_plot_overlap_mats(
795857
output_dir=self.output_dir,
796-
virtual_queue=self.virtual_queue,
797-
run_nos=run_nos,
798-
run_somd_dir=self.input_dir,
799-
percentage_end=fraction * 100,
800-
percentage_start=0,
801-
subsampling=subsampling,
802-
equilibrated=True,
858+
nlam=len(self.lam_windows),
859+
predicted=True,
860+
gradient_data=equilibrated_gradient_data,
803861
)
804862

805-
free_energies, errors, *_ = _collect_mbar_slurm(
863+
# Plot RMSDS
864+
if plot_rmsds:
865+
self._logger.info("Plotting RMSDs")
866+
_plot_rmsds(
867+
lam_windows=self.lam_windows,
806868
output_dir=self.output_dir,
807-
run_nos=run_nos,
808-
jobs=jobs,
809-
mbar_out_files=mbar_outfiles,
810-
virtual_queue=self.virtual_queue,
811-
tmp_simfiles=tmp_simfiles,
869+
selection="resname LIG and (not name H*)",
812870
)
813871

814-
mean_free_energy = _np.mean(free_energies)
815-
# Gaussian 95 % C.I.
816-
conf_int = (
817-
_stats.t.interval(
818-
0.95,
819-
len(free_energies) - 1,
820-
mean_free_energy,
821-
scale=_stats.sem(free_energies),
822-
)[1]
823-
- mean_free_energy
824-
) # 95 % C.I.
825-
826-
# Write overall MBAR stats to file
827-
with open(f"{self.output_dir}/overall_stats.dat", "a") as ofile:
828-
if get_frnrg:
829-
ofile.write(
830-
"###################################### Free Energies ########################################\n"
831-
)
832-
ofile.write(
833-
f"Mean free energy: {mean_free_energy: .3f} + /- {conf_int:.3f} kcal/mol\n"
834-
)
835-
for i in range(len(free_energies)):
836-
ofile.write(
837-
f"Free energy from run {i + 1}: {free_energies[i]: .3f} +/- {errors[i]:.3f} kcal/mol\n"
838-
)
839-
ofile.write(
840-
"Errors are 95 % C.I.s based on the assumption of a Gaussian distribution of free energies\n"
841-
)
842-
ofile.write(f"Runs analysed: {run_nos}\n")
843-
844-
# Plot overlap matrices and PMFs
845-
_plot_overlap_mats(
846-
output_dir=self.output_dir,
847-
nlam=len(self.lam_windows),
848-
mbar_outfiles=mbar_outfiles,
849-
)
850-
_plot_mbar_pmf(mbar_outfiles, self.output_dir)
872+
# Analyse the gradient data and make plots
873+
self._logger.info("Plotting gradients data")
851874
equilibrated_gradient_data = _GradientData(
852-
lam_winds=self.lam_windows, equilibrated=True
875+
lam_winds=self.lam_windows, equilibrated=True, run_nos=run_nos
853876
)
854-
_plot_overlap_mats(
877+
for plot_type in [
878+
"mean",
879+
"stat_ineff",
880+
"integrated_sem",
881+
"integrated_var",
882+
"pred_best_simtime",
883+
]:
884+
_plot_gradient_stats(
885+
gradients_data=equilibrated_gradient_data,
886+
output_dir=self.output_dir,
887+
plot_type=plot_type,
888+
)
889+
_plot_gradient_hists(
890+
gradients_data=equilibrated_gradient_data,
855891
output_dir=self.output_dir,
856-
nlam=len(self.lam_windows),
857-
predicted=True,
858-
gradient_data=equilibrated_gradient_data,
892+
run_nos=run_nos,
859893
)
860-
861-
# Plot RMSDS
862-
if plot_rmsds:
863-
self._logger.info("Plotting RMSDs")
864-
_plot_rmsds(
865-
lam_windows=self.lam_windows,
894+
_plot_gradient_timeseries(
895+
gradients_data=equilibrated_gradient_data,
866896
output_dir=self.output_dir,
867-
selection="resname LIG and (not name H*)",
897+
run_nos=run_nos,
868898
)
869899

870-
# Analyse the gradient data and make plots
871-
self._logger.info("Plotting gradients data")
872-
equilibrated_gradient_data = _GradientData(
873-
lam_winds=self.lam_windows, equilibrated=True, run_nos=run_nos
874-
)
875-
for plot_type in [
876-
"mean",
877-
"stat_ineff",
878-
"integrated_sem",
879-
"integrated_var",
880-
"pred_best_simtime",
881-
]:
882-
_plot_gradient_stats(
883-
gradients_data=equilibrated_gradient_data,
884-
output_dir=self.output_dir,
885-
plot_type=plot_type,
900+
# Make plots of equilibration time
901+
self._logger.info("Plotting equilibration times")
902+
_plot_equilibration_time(
903+
lam_windows=self.lam_windows, output_dir=self.output_dir
886904
)
887-
_plot_gradient_hists(
888-
gradients_data=equilibrated_gradient_data,
889-
output_dir=self.output_dir,
890-
run_nos=run_nos,
891-
)
892-
_plot_gradient_timeseries(
893-
gradients_data=equilibrated_gradient_data,
894-
output_dir=self.output_dir,
895-
run_nos=run_nos,
896-
)
897905

898-
# Make plots of equilibration time
899-
self._logger.info("Plotting equilibration times")
900-
_plot_equilibration_time(
901-
lam_windows=self.lam_windows, output_dir=self.output_dir
902-
)
906+
# Check and plot the Gelman-Rubin stat
907+
rhat_dict = _check_equil_multiwindow_gelman_rubin(
908+
lambda_windows=self.lam_windows, output_dir=self.output_dir
909+
)
910+
rhat_equil = {lam: rhat < 1.1 for lam, rhat in rhat_dict.items()}
911+
for lam, equil in rhat_equil.items():
912+
if not equil:
913+
self._logger.warning(
914+
f"The Gelman-Rubin statistic for lambda = {lam} is greater than 1.1. "
915+
"This suggests that the repeat simulations have not converged to the "
916+
"same distirbution and there is a sampling issue."
917+
)
903918

904-
# Check and plot the Gelman-Rubin stat
905-
rhat_dict = _check_equil_multiwindow_gelman_rubin(
906-
lambda_windows=self.lam_windows, output_dir=self.output_dir
907-
)
908-
rhat_equil = {lam: rhat < 1.1 for lam, rhat in rhat_dict.items()}
909-
for lam, equil in rhat_equil.items():
910-
if not equil:
911-
self._logger.warning(
912-
f"The Gelman-Rubin statistic for lambda = {lam} is greater than 1.1. "
913-
"This suggests that the repeat simulations have not converged to the "
914-
"same distirbution and there is a sampling issue."
915-
)
919+
# Write out stats
920+
with open(f"{self.output_dir}/overall_stats.dat", "a") as ofile:
921+
for win in self.lam_windows:
922+
ofile.write(
923+
f"Equilibration time for lambda = {win.lam}: {win.equil_time:.3f} ns per simulation\n"
924+
)
925+
ofile.write(
926+
f"Total time simulated for lambda = {win.lam}: {win.sims[0].tot_simtime:.3f} ns per simulation\n"
927+
)
916928

917-
# Write out stats
918-
with open(f"{self.output_dir}/overall_stats.dat", "a") as ofile:
919-
for win in self.lam_windows:
920-
ofile.write(
921-
f"Equilibration time for lambda = {win.lam}: {win.equil_time:.3f} ns per simulation\n"
922-
)
923-
ofile.write(
924-
f"Total time simulated for lambda = {win.lam}: {win.sims[0].tot_simtime:.3f} ns per simulation\n"
925-
)
929+
if get_frnrg:
930+
self._logger.info(
931+
f"Overall free energy changes: {free_energies} kcal mol-1"
932+
) # type: ignore
933+
self._logger.info(f"Overall errors: {errors} kcal mol-1") # type: ignore
934+
self._logger.info(f"Analysed runs: {run_nos}")
935+
# Update the interally-stored results
936+
self._delta_g = free_energies
937+
self._delta_g_er = errors
938+
return free_energies, errors # type: ignore
939+
else:
940+
return None, None
926941

927-
if get_frnrg:
928-
self._logger.info(
929-
f"Overall free energy changes: {free_energies} kcal mol-1"
930-
) # type: ignore
931-
self._logger.info(f"Overall errors: {errors} kcal mol-1") # type: ignore
932-
self._logger.info(f"Analysed runs: {run_nos}")
933-
# Update the interally-stored results
934-
self._delta_g = free_energies
935-
self._delta_g_er = errors
936-
return free_energies, errors # type: ignore
937-
else:
938-
return None, None
942+
finally: # Ensure that all plotting resources are closed
943+
_plt.close("all")
939944

940945
def get_results_df(self, save_csv: bool = True) -> _pd.DataFrame:
941946
"""

docs/CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Change Log
66
====================
77
- Fix bug which caused somd.rst7 files in the ensemble equilibration directories to be incorrectly numbered in some cases.
88
- Fix bug which caused the output directory to be incorrectly replaced with "output" in some cases.
9+
- Ensure that all plotting resources get closed after analysis to avoid continually increasing memory usage.
910

1011
0.3.1
1112
====================

0 commit comments

Comments
 (0)