-
Notifications
You must be signed in to change notification settings - Fork 2
Optionally include hydrogens in RMSDs #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
14cbfe7
85fc6af
071e061
a31957d
5ab89b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |||
| RMSDCollection, | ||||
| RMSECollection, | ||||
| _normalize, | ||||
| get_rmsd, | ||||
| ) | ||||
| from yammbs.torsion.inputs import QCArchiveTorsionDataset | ||||
| from yammbs.torsion.models import ( | ||||
|
|
@@ -385,13 +386,16 @@ def get_rmsd( | |||
| force_field: str, | ||||
| torsion_ids: list[int] | None = None, | ||||
| skip_check: bool = False, | ||||
| include_hydrogens: bool = False, | ||||
| restraint_k: float = 0.0, | ||||
| ) -> RMSDCollection: | ||||
| """Get the RMSD summed over the torsion profile.""" | ||||
| from openff.toolkit import Molecule | ||||
|
|
||||
| if not torsion_ids: | ||||
| torsion_ids = self.get_torsion_ids() | ||||
|
|
||||
| if not skip_check: | ||||
| if skip_check is None: | ||||
| # TODO: Copy this into each get_* method? | ||||
| LOGGER.info("Calling optimize_mm from inside of get_log_sse.") | ||||
| self.optimize_mm(force_field=force_field, restraint_k=restraint_k) | ||||
|
|
@@ -407,7 +411,21 @@ def get_rmsd( | |||
| allow_undefined_stereo=True, | ||||
| ) | ||||
|
|
||||
| rmsds.append(RMSD.from_data(torsion_id, molecule, qm_points, mm_points)) | ||||
| rmsds.append( | ||||
| RMSD( | ||||
| id=torsion_id, | ||||
| rmsd=sum( | ||||
| get_rmsd( | ||||
| molecule, | ||||
| qm_points[key], | ||||
| mm_points[key], | ||||
| include_hydrogens=include_hydrogens, | ||||
| ) | ||||
| for key in qm_points | ||||
| ) | ||||
| / len(qm_points), | ||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this detail - without which there were numeric differences, i.e. 24x for a scan of 24 grid points
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just noting that
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! Would the same suggestion also apply to yammbs/yammbs/torsion/analysis.py Line 122 in a31957d
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is not strictly related to the intent of the PR I'm going to split it out into a separate issue: #186 |
||||
| ), | ||||
| ) | ||||
|
|
||||
| return rmsds | ||||
|
|
||||
|
|
@@ -541,6 +559,7 @@ def get_outputs(self) -> MinimizedTorsionDataset: | |||
|
|
||||
| def get_metrics( | ||||
| self, | ||||
| include_hydrogens: bool = False, | ||||
| force_fields: Iterable[str] | None = None, | ||||
| js_temperature: float = 500.0, | ||||
| restraint_k: float = 0.0, | ||||
|
|
@@ -550,6 +569,8 @@ def get_metrics( | |||
|
|
||||
| Parameters | ||||
| ---------- | ||||
| include_hydrogens | ||||
| Whether RMSDs should include hydrogens (all-atom) or not (heavy atom) | ||||
| force_fields : Iterable[str] | None | ||||
| Iterable of force fields to compute metrics for. If None, compute for all available. | ||||
| js_temperature : float | ||||
|
|
@@ -583,21 +604,19 @@ def get_metrics( | |||
| # TODO: Optimize this for speed | ||||
| for force_field in force_fields: | ||||
| rmses = self.get_rmse(force_field=force_field, skip_check=True).to_dataframe() | ||||
| rmsds = self.get_rmsd(force_field=force_field, skip_check=True).to_dataframe() | ||||
| mean_errors = self.get_mean_error(force_field=force_field, skip_check=True).to_dataframe() | ||||
| js_distances = self.get_js_distance( | ||||
| force_field=force_field, | ||||
| skip_check=True, | ||||
| temperature=js_temperature, | ||||
| ).to_dataframe() | ||||
|
|
||||
| dataframe = rmses.join(rmsds).join(mean_errors).join(js_distances) | ||||
| dataframe = rmses.join(mean_errors).join(js_distances) | ||||
|
|
||||
| dataframe = dataframe.replace({pandas.NA: numpy.nan}) | ||||
|
|
||||
| metrics.metrics[force_field] = { | ||||
| id: Metric( # type: ignore[misc] | ||||
| rmsd=row["rmsd"], | ||||
| rmse=row["rmse"], | ||||
| mean_error=row["mean_error"], | ||||
| js_distance=(row["js_distance"], row["js_temperature"]), | ||||
|
|
@@ -829,7 +848,6 @@ def get_summary_df( | |||
| rows = [] | ||||
| metrics = self.get_metrics().metrics | ||||
| metrics_to_plot = { | ||||
| "RMSD / A": lambda x: x.rmsd, | ||||
| "RMSE / kcal mol-1": lambda x: x.rmse, | ||||
| "Mean Error / kcal mol-1": lambda x: x.mean_error, | ||||
| "JS Distance": lambda x: x.js_distance[0], | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(not blocking) It's quite confusing that
get_rmsdis the name of both a function and a method (with different purposes/signatures) - could be a good thing to clean up/disambiguate in the code