diff --git a/app.py b/app.py index ff2dcc1a..58f9563b 100644 --- a/app.py +++ b/app.py @@ -24,6 +24,7 @@ from mlipaudit.dihedral_scan import DihedralScanBenchmark from mlipaudit.folding_stability import FoldingStabilityBenchmark from mlipaudit.io import load_benchmark_results_from_disk +from mlipaudit.noncovalent_interactions import NoncovalentInteractionsBenchmark from mlipaudit.reactivity import ReactivityBenchmark from mlipaudit.ring_planarity import RingPlanarityBenchmark from mlipaudit.small_molecule_minimization import ( @@ -36,6 +37,7 @@ conformer_selection_page, dihedral_scan_page, folding_stability_page, + noncovalent_interactions_page, reactivity_page, ring_planarity_page, small_molecule_minimization_page, @@ -50,6 +52,7 @@ BENCHMARKS: list[type[Benchmark]] = [ ConformerSelectionBenchmark, DihedralScanBenchmark, + NoncovalentInteractionsBenchmark, TautomersBenchmark, RingPlanarityBenchmark, SmallMoleculeMinimizationBenchmark, @@ -104,7 +107,14 @@ def _func(): title="Tautomers", url_path="tautomers", ) - +noncovalent_interactions = st.Page( + functools.partial( + noncovalent_interactions_page, + data_func=_data_func_from_key("noncovalent_interactions", data), + ), + title="Noncovalent Interactions", + url_path="noncovalent_interactions", +) ring_planarity = st.Page( functools.partial( ring_planarity_page, @@ -174,6 +184,7 @@ def _func(): conformer_selection, dihedral_scan, tautomers, + noncovalent_interactions, ring_planarity, small_molecule_minimization, bond_length_distribution, diff --git a/app_data/noncovalent_interactions/n_systems_per_subset.json b/app_data/noncovalent_interactions/n_systems_per_subset.json new file mode 100644 index 00000000..4691b6cb --- /dev/null +++ b/app_data/noncovalent_interactions/n_systems_per_subset.json @@ -0,0 +1,46 @@ +{ + "Dispersion: NobleGases": 140, + "Dispersion: PS": 103, + "Dispersion: HCNO": 79, + "Dispersion: Halogens": 94, + "Dispersion: Boron": 26, + "Hydrogen bonds: NH-O": 65, + "Hydrogen bonds: noHB": 113, + "Hydrogen bonds: CH-N": 20, + "Hydrogen bonds: OH-O": 60, + "Hydrogen bonds: NH-N": 53, + "Hydrogen bonds: OH-N": 45, + "Hydrogen bonds: CH-O": 19, + "Hydrogen bonds: XH-Cl": 32, + "Hydrogen bonds: XH-I": 19, + "Hydrogen bonds: XH-O": 51, + "Hydrogen bonds: XH-S": 54, + "Hydrogen bonds: XH-P": 52, + "Hydrogen bonds: XH-N": 34, + "Hydrogen bonds: XH-F": 41, + "Hydrogen bonds: XH-Br": 17, + "Ionic hydrogen bonds: NH(+)-O": 15, + "Ionic hydrogen bonds: OH-O(-)": 15, + "Ionic hydrogen bonds: OH-N(-)": 3, + "Ionic hydrogen bonds: CH-O(-)": 15, + "Ionic hydrogen bonds: NH(+)-N": 15, + "Ionic hydrogen bonds: NH-N(-)": 6, + "Ionic hydrogen bonds: CH-N(-)": 3, + "Ionic hydrogen bonds: NH-O(-)": 15, + "Ionic hydrogen bonds: NH-C(-)": 5, + "Ionic hydrogen bonds: NH(+)-C": 4, + "Ionic hydrogen bonds: OH-C(-)": 2, + "Ionic hydrogen bonds: OH(+)-O": 1, + "Ionic hydrogen bonds: CH-C(-)": 1, + "Repulsive contacts: NobleGases": 180, + "Repulsive contacts: Halogens": 235, + "Repulsive contacts: HCNO": 170, + "Repulsive contacts: PS": 154, + "Sigma hole: Cl": 29, + "Sigma hole: P": 33, + "Sigma hole: Br": 36, + "Sigma hole: S": 31, + "Sigma hole: Se": 44, + "Sigma hole: I": 42, + "Sigma hole: As": 35 +} diff --git a/docs/source/api_reference/index.rst b/docs/source/api_reference/index.rst index 146768e5..5b065e01 100644 --- a/docs/source/api_reference/index.rst +++ b/docs/source/api_reference/index.rst @@ -11,6 +11,7 @@ API reference benchmark small_molecules/conformer_selection small_molecules/dihedral_scan + small_molecules/noncovalent_interactions small_molecules/tautomers small_molecules/ring_planarity small_molecules/minimization diff --git a/docs/source/api_reference/small_molecules/noncovalent_interactions.rst b/docs/source/api_reference/small_molecules/noncovalent_interactions.rst new file mode 100644 index 00000000..50283d1b --- /dev/null +++ b/docs/source/api_reference/small_molecules/noncovalent_interactions.rst @@ -0,0 +1,22 @@ +.. _noncovalent_interactions_api: + +Noncovalent Interactions +======================== + +.. module:: mlipaudit.noncovalent_interactions.noncovalent_interactions + +.. autoclass:: NoncovalentInteractionsBenchmark + + .. automethod:: __init__ + + .. automethod:: run_model + + .. automethod:: analyze + +.. autoclass:: NoncovalentInteractionsResult + +.. autoclass:: NoncovalentInteractionsSystemResult + +.. autoclass:: NoncovalentInteractionsModelOutput + +.. autoclass:: NoncovalentInteractionsSystemModelOutput diff --git a/docs/source/benchmarks/small_molecules/img/pes.png b/docs/source/benchmarks/small_molecules/img/pes.png index 710b47e0..08229585 100644 Binary files a/docs/source/benchmarks/small_molecules/img/pes.png and b/docs/source/benchmarks/small_molecules/img/pes.png differ diff --git a/docs/source/benchmarks/small_molecules/index.rst b/docs/source/benchmarks/small_molecules/index.rst index b5b0daee..a5314e45 100644 --- a/docs/source/benchmarks/small_molecules/index.rst +++ b/docs/source/benchmarks/small_molecules/index.rst @@ -12,6 +12,7 @@ sampling, stability, and interactions with other molecules. Conformer Selection Dihedral scan + Noncovalent Interactions Ring planarity Tautomers Minimization diff --git a/docs/source/benchmarks/small_molecules/noncovalent_interactions.rst b/docs/source/benchmarks/small_molecules/noncovalent_interactions.rst new file mode 100644 index 00000000..36d0a21d --- /dev/null +++ b/docs/source/benchmarks/small_molecules/noncovalent_interactions.rst @@ -0,0 +1,74 @@ +.. _noncovalent_interactions: + +Non-covalent interactions +========================= + +Purpose +------- +This benchmark tests if the **MLIP** can reproduce interaction energies of molecular complexes driven by non-covalent interactions. +Non-covalent interactions are of highest importance for the structure and function of every biological molecule. This benchmark +assesses a broad range of interaction types: London dispersion, hydrogen bonds, ionic hydrogen bonds, repulsive contacts and sigma +hole interactions. + + +Description +----------- +The benchmark runs energy inference on all structures of the distance scans of bi-molecular complexes in the dataset. The key +metric is the **RMSE of the interaction energy**, which is the minimum of the energy well in the distance scan, relative to the +energy of the dissociated complex - compared to the reference data. For repulsive contacts, the maximum of the energy profile is +used instead. Some of the molecular complexes in the benchmark dataset contain exotic elements (see *dataset section*). In case that +the **MLIP** has never seen an element of a molecular complex, this complex will be skipped in the benchmark. + +.. list-table:: + :widths: 25 45 + :header-rows: 0 + + * - .. figure:: img/butadiene_diazomethane.png + :width: 100% + :align: center + :figclass: align-center + + - .. figure:: img/pes.png + :width: 100% + :align: center + :figclass: align-center + +Dataset +------- +This benchmark uses the datasets from the `NCI Atlas `_, with dissociation energy profiles. +These datasets contain **QM** optimized geometries, along with **CCSD(T)/CBS** level interaction energies. The molecular complexes of +these datasets contain typical organic small molecules, but also more exotic species and elements. Here is a summary of the +datasets used in this benchmark: + +.. list-table:: NCI Atlas Datasets + :widths: 20 30 50 + :header-rows: 1 + + * - Dataset Name + - Type of interaction + - Subsets + * - D442x10 + - London dispersion + - Noble Gases, Boron, HCNO, Halogens + * - HB375x10 + - Hydrogen bonds + - OH-N, OH-O, OH-C, NH-N, NH-O, … + * - HB300SPXx10 + - Hydrogen bonds extended to S, P and halogens + - XH-S, XH-P, XH-Cl, XH-Br + * - IHB100x10 + - Ionic hydrogen bonds + - O, N, C with cationic donors and anionic acceptors + * - R739x5 + - Repulsive contacts + - HCNO, halogens, PS + * - SH250x10 + - Sigma hole interactions + - P, S, Br, Cl, I + +Interpretation +-------------- +The **RMSE** of the interaction energies should be **as low as possible**. This metric is likely to be very different for the different +interaction types and data subsets. The **RMSE** in interaction error **should be compared per interaction type** and then in a more +fine-grained visualization for the data subsets to identify areas of weakness for the **MLIP**. Within these areas of weakness, +individual dissociation energy profiles can be visually inspected to see how they compare to the reference. diff --git a/src/mlipaudit/main.py b/src/mlipaudit/main.py index 83c5622b..3d2045f2 100644 --- a/src/mlipaudit/main.py +++ b/src/mlipaudit/main.py @@ -26,6 +26,7 @@ from mlipaudit.dihedral_scan import DihedralScanBenchmark from mlipaudit.folding_stability import FoldingStabilityBenchmark from mlipaudit.io import write_benchmark_results_to_disk +from mlipaudit.noncovalent_interactions import NoncovalentInteractionsBenchmark from mlipaudit.reactivity import ReactivityBenchmark from mlipaudit.ring_planarity import RingPlanarityBenchmark from mlipaudit.small_molecule_minimization import SmallMoleculeMinimizationBenchmark @@ -38,6 +39,7 @@ BENCHMARKS = [ ConformerSelectionBenchmark, TautomersBenchmark, + NoncovalentInteractionsBenchmark, DihedralScanBenchmark, RingPlanarityBenchmark, SmallMoleculeMinimizationBenchmark, diff --git a/src/mlipaudit/noncovalent_interactions/__init__.py b/src/mlipaudit/noncovalent_interactions/__init__.py new file mode 100644 index 00000000..0605f499 --- /dev/null +++ b/src/mlipaudit/noncovalent_interactions/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 InstaDeep Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mlipaudit.noncovalent_interactions.noncovalent_interactions import ( + NoncovalentInteractionsBenchmark, + NoncovalentInteractionsResult, +) diff --git a/src/mlipaudit/noncovalent_interactions/noncovalent_interactions.py b/src/mlipaudit/noncovalent_interactions/noncovalent_interactions.py new file mode 100644 index 00000000..af612e19 --- /dev/null +++ b/src/mlipaudit/noncovalent_interactions/noncovalent_interactions.py @@ -0,0 +1,428 @@ +# Copyright 2025 InstaDeep Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging + +import numpy as np +from ase import Atoms, units +from mlip.inference import run_batched_inference +from pydantic import BaseModel, TypeAdapter + +from mlipaudit.benchmark import Benchmark, BenchmarkResult, ModelOutput +from mlipaudit.utils import skip_unallowed_elements + +logger = logging.getLogger("mlipaudit") + +NCI_ATLAS_FILENAME = "NCI_Atlas.json" + +REPULSIVE_DATASETS = ["NCIA_R739x5"] + +DATASET_RAW_TO_DESCRIPTIVE = { + "D442x10": "Dispersion", + "HB375x10": "Hydrogen bonds", + "HB300SPXx10": "Hydrogen bonds", + "IHB100x10": "Ionic hydrogen bonds", + "R739x5": "Repulsive contacts", + "SH250x10": "Sigma hole", +} + +GROUP_RAW_TO_DESCRIPTIVE = { + "CH-Oa": "CH-O(-)", + "CH-Na": "CH-N(-)", + "CH-Ca": "CH-C(-)", + "NH-Oa": "NH-O(-)", + "NH-Na": "NH-N(-)", + "NH-Ca": "NH-C(-)", + "OH-Oa": "OH-O(-)", + "OH-Na": "OH-N(-)", + "OH-Ca": "OH-C(-)", + "NHk-O": "NH(+)-O", + "NHk-C": "NH(+)-C", + "NHk-N": "NH(+)-N", + "OHk-O": "OH(+)-O", + "B": "Boron", +} + + +class NoncovalentInteractionsSystemResult(BenchmarkResult): + """Results object for the noncovalent interactions benchmark for a single + bi-molecular system. + + Attributes: + system_id: The system id. + structure_name: The structure name. + dataset: The dataset name. + group: The group name. + reference_interaction_energy: The reference interaction energy. + mlip_interaction_energy: The MLIP interaction energy. + deviation: The deviation between the reference and MLIP interaction + energies. + reference_energy_profile: The reference energy profile. + energy_profile: The MLIP energy profile. + distance_profile: The distance profile. + """ + + system_id: str + structure_name: str + dataset: str + group: str + reference_interaction_energy: float + mlip_interaction_energy: float + deviation: float + reference_energy_profile: list[float] + energy_profile: list[float] + distance_profile: list[float] + + +class NoncovalentInteractionsResult(BenchmarkResult): + """Results object for the noncovalent interactions benchmark. + + Attributes: + systems: The systems results. + n_skipped_unallowed_elements: The number of structures skipped due to unallowed + elements. + rmse_interaction_energy_all: The RMSE of the interaction energy over all + tested systems. + rmse_interaction_energy_subsets: The RMSE of the interaction energy per subset. + mae_interaction_energy_subsets: The MAE of the interaction energy per subset. + rmse_interaction_energy_datasets: The RMSE of the interaction energy per + dataset. + mae_interaction_energy_datasets: The MAE of the interaction energy per + dataset. + """ + + systems: list[NoncovalentInteractionsSystemResult] + n_skipped_unallowed_elements: int = 0 + rmse_interaction_energy_all: float + rmse_interaction_energy_subsets: dict[str, float] + mae_interaction_energy_subsets: dict[str, float] + rmse_interaction_energy_datasets: dict[str, float] + mae_interaction_energy_datasets: dict[str, float] + + +class MolecularSystem(BaseModel): + """Dataclass for a bi-molecular system. + + Attributes: + system_id: The system id. + system_name: The system name. + dataset_name: The dataset name. + group: The group name. + atoms: The atoms in the system. + coords: The coordinates of the atoms in the system. + distance_profile: The distance profile of the interaction. + interaction_energy_profile: The interaction energy profile of the interaction. + """ + + system_id: str + system_name: str + dataset_name: str + group: str + atoms: list[str] + coords: list[list[list[float]]] + distance_profile: list[float] + interaction_energy_profile: list[float] + + +Systems = TypeAdapter(dict[str, MolecularSystem]) + + +class NoncovalentInteractionsSystemModelOutput(ModelOutput): + """Model output for a bi-molecular system.""" + + system_id: str + energy_profile: list[float] + + +class NoncovalentInteractionsModelOutput(ModelOutput): + """Model output for the noncovalent interactions benchmark.""" + + systems: list[NoncovalentInteractionsSystemModelOutput] + n_skipped_unallowed_elements: int + + +def compute_total_interaction_energy( + distance_profile: list[float], + interaction_energy_profile: list[float], + repulsive: bool = False, +) -> float: + """Compute the total interaction energy. + + This function will use the minimum energy value of the interaction energy profile + as the bottom of the energy well and the energy value associated with the highest + distance as the energy of the dissociated structure baseline. + + Args: + distance_profile: The distance profile of the interaction, meaning a series of + distances between the two interacting molecules. + interaction_energy_profile: The interaction energy profile of the interaction, + meaning a series of interaction energies between the two interacting + molecules at the distances specified in the distance profile. + repulsive: Whether to use the maximum energy value of the interaction energy + profile as the bottom of the energy well. Defaults to False. + + Returns: + The total interaction energy. + """ + max_energy = np.max(interaction_energy_profile) + min_energy = np.min(interaction_energy_profile) + max_distance_idx = np.argmax(distance_profile) + dissociated_energy = interaction_energy_profile[max_distance_idx] + + if repulsive: + return max_energy - dissociated_energy + else: + return min_energy - dissociated_energy + + +def _descriptive_data_subset_name( + dataset_name: str, + group: str, +) -> tuple[str, str]: + """Return a descriptive name for a dataset subset.""" + dataset_name = dataset_name.replace("NCIA_", "") + + if dataset_name in DATASET_RAW_TO_DESCRIPTIVE: + dataset_name_descriptive = DATASET_RAW_TO_DESCRIPTIVE[dataset_name] + else: + dataset_name_descriptive = dataset_name + + if group in GROUP_RAW_TO_DESCRIPTIVE: + group_descriptive = GROUP_RAW_TO_DESCRIPTIVE[group] + else: + group_descriptive = group + + return dataset_name_descriptive, group_descriptive + + +def _compute_metrics_from_system_results( + results: list[NoncovalentInteractionsSystemResult], + n_skipped_unallowed_elements: int, +) -> NoncovalentInteractionsResult: + """Compute deviation metrics from the system results. + + Args: + results: The system results. + n_skipped_unallowed_elements: The number of structures skipped due to unallowed + elements. + + Returns: + A `NoncovalentInteractionsResult` object with the benchmark results. + """ + deviation_per_subset: dict[str, list[float]] = {} + deviation_per_dataset: dict[str, list[float]] = {} + for system_results in results: + dataset_name = system_results.dataset + group = system_results.group + data_subset_name = f"{dataset_name}: {group}" + + if data_subset_name not in deviation_per_subset: + deviation_per_subset[data_subset_name] = [] + deviation_per_subset[data_subset_name].append(system_results.deviation) + if dataset_name not in deviation_per_dataset: + deviation_per_dataset[dataset_name] = [] + deviation_per_dataset[dataset_name].append(system_results.deviation) + + rmse_interaction_energy_subsets = {} + mae_interaction_energy_subsets = {} + rmse_interaction_energy_datasets = {} + mae_interaction_energy_datasets = {} + for data_subset_name, deviations in deviation_per_subset.items(): + rmse_interaction_energy_subsets[data_subset_name] = np.sqrt( + np.mean(np.array(deviations) ** 2) + ) + mae_interaction_energy_subsets[data_subset_name] = np.mean( + np.abs(np.array(deviations)) + ) + for dataset_name_descriptive, deviations in deviation_per_dataset.items(): + rmse_interaction_energy_datasets[dataset_name_descriptive] = np.sqrt( + np.mean(np.array(deviations) ** 2) + ) + mae_interaction_energy_datasets[dataset_name_descriptive] = np.mean( + np.abs(np.array(deviations)) + ) + + all_deviations = [system_results.deviation for system_results in results] + rmse_interaction_energy_all = np.sqrt(np.mean(np.array(all_deviations) ** 2)) + + return NoncovalentInteractionsResult( + systems=results, + n_skipped_unallowed_elements=n_skipped_unallowed_elements, + rmse_interaction_energy_all=rmse_interaction_energy_all, + rmse_interaction_energy_subsets=rmse_interaction_energy_subsets, + mae_interaction_energy_subsets=mae_interaction_energy_subsets, + rmse_interaction_energy_datasets=rmse_interaction_energy_datasets, + mae_interaction_energy_datasets=mae_interaction_energy_datasets, + ) + + +class NoncovalentInteractionsBenchmark(Benchmark): + """Benchmark for noncovalent interactions. + + Attributes: + name: The unique benchmark name that should be used to run the benchmark + from the CLI and that will determine the output folder name for the result + file. The name is ``noncovalent_interactions``. + result_class: A reference to the type of `BenchmarkResult` that will determine + the return type of ``self.analyze()``. The result class type is + ``NoncovalentInteractionsResult``. + """ + + name = "noncovalent_interactions" + result_class = NoncovalentInteractionsResult + + def run_model(self) -> None: + """Run a single point energy calculation for each structure. + + The calculation is performed as a batched inference using the mlip force field + directly. This benchmark will skip structures with unseen elements. + """ + skipped_structures = skip_unallowed_elements( + self.force_field, + [ + (structure.system_id, structure.atoms) + for structure in self._nci_atlas_data.values() + ], + ) + + atoms_all: list[Atoms] = [] + atoms_all_idx_map: dict[str, list[int]] = {} + i = 0 + + for structure in self._nci_atlas_data.values(): + if structure.system_id in skipped_structures: + continue + else: + atoms_all_idx_map[structure.system_id] = [] + for coord in structure.coords: + atoms = Atoms( + symbols=structure.atoms, + positions=coord, + ) + atoms_all.append(atoms) + atoms_all_idx_map[structure.system_id].append(i) + i += 1 + + logger.info("Running energy calculations...") + if skipped_structures: + logger.info( + "Skipping %s structures because of unallowed elements.", + len(skipped_structures), + ) + + predictions = run_batched_inference( + atoms_all, + self.force_field, + batch_size=128, + ) + + model_output_systems = [] + for system_id, indices in atoms_all_idx_map.items(): + predictions_structure = [predictions[i] for i in indices] + energy_profile: list[float] = [ + prediction.energy for prediction in predictions_structure + ] + model_output_systems.append( + NoncovalentInteractionsSystemModelOutput( + system_id=system_id, + energy_profile=energy_profile, + ) + ) + + self.model_output = NoncovalentInteractionsModelOutput( + systems=model_output_systems, + n_skipped_unallowed_elements=len(skipped_structures), + ) + + def analyze(self) -> NoncovalentInteractionsResult: + """Calculate the total interaction energies and their abs. deviations. + + This calculation will yield the MLIP total interaction energy and energy profile + and the abs. deviation compared to the reference data. + + Returns: + A `NoncovalentInteractionsResult` object with the benchmark results. + + Raises: + RuntimeError: If called before `run_model()`. + """ + if self.model_output is None: + raise RuntimeError("Must call run_model() first.") + + results = [] + for system in self.model_output.systems: + system_id = system.system_id + mlip_energy_profile = [ + energy / (units.kcal / units.mol) for energy in system.energy_profile + ] + distance_profile = self._nci_atlas_data[system_id].distance_profile + ref_energy_profile = self._nci_atlas_data[ + system_id + ].interaction_energy_profile + + dataset_name = self._nci_atlas_data[system_id].dataset_name + repulsive = dataset_name in REPULSIVE_DATASETS + + ref_interaction_energy = compute_total_interaction_energy( + distance_profile, ref_energy_profile, repulsive=repulsive + ) + mlip_interaction_energy = compute_total_interaction_energy( + distance_profile, mlip_energy_profile, repulsive=repulsive + ) + deviation = mlip_interaction_energy - ref_interaction_energy + + group = self._nci_atlas_data[system_id].group + + results.append( + NoncovalentInteractionsSystemResult( + system_id=system_id, + structure_name=self._nci_atlas_data[system_id].system_name, + dataset=_descriptive_data_subset_name( + dataset_name, + group, + )[0], + group=_descriptive_data_subset_name( + dataset_name, + group, + )[1], + reference_interaction_energy=ref_interaction_energy, + mlip_interaction_energy=mlip_interaction_energy, + deviation=deviation, + reference_energy_profile=ref_energy_profile, + energy_profile=mlip_energy_profile, + distance_profile=distance_profile, + ) + ) + + return _compute_metrics_from_system_results( + results, self.model_output.n_skipped_unallowed_elements + ) + + @functools.cached_property + def _nci_atlas_data(self) -> dict[str, MolecularSystem]: + with open( + self.data_input_dir / self.name / NCI_ATLAS_FILENAME, + "r", + encoding="utf-8", + ) as f: + nci_atlas_data = Systems.validate_json(f.read()) + + if self.fast_dev_run: + nci_atlas_data = { + "1.03.03": nci_atlas_data["1.03.03"], + "1.01.01": nci_atlas_data["1.01.01"], + } + + return nci_atlas_data diff --git a/src/mlipaudit/ui/__init__.py b/src/mlipaudit/ui/__init__.py index e6fa5966..6178935a 100644 --- a/src/mlipaudit/ui/__init__.py +++ b/src/mlipaudit/ui/__init__.py @@ -16,6 +16,7 @@ from mlipaudit.ui.conformer_selection import conformer_selection_page from mlipaudit.ui.dihedral_scan import dihedral_scan_page from mlipaudit.ui.folding_stability import folding_stability_page +from mlipaudit.ui.noncovalent_interactions import noncovalent_interactions_page from mlipaudit.ui.reactivity import reactivity_page from mlipaudit.ui.ring_planarity import ring_planarity_page from mlipaudit.ui.small_molecule_minimization import small_molecule_minimization_page diff --git a/src/mlipaudit/ui/noncovalent_interactions.py b/src/mlipaudit/ui/noncovalent_interactions.py new file mode 100644 index 00000000..ce72514d --- /dev/null +++ b/src/mlipaudit/ui/noncovalent_interactions.py @@ -0,0 +1,441 @@ +# Copyright 2025 InstaDeep Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import Callable, TypeAlias + +import altair as alt +import numpy as np +import pandas as pd +import streamlit as st +from ase import units + +from mlipaudit.noncovalent_interactions import NoncovalentInteractionsResult + +NCI_ATLAS_DIR = Path.cwd() / "app_data" / "noncovalent_interactions" + +ModelName: TypeAlias = str +BenchmarkResultForMultipleModels: TypeAlias = dict[ + ModelName, NoncovalentInteractionsResult +] + + +def _process_data_into_rmse_per_dataset( + data: BenchmarkResultForMultipleModels, + model_select: list[str], + conversion_factor: float, + subset: bool = False, +) -> pd.DataFrame: + """Process the data into a dictionary of RMSE per subset. + + Args: + data: The benchmark results. + model_select: The models to include in the DataFrame. + conversion_factor: The conversion factor for the energy unit. + subset: Whether to process into RMSE per dataset or per subset. + + Returns: + A pandas DataFrame with the RMSE per subset or dataset. + """ + converted_data = [] + for model_name, results in data.items(): + if ( + len(model_select) > 0 + and model_name in model_select + or len(model_select) == 0 + ): + if subset: + converted_data.append(results.rmse_interaction_energy_subsets) + else: + converted_data.append(results.rmse_interaction_energy_datasets) + + df = pd.DataFrame(converted_data, index=model_select) + df = df.dropna(axis=1, how="all") + df = df.map(lambda x: x * conversion_factor) + return df + + +def _get_best_model_name( + data: BenchmarkResultForMultipleModels, + model_select: list[str], +) -> str: + """Get the name of the best model based on lowest RMSE over all tested systems. + + Args: + data: The benchmark results. + model_select: The models to include in the DataFrame. + + Returns: + The name of the best model. + """ + avg_rmse_per_model = {} + for model_name, results in data.items(): + if ( + len(model_select) > 0 + and model_name in model_select + or len(model_select) == 0 + ): + avg_rmse_per_model[model_name] = results.rmse_interaction_energy_all + return min(avg_rmse_per_model, key=lambda k: avg_rmse_per_model[k]) + + +def _get_energy_profiles_for_subset( + data: BenchmarkResultForMultipleModels, + subset: str, + model_select: list[str], + conversion_factor: float, +) -> dict[ModelName, dict[str, tuple[list[float], list[float]]]]: + """Get the energy profiles for a subset. + + Args: + data: The benchmark results. + subset: The subset to get the energy profiles for. + model_select: The models to include in the DataFrame. + conversion_factor: The conversion factor for the energy unit. + + Returns: + A dictionary of energy profiles for the subset. + """ + energy_profiles_per_model: dict[ + ModelName, dict[str, tuple[list[float], list[float]]] + ] = {} + for model_name, results in data.items(): + if ( + len(model_select) > 0 + and model_name in model_select + or len(model_select) == 0 + ): + energy_profiles_per_model[model_name] = {} + energy_profiles_per_model["Reference"] = {} + + for system_results in results.systems: + system_subset_name = f"{system_results.dataset}: {system_results.group}" + + if system_subset_name == subset: + energy_profile = system_results.energy_profile + ref_energy_profile = system_results.reference_energy_profile + distance_profile = system_results.distance_profile + + dist_idx_sorted = np.argsort(distance_profile) + max_dist_idx = np.argmax(distance_profile) + + energy_profile_sorted = [ + float(energy) * conversion_factor + - float(energy_profile[max_dist_idx]) * conversion_factor + for energy in np.array(energy_profile)[dist_idx_sorted] + ] + ref_energy_profile_sorted = [ + float(energy) * conversion_factor + - float(ref_energy_profile[max_dist_idx]) * conversion_factor + for energy in np.array(ref_energy_profile)[dist_idx_sorted] + ] + + energy_profiles_per_model[model_name][ + system_results.structure_name + ] = ( + sorted(distance_profile), + energy_profile_sorted, + ) + energy_profiles_per_model["Reference"][ + system_results.structure_name + ] = ( + sorted(distance_profile), + ref_energy_profile_sorted, + ) + return energy_profiles_per_model + + +def noncovalent_interactions_page( + data_func: Callable[[], BenchmarkResultForMultipleModels], +) -> None: + """Page for the visualization app for the noncovalent interactions benchmark. + + Args: + data_func: A data function that delivers the results on request. It does + not take any arguments and returns a dictionary with model names as + keys and the benchmark results objects as values. + """ + st.markdown("# Non-covalent interactions") + st.sidebar.markdown("# Non-covalent interactions") + + st.markdown( + "This benchmark tests if the MLIPs can reproduce interaction energies of " + "molecular complexes driven by non-covalent interactions. The benchmark " + "uses six datasets from the NCI Atlas: D442x10 (London dispersion), " + "HB375x10 (hydrogen bonds), HB300SPXx10 (hydrogen bonds extended to S, P " + "and halogens), IHB100x10 (ionic hydrogen bonds), R739x5 (repulsive " + "contacts) and SH250x10 (sigma hole interaction). These contain QM " + "optimized geometries of distance scans of bi-molecular complexes, where " + "the two molecules interact via non-covalent interactions with associated " + "energies. Each dataset contains multiple subsets which specify certain " + "categories of interactions. The key metric is the RMSE of the interaction " + "energies. These are defined as the bottom of the potential energy curve " + "minus the energy of the separated molecules. For repulsive contacts, the " + "maximum of the energy profile is used instead." + ) + st.markdown( + "For more information see the [docs](https://instadeepai.github.io/mlipaudit-open/benchmarks/small_molecules/noncovalent_interactions.html)" + " and the [NCI Atlas webpage](http://www.nciatlas.org/)." + ) + st.markdown( + "The benchmark skips all structures that contain elements which were completely" + " absent from the MLIP's training data. Scroll to the end of the page to find " + "an overview of how many structures were skipped for each dataset." + ) + + with st.sidebar.container(): + unit_selection = st.selectbox( + "Select an energy unit:", + ["kcal/mol", "eV"], + ) + + # Set conversion factor based on selection + if unit_selection == "kcal/mol": + conversion_factor = 1.0 + unit_label = "kcal/mol" + else: + conversion_factor = units.kcal / units.mol + unit_label = "eV" + + data = data_func() + + model_names = list(set(data.keys())) + model_select = st.sidebar.multiselect( + "Select model(s)", model_names, default=model_names + ) + selected_models = model_select if model_select else model_names + + df = _process_data_into_rmse_per_dataset( + data, + selected_models, + conversion_factor, + subset=False, + ) + + st.markdown("## Best model summary") + best_model_name = _get_best_model_name( + data, + selected_models, + ) + + st.markdown( + f"The best model **{best_model_name}** based on lowest RMSE " + "over all tested systems." + ) + + cols_metrics = st.columns(len(df.columns)) + for i, dataset in enumerate(df.columns): + with cols_metrics[i]: + st.metric(dataset, f"{float(df.loc[best_model_name, dataset]):.3f}") + + st.markdown("## Summary statistics") + st.markdown( + "This table shows the average RMSE of the interaction energies for each " + "interaction type and model. For a more fine-grained breakdown of " + "interaction types, see the bar plot below." + ) + st.dataframe(df) + + st.markdown("## RMSE per data subset") + df_subset = _process_data_into_rmse_per_dataset( + data, + selected_models, + conversion_factor, + subset=True, + ) + + # Reshape dataframe for Altair plotting + df_melted = ( + df_subset.reset_index() + .melt(id_vars=["index"], var_name="Interaction type", value_name="RMSE") + .rename(columns={"index": "Model name"}) + ) + + # Create horizontal bar plot + selection = alt.selection_point(fields=["Model name"], bind="legend") + + chart = ( + alt.Chart(df_melted) + .mark_bar() + .add_params(selection) + .encode( + y=alt.Y("Interaction type:N", title="Interaction Type"), + x=alt.X("RMSE:Q", title="RMSE"), + color=alt.Color("Model name:N", title="Model Name"), + opacity=alt.condition(selection, alt.value(0.8), alt.value(0.3)), + tooltip=["Model name:N", "Interaction type:N", "RMSE:Q"], + ) + .resolve_scale(color="independent") + .properties( + width=800, + height=max(len(df_melted) * 50, 400), + ) + ) + + st.altair_chart(chart, use_container_width=True) + + st.markdown("## Energy profiles") + + st.markdown( + "The energy profiles below show the energy of the complex as a " + "function of the distance between the two molecules. For more " + "information about the molecular complexes indicated by the " + "structure names, browse the datasets on the [NCI Atlas webpage](http://www.nciatlas.org/)." + ) + + available_subsets: set[str] = set() + for _, results in data.items(): + available_subsets.update(results.rmse_interaction_energy_subsets.keys()) + + dataset_selector_list = [] + subset_selector_list = [] + for subset_name in available_subsets: + dataset_selector_list.append(subset_name.split(":")[0].strip()) + subset_selector_list.append(subset_name.split(":")[1].strip()) + + dataset_selector = st.selectbox( + "Select a dataset", + dataset_selector_list, + ) + subset_selector = st.selectbox( + "Select a subset", + subset_selector_list, + ) + model_selector_sorting = st.selectbox( + "Select a model for sorting by interaction energy error", + model_select, + ) + + selected_subset = f"{dataset_selector}: {subset_selector}" + + energy_profiles_per_model = _get_energy_profiles_for_subset( + data, + selected_subset, + model_select, + conversion_factor, + ) + + # Get structure names and sort them by deviation for the selected model + if ( + model_selector_sorting in energy_profiles_per_model + and energy_profiles_per_model[model_selector_sorting] + ): + structure_names = list(energy_profiles_per_model[model_selector_sorting].keys()) + + # Create a mapping from structure_name to deviation for sorting + structure_to_deviation = {} + for system_result in data[model_selector_sorting].systems: + if ( + f"{system_result.dataset}: {system_result.group}" == selected_subset + and system_result.structure_name in structure_names + ): + structure_to_deviation[system_result.structure_name] = abs( + system_result.deviation + ) + + structure_names.sort(key=lambda x: structure_to_deviation.get(x, 0)) + + # Add structure selection dropdown + selected_structure = st.selectbox( + "Select a structure to plot", + structure_names, + ) + + # Create DataFrame for Altair plotting + plot_data = [] + for model_name, structure_data in energy_profiles_per_model.items(): + for structure_name, (distances, energies) in structure_data.items(): + if structure_name == selected_structure: + for dist, energy in zip(distances, energies): + plot_data.append({ + "distance": dist, + "energy": energy, + "model": model_name, + }) + break + + if plot_data: + df_plot = pd.DataFrame(plot_data) + + # Create the line plot + line_chart = ( + alt.Chart(df_plot) + .mark_line( + point=alt.OverlayMarkDef(size=50, filled=False, strokeWidth=2) + ) + .encode( + x=alt.X("distance:Q", title="Distance (Å)"), + y=alt.Y("energy:Q", title=f"Energy ({unit_label})"), + color=alt.Color("model:N", title="Model"), + tooltip=["distance:Q", "energy:Q", "model:N"], + ) + .interactive() + .properties(width=800, height=400) + ) + + st.altair_chart(line_chart, use_container_width=True) + else: + st.warning( + "No energy profile data available for the selected subset and " + "structure." + ) + else: + st.warning( + f"No energy profile data available for model '{model_selector_sorting}' " + "in the selected subset." + ) + + st.markdown("## Skipped structures per dataset") + st.markdown( + "This table shows the number of structures that were skipped for each data " + "subset and model. The first row shows the total number of structures in " + "each data subset." + ) + + with open( + NCI_ATLAS_DIR / "n_systems_per_subset.json", + mode="r", + encoding="utf-8", + ) as f: + n_systems_per_subset = json.load(f) + + subsets = list(n_systems_per_subset.keys()) + + converted_data = [] + for model_name, results in data.items(): + if ( + len(model_select) > 0 + and model_name in model_select + or len(model_select) == 0 + ): + n_systems_per_subset_for_model = {} + n_skipped_per_subset_for_model = {} + for subset in subsets: + n_systems_per_subset_for_model[subset] = 0 + for system in results.systems: + subset_name_system = f"{system.dataset}: {system.group}" + if subset_name_system == subset: + n_systems_per_subset_for_model[subset] += 1 + + for subset in subsets: + n_skipped_per_subset_for_model[subset] = ( + n_systems_per_subset[subset] + - n_systems_per_subset_for_model[subset] + ) + + converted_data.append(n_skipped_per_subset_for_model) + + df = pd.DataFrame(converted_data, index=selected_models) + st.dataframe(df) diff --git a/src/mlipaudit/utils/__init__.py b/src/mlipaudit/utils/__init__.py index 7f9cc5bc..3211d3f9 100644 --- a/src/mlipaudit/utils/__init__.py +++ b/src/mlipaudit/utils/__init__.py @@ -16,3 +16,4 @@ create_ase_trajectory_from_simulation_state, create_mdtraj_trajectory_from_simulation_state, ) +from mlipaudit.utils.unallowed_elements import skip_unallowed_elements diff --git a/src/mlipaudit/utils/unallowed_elements.py b/src/mlipaudit/utils/unallowed_elements.py new file mode 100644 index 00000000..4b05a558 --- /dev/null +++ b/src/mlipaudit/utils/unallowed_elements.py @@ -0,0 +1,42 @@ +# Copyright 2025 InstaDeep Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ase.data import chemical_symbols +from mlip.models import ForceField + + +def skip_unallowed_elements( + force_field: ForceField, + structure_tuples: list[tuple[str, list[str]]], +) -> set[str]: + """Get a list of structure identifiers that contain unallowed elements. + + Args: + force_field: The force field to use. + structure_tuples: A list of tuples, where each tuple contains a structure + identifier and a list of atom symbols. + + Returns: + A list of structure identifiers that contain unallowed elements. + """ + allowed_atomic_numbers = force_field.allowed_atomic_numbers + allowed_symbols = set(chemical_symbols[z] for z in allowed_atomic_numbers) + + structures_to_skip = set() + + for structure_id, atom_symbols_list in structure_tuples: + if not set(atom_symbols_list).issubset(allowed_symbols): + structures_to_skip.add(structure_id) + + return structures_to_skip diff --git a/tests/conftest.py b/tests/conftest.py index 1d194bd1..d6e76e96 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,17 @@ +# Copyright 2025 InstaDeep Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from pathlib import Path from typing import Callable from unittest.mock import MagicMock, create_autospec diff --git a/tests/data/noncovalent_interactions/NCI_Atlas.json b/tests/data/noncovalent_interactions/NCI_Atlas.json new file mode 100644 index 00000000..59d016a9 --- /dev/null +++ b/tests/data/noncovalent_interactions/NCI_Atlas.json @@ -0,0 +1,540 @@ +{ + "1.03.03": { + "system_id": "1.03.03", + "system_name": "hydrogen...nitrogen", + "dataset_name": "NCIA_D442x10", + "group": "HCNO", + "total_charge": 0.0, + "charge_a": 0.0, + "charge_b": 0.0, + "selection_indices_a": [ + 0, + 1 + ], + "selection_indices_b": [ + 2, + 3 + ], + "atoms": [ + "N", + "N", + "H", + "H" + ], + "distance_profile": [ + 20.0, + 10.0, + 10.5, + 11.0, + 9.5, + 8.0, + 9.0, + 8.5, + 12.5, + 15.0 + ], + "interaction_energy_profile": [ + -0.009, + -0.166, + -0.17, + -0.159, + -0.127, + 0.613, + -0.024, + 0.194, + -0.103, + -0.043 + ], + "coords": [ + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.023934039, + -2.62501083, + 5.296606358 + ], + [ + 0.07923179, + -2.620762929, + 6.031071857 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.037443757, + -1.380030867, + 2.603697365 + ], + [ + 0.065722072, + -1.375782966, + 3.338162864 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.036768271, + -1.442279865, + 2.738342815 + ], + [ + 0.066397558, + -1.438031964, + 3.472808314 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.036092785, + -1.504528863, + 2.872988264 + ], + [ + 0.067073044, + -1.500280962, + 3.607453763 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.038119243, + -1.317781869, + 2.469051915 + ], + [ + 0.065046586, + -1.313533968, + 3.203517414 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.040145701, + -1.131034874, + 2.065115566 + ], + [ + 0.063020128, + -1.126786973, + 2.799581065 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.038794729, + -1.255532871, + 2.334406466 + ], + [ + 0.0643711, + -1.25128497, + 3.068871965 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.039470215, + -1.193283873, + 2.199761016 + ], + [ + 0.063695614, + -1.189035972, + 2.934226515 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.034066327, + -1.691275858, + 3.276924613 + ], + [ + 0.069099502, + -1.687027957, + 4.011390112 + ] + ], + [ + [ + 0.148046572, + 0.648135462, + -0.453788264 + ], + [ + -0.152762313, + -0.363401955, + -0.181097735 + ], + [ + -0.030688898, + -2.002520848, + 3.950151861 + ], + [ + 0.072476931, + -1.998272947, + 4.68461736 + ] + ] + ] + }, + "1.01.01": { + "system_id": "1.01.01", + "system_name": "hydrogen...hydrogen", + "dataset_name": "NCIA_D442x10", + "group": "HCNO", + "total_charge": 0.0, + "charge_a": 0.0, + "charge_b": 0.0, + "selection_indices_a": [ + 0, + 1 + ], + "selection_indices_b": [ + 2, + 3 + ], + "atoms": [ + "H", + "H", + "H", + "H" + ], + "distance_profile": [ + 20.0, + 11.0, + 10.5, + 10.0, + 8.5, + 9.0, + 12.5, + 8.0, + 9.5, + 15.0 + ], + "interaction_energy_profile": [ + -0.004, + -0.078, + -0.087, + -0.09, + 0.026, + -0.049, + -0.048, + 0.176, + -0.081, + -0.02 + ], + "coords": [ + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.695326527, + -1.981386764, + 4.26995891 + ], + [ + -0.262707793, + -1.378977922, + 4.263145027 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.422023309, + -1.022758807, + 1.835677885 + ], + [ + 0.010595425, + -0.420349965, + 1.828864002 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.406839797, + -0.969501699, + 1.700440051 + ], + [ + 0.025778937, + -0.367092857, + 1.693626168 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.391656285, + -0.91624459, + 1.565202216 + ], + [ + 0.040962449, + -0.313835748, + 1.558388333 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.346105749, + -0.756473264, + 1.159488712 + ], + [ + 0.086512985, + -0.154064422, + 1.152674829 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.361289261, + -0.809730373, + 1.294726547 + ], + [ + 0.071329473, + -0.207321531, + 1.287912664 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.467573846, + -1.182530134, + 2.24139139 + ], + [ + -0.034955112, + -0.580121292, + 2.234577507 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.330922237, + -0.703216155, + 1.024250877 + ], + [ + 0.101696497, + -0.100807313, + 1.017436994 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.376472773, + -0.862987481, + 1.429964381 + ], + [ + 0.056145961, + -0.260578639, + 1.423150498 + ] + ], + [ + [ + 0.294609617, + 0.81227302, + -1.852328373 + ], + [ + 0.056084218, + 0.417807318, + -1.271262175 + ], + [ + -0.543491406, + -1.448815677, + 2.917580563 + ], + [ + -0.110872672, + -0.846406835, + 2.91076668 + ] + ] + ] + } +} diff --git a/tests/noncovalent_interactions/test_noncovalent_interactions.py b/tests/noncovalent_interactions/test_noncovalent_interactions.py new file mode 100644 index 00000000..e8f78383 --- /dev/null +++ b/tests/noncovalent_interactions/test_noncovalent_interactions.py @@ -0,0 +1,252 @@ +# Copyright 2025 InstaDeep Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from pathlib import Path + +import pytest +from ase import units + +from mlipaudit.noncovalent_interactions import ( + NoncovalentInteractionsBenchmark, +) +from mlipaudit.noncovalent_interactions.noncovalent_interactions import ( + NoncovalentInteractionsModelOutput, + NoncovalentInteractionsResult, + NoncovalentInteractionsSystemModelOutput, + NoncovalentInteractionsSystemResult, + compute_total_interaction_energy, +) + +INPUT_DATA_DIR = Path(__file__).parent.parent / "data" + + +@pytest.fixture +def noncovalent_interactions_benchmark( + request, mocked_benchmark_init, mock_force_field +) -> NoncovalentInteractionsBenchmark: + """Assembles a fully configured and isolated NoncovalentInteractionsBenchmark + instance. + + This fixture is parameterized to handle the `fast_dev_run` flag. + + Returns: + An initialized NoncovalentInteractionsBenchmark instance. + """ + is_fast_run = getattr(request, "param", False) + + return NoncovalentInteractionsBenchmark( + force_field=mock_force_field, + data_input_dir=INPUT_DATA_DIR, + fast_dev_run=is_fast_run, + ) + + +@pytest.mark.parametrize( + "noncovalent_interactions_benchmark", [True, False], indirect=True +) +def test_full_run_with_mocked_inference( + noncovalent_interactions_benchmark, mocked_batched_inference, mocker +): + """Integration test using the modular fixture for fast_dev_run.""" + benchmark = noncovalent_interactions_benchmark + benchmark.force_field.allowed_atomic_numbers = list(range(1, 92)) + + _mocked_batched_inference = mocker.patch( + "mlipaudit.noncovalent_interactions.noncovalent_interactions.run_batched_inference", + side_effect=mocked_batched_inference, + ) + + benchmark.run_model() + + assert type(benchmark.model_output) is NoncovalentInteractionsModelOutput + assert len(benchmark.model_output.systems) == len(benchmark._nci_atlas_data) + assert ( + type(benchmark.model_output.systems[0]) + is NoncovalentInteractionsSystemModelOutput + ) + assert len(benchmark.model_output.systems[0].energy_profile) == len( + benchmark._nci_atlas_data["1.03.03"].coords + ) + + result = benchmark.analyze() + + assert type(result) is NoncovalentInteractionsResult + assert len(result.systems) == len(benchmark._nci_atlas_data) + for system_results in result.systems: + assert system_results.dataset == "Dispersion" + + test_system = result.systems[0] + assert type(test_system) is NoncovalentInteractionsSystemResult + assert test_system.system_id == "1.03.03" + assert len(test_system.reference_energy_profile) == len( + benchmark._nci_atlas_data["1.03.03"].distance_profile + ) + assert len(test_system.energy_profile) == len( + benchmark._nci_atlas_data["1.03.03"].distance_profile + ) + assert len(test_system.distance_profile) == len( + benchmark._nci_atlas_data["1.03.03"].distance_profile + ) + + +def test_compute_total_interaction_energy(): + """Tests the compute_total_interaction_energy function.""" + distance_profile = [1.0, 2.0, 3.0] + distance_profile_unsorted = [2.0, 3.0, 1.0] + interaction_energy_profile_attractive = [1.5, -1.0, 0.0] + interaction_energy_profile_repulsive = [1.4, 1.0, 0.0] + interaction_energy_profile_unsorted = [-1.0, 0.0, 1.5] + + assert compute_total_interaction_energy( + distance_profile, interaction_energy_profile_attractive, repulsive=False + ) == pytest.approx(-1.0) + + assert compute_total_interaction_energy( + distance_profile, interaction_energy_profile_repulsive, repulsive=True + ) == pytest.approx(1.4) + + assert compute_total_interaction_energy( + distance_profile_unsorted, interaction_energy_profile_unsorted, repulsive=False + ) == pytest.approx(-1.0) + + +def test_analyze_raises_error_if_run_first(noncovalent_interactions_benchmark): + """Verifies the RuntimeError using the new fixture.""" + expected_message = "Must call run_model() first." + with pytest.raises(RuntimeError, match=re.escape(expected_message)): + noncovalent_interactions_benchmark.analyze() + + +@pytest.mark.parametrize( + "noncovalent_interactions_benchmark, expected_molecules", + [(True, 2), (False, 2)], + indirect=["noncovalent_interactions_benchmark"], +) +def test_data_loading(noncovalent_interactions_benchmark, expected_molecules): + """Unit test for the _nci_atlas_data property, parameterized for fast_dev_run.""" + data = noncovalent_interactions_benchmark._nci_atlas_data + assert len(data) == expected_molecules + assert data["1.03.03"].system_id == "1.03.03" + if not noncovalent_interactions_benchmark.fast_dev_run: + assert data["1.01.01"].system_id == "1.01.01" + + +def test_perfect_agreement(noncovalent_interactions_benchmark): + """Tests that the core mathematical properties of the analyze method hold true.""" + benchmark = noncovalent_interactions_benchmark + + energy_profile_1_kcal_mol = [ + -0.009, + -0.166, + -0.17, + -0.159, + -0.127, + 0.613, + -0.024, + 0.194, + -0.103, + -0.043, + ] + + energy_profile_2_kcal_mol = [ + -0.004, + -0.078, + -0.087, + -0.09, + 0.026, + -0.049, + -0.048, + 0.176, + -0.081, + -0.02, + ] + + energy_profile_1_ev = [ + x * (units.kcal / units.mol) for x in energy_profile_1_kcal_mol + ] + + energy_profile_2_ev = [ + x * (units.kcal / units.mol) for x in energy_profile_2_kcal_mol + ] + + benchmark.model_output = NoncovalentInteractionsModelOutput( + systems=[ + NoncovalentInteractionsSystemModelOutput( + system_id="1.03.03", + energy_profile=energy_profile_1_ev, + ), + NoncovalentInteractionsSystemModelOutput( + system_id="1.01.01", + energy_profile=energy_profile_2_ev, + ), + ], + n_skipped_unallowed_elements=0, + ) + + result = benchmark.analyze() + for system_results in result.systems: + assert system_results.dataset == "Dispersion" + assert system_results.group == "HCNO" + + assert result.systems[0].reference_interaction_energy == pytest.approx(-0.161) + assert result.systems[1].reference_interaction_energy == pytest.approx(-0.086) + + assert result.rmse_interaction_energy_all == pytest.approx(0.0) + + assert result.rmse_interaction_energy_datasets["Dispersion"] == pytest.approx(0.0) + assert result.mae_interaction_energy_datasets["Dispersion"] == pytest.approx(0.0) + assert result.rmse_interaction_energy_subsets["Dispersion: HCNO"] == pytest.approx( + 0.0 + ) + assert result.mae_interaction_energy_subsets["Dispersion: HCNO"] == pytest.approx( + 0.0 + ) + + +def test_bad_agreement(noncovalent_interactions_benchmark): + """Tests that the core mathematical properties of the analyze method hold true.""" + benchmark = noncovalent_interactions_benchmark + + benchmark.model_output = NoncovalentInteractionsModelOutput( + systems=[ + NoncovalentInteractionsSystemModelOutput( + system_id="1.03.03", + energy_profile=[0.0] * 10, + ), + NoncovalentInteractionsSystemModelOutput( + system_id="1.01.01", + energy_profile=[0.0] * 10, + ), + ], + n_skipped_unallowed_elements=0, + ) + + result = benchmark.analyze() + for system_results in result.systems: + assert system_results.dataset == "Dispersion" + assert system_results.group == "HCNO" + + assert result.systems[0].reference_interaction_energy == pytest.approx(-0.161) + assert result.systems[1].reference_interaction_energy == pytest.approx(-0.086) + + assert result.systems[0].deviation == pytest.approx(0.161) + assert result.systems[1].deviation == pytest.approx(0.086) + + assert result.rmse_interaction_energy_all == pytest.approx(0.12906781163403988) + + assert result.mae_interaction_energy_datasets["Dispersion"] == pytest.approx(0.1235) + assert result.rmse_interaction_energy_datasets["Dispersion"] == pytest.approx( + 0.12906781163403988 + ) diff --git a/tests/utils/test_unallowed_elements.py b/tests/utils/test_unallowed_elements.py new file mode 100644 index 00000000..35b53e87 --- /dev/null +++ b/tests/utils/test_unallowed_elements.py @@ -0,0 +1,47 @@ +# Copyright 2025 InstaDeep Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from mlipaudit.utils import skip_unallowed_elements + + +def test_allowed_elements_are_not_skipped(mock_force_field): + """Tests the skip_unallowed_elements function.""" + mock_force_field.allowed_atomic_numbers = list(range(1, 92)) + + structure_tuples = [ + ("mol_1", ["H", "C", "N", "O"]), + ("mol_2", ["H", "H", "H", "C", "H", "O"]), + ("mol_3", ["H", "H", "H", "C", "N", "H", "H"]), + ("mol_4", ["He", "C", "H", "H", "H", "H"]), + ("mol_5", ["He", "B", "Si", "Ge", "As", "Na", "Cl"]), + ] + + assert skip_unallowed_elements(mock_force_field, structure_tuples) == set() + + +def test_unallowed_elements_are_skipped(mock_force_field): + """Tests the skip_unallowed_elements function.""" + mock_force_field.allowed_atomic_numbers = [1, 6, 7, 8] + + structure_tuples = [ + ("mol_1", ["H", "C", "N", "O"]), + ("mol_2", ["He", "H", "H", "C", "H", "O"]), + ("mol_3", ["F", "F", "F", "C", "N", "H", "H"]), + ] + + assert skip_unallowed_elements(mock_force_field, structure_tuples) == set([ + "mol_2", + "mol_3", + ])