diff --git a/meeko/molsetup.py b/meeko/molsetup.py index 1393defe..0c3efb9e 100644 --- a/meeko/molsetup.py +++ b/meeko/molsetup.py @@ -1562,7 +1562,7 @@ def _decode_object(cls, obj: dict[str, Any]): int(k): [string_to_tuple(t) for t in v] for k, v in obj["atom_to_ring_id"].items() } - rdkit_molsetup.rmsd_symmetry_indices = list(map(string_to_tuple, obj["rmsd_symmetry_indices"])) + rdkit_molsetup.rmsd_symmetry_indices = tuple(map(string_to_tuple, obj["rmsd_symmetry_indices"])) return rdkit_molsetup # endregion @@ -1641,7 +1641,6 @@ def from_mol( # functions molsetup = cls() molsetup.mol = mol - molsetup.atom_true_count = molsetup.get_num_mol_atoms() molsetup.name = molsetup.get_mol_name() coords = rdkit_conformer.GetPositions() molsetup.init_atom(compute_gasteiger_charges, read_charges_from_prop, coords) diff --git a/meeko/polymer.py b/meeko/polymer.py index eadf9a61..cb91034e 100644 --- a/meeko/polymer.py +++ b/meeko/polymer.py @@ -36,7 +36,6 @@ import numpy as np -data_path = files("meeko") / "data" periodic_table = Chem.GetPeriodicTable() try: @@ -643,13 +642,18 @@ def add_dict(self, data, overwrite=False): res_template = ResidueTemplate.from_dict(value) self.residue_templates[key] = res_template for link_label, value in data.get("padders", {}).items(): - if overwrite or key not in self.padders: - padder = ResiduePadder.from_dict(data) + if overwrite or link_label not in self.padders: + padder = ResiduePadder.from_dict(value) self.padders[link_label] = padder return + + @staticmethod + def _default_data_path(): + return files("meeko") / "data" @staticmethod - def lookup_filename(filename, data_path): + def lookup_filename(filename, data_path = None): + data_path = data_path or ResidueChemTemplates._default_data_path() p = pathlib.Path(filename) if not p.exists(): if (data_path / p).exists(): @@ -661,7 +665,8 @@ def lookup_filename(filename, data_path): return filename @classmethod - def from_json_file(cls, filename): + def from_json_file(cls, filename, data_path = None): + data_path = data_path or ResidueChemTemplates._default_data_path() filename = cls.lookup_filename(filename, data_path) with open(filename) as f: jsonstr = f.read() @@ -681,7 +686,8 @@ def from_json_file(cls, filename): def create_from_defaults(cls): return cls.from_json_file("residue_chem_templates") - def add_json_file(self, filename): + def add_json_file(self, filename, data_path = None): + data_path = data_path or ResidueChemTemplates._default_data_path() filename = self.lookup_filename(filename, data_path) with open(filename) as f: jsonstr = f.read() @@ -1052,6 +1058,12 @@ def _decode_object(cls, obj: dict[str, Any]): k: Monomer.from_dict(v) for k, v in obj["monomers"].items() } polymer.log = obj["log"] + for nested_key in polymer.log: + if isinstance(polymer.log[nested_key], dict): + polymer.log[nested_key] = { + k: tuple(v) if isinstance(v, list) else v + for k, v in polymer.log[nested_key].items() + } return polymer # endregion @@ -2425,7 +2437,7 @@ def __init__( # (JSON-unbound) computed attributes # TODO convert link indices/labels in template to rdkit_mol indices herein - # self.link_labels = {} + self.link_labels = {} self.template = None @staticmethod @@ -2628,6 +2640,8 @@ class ResiduePadder(BaseJSONParsable): ---------- rxn : rdChemReactions.ChemicalReaction Reaction SMARTS of a single-reactant, single-product reaction for padding. + adjacent_smarts : str + SMARTS pattern for identifying atoms in the adjacent residue to copy positions from. adjacent_smartsmol : Chem.Mol SMARTS molecule with mapping numbers to copy atom positions from part of adjacent residue. adjacent_smartsmol_mapidx : list @@ -2671,12 +2685,14 @@ def __init__(self, rxn_smarts: str, adjacent_res_smarts: str = None, auto_blunt: # Fill in adjacent_smartsmol_mapidx if adjacent_res_smarts is None: + self.adjacent_smarts = None self.adjacent_smartsmol = None self.adjacent_smartsmol_mapidx = None return - # Ensure adjacent_res_smarts is None or a valid SMARTS - self.adjacent_smartsmol = self._initialize_adj_smartsmol(adjacent_res_smarts) + # Ensure adjacent_res_smarts is None or a valid SMARTS + self.adjacent_smarts = adjacent_res_smarts + self.adjacent_smartsmol = self._initialize_adj_smartsmol(self.adjacent_smarts) # Ensure the mapping numbers are the same in adjacent_smartsmol and rxn_smarts's product self._check_adj_smarts(self.rxn, self.adjacent_smartsmol) @@ -2875,7 +2891,7 @@ def _check_target_mol(self, target_mol: Chem.Mol): def json_encoder(cls, obj: "ResiduePadder") -> Optional[dict[str, Any]]: output_dict = { "rxn_smarts": rdChemReactions.ReactionToSmarts(obj.rxn), - "adjacent_res_smarts": serialize_optional(Chem.MolToSmarts, obj.adjacent_smartsmol), + "adjacent_res_smarts": obj.adjacent_smarts, "auto_blunt": obj.auto_blunt, } # we are not serializing the adjacent_smartsmol_mapidx as that will diff --git a/test/json_serialization_test.py b/test/json_serialization_test.py index d998727a..2fad49f0 100644 --- a/test/json_serialization_test.py +++ b/test/json_serialization_test.py @@ -1,72 +1,77 @@ -import collections -import json -import meeko -import numpy +import numpy as np import pathlib import pytest +from rdkit import Chem +from rdkit.Chem import rdChemReactions +import meeko +from meeko import MoleculePreparation +import warnings +import logging +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) +# JSONParsable classes subject to serialization tests from meeko import ( Monomer, Polymer, - MoleculePreparation, - MoleculeSetup, RDKitMoleculeSetup, ResiduePadder, ResidueTemplate, ResidueChemTemplates, - PDBQTWriterLegacy, ) +from meeko.molsetup import Atom, Bond, Ring, Restraint -from meeko import polymer -from meeko.molsetup import Atom, Bond, Ring, RingClosureInfo, Restraint - -from rdkit import Chem -from rdkit.Chem import rdChemReactions - -from meeko.utils.pdbutils import PDBAtomInfo +# Registry of class : set of attributes to skip for testing +EQUALITY_SKIP_FIELDS = { + Monomer: {"template", "link_labels"}, +} +# Optional dependency for test_dihedral_equality try: import openforcefields _got_openff = True except ImportError as err: _got_openff = False -# from ..meeko.utils.pdbutils import PDBAtomInfo - +# Test data: starting files for polymer creation pkgdir = pathlib.Path(meeko.__file__).parents[1] - -# Test Data ahhy_example = pkgdir / "test/polymer_data/AHHY.pdb" ahhy_v061_json = pkgdir / "test/polymer_data/AHHY-v0.6.1.json" just_one_ALA_missing = ( pkgdir / "test/polymer_data/just-one-ALA-missing-CB.pdb" ) +pqr_example = pkgdir / "test/polymer_data/1FAS_dry.pqr" # Polymer creation data chem_templates = ResidueChemTemplates.create_from_defaults() mk_prep = MoleculePreparation() -def test_read_v061_polymer(): - with open(ahhy_v061_json) as f: - json_str = f.read() - polymer = Polymer.from_json(json_str) - return - # region Fixtures @pytest.fixture def populated_polymer(): - file = open(ahhy_example) - pdb_str = file.read() + """fixture for a populated polymer object""" + with open(ahhy_example) as file: + pdb_str = file.read() polymer = Polymer.from_pdb_string( pdb_str, chem_templates, mk_prep, blunt_ends=[("A:1", 0)] ) return polymer +@pytest.fixture +def populated_polymer_v061(): + """fixture for a populated polymer object, from a v0.6.1 JSON file""" + with open(ahhy_v061_json) as f: + json_str = f.read() + polymer = Polymer.from_json(json_str) + if len(polymer.monomers) == 0: + raise ValueError("Polymer creation failed") + return polymer @pytest.fixture def populated_polymer_missing(): - file = open(just_one_ALA_missing) - pdb_str = file.read() + """fixture for a populated polymer object, with one residue missing""" + with open(just_one_ALA_missing) as file: + pdb_str = file.read() polymer = Polymer.from_pdb_string( pdb_str, chem_templates, @@ -76,254 +81,235 @@ def populated_polymer_missing(): ) return polymer - -@pytest.fixture -def populated_monomer(populated_polymer): - polymer = populated_polymer - return polymer.monomers["A:1"] - - -@pytest.fixture -def populated_rdkit_molsetup(populated_monomer): - monomer = populated_monomer - return monomer.molsetup - - -@pytest.fixture -def populated_residue_chem_templates(populated_polymer): - polymer = populated_polymer - return polymer.residue_chem_templates - - @pytest.fixture -def populated_residue_template(populated_residue_chem_templates): - res_chem_templates = populated_residue_chem_templates - return res_chem_templates.residue_templates["G"] - +def populated_polymer_from_pqr(): + """fixture for a populated polymer object""" + with open(pqr_example) as file: + pqr_str = file.read() + polymer = Polymer.from_pqr_string( + pqr_str, chem_templates, mk_prep + ) + return polymer @pytest.fixture -def populated_residue_padder(populated_residue_chem_templates): - res_chem_templates = populated_residue_chem_templates - return res_chem_templates.padders["5-prime"] - - +def populated_residue_chem_templates(): + """fixture for a populated ResidueChemTemplates object from default""" + return ResidueChemTemplates.create_from_defaults() # endregion -# region Test Cases -def test_rdkit_molsetup_encoding_decoding(populated_rdkit_molsetup): - """ - Takes a fully populated RDKitMoleculeSetup, checks that it can be serialized to JSON and deserialized back into an - object without any errors, then checks that the deserialized object matches the starting object and that the - attribute types, values, and structure of the deserialized object are as expected for an RDKitMoleculeSetup. - - Parameters - ---------- - populated_rdkit_molsetup: RDKitMoleculeSetup - Takes as input a populated RDKitMoleculeSetup object. - - Returns - ------- - None - """ - # TODO: Certain fields are empty in this example, and if we want to make sure that json is working in all scenarios - # we will need to make other tests for those empty fields. - # Encode and decode MoleculeSetup from json - starting_molsetup = populated_rdkit_molsetup - json_str = starting_molsetup.to_json() - decoded_molsetup = RDKitMoleculeSetup.from_json(json_str) - - # First asserts that all types are as expected - assert isinstance(starting_molsetup, RDKitMoleculeSetup) - assert isinstance(decoded_molsetup, RDKitMoleculeSetup) - - # Go through MoleculeSetup attributes and check that they are the expected type and match the MoleculeSetup object - # before serialization. - check_molsetup_equality(decoded_molsetup, starting_molsetup) - return - - -def test_monomer_encoding_decoding(populated_monomer): - """ - Takes a fully populated Monomer, checks that it can be serialized to JSON and deserialized back into an - object without any errors, then checks that the deserialized object matches the starting object and that the - attribute types, values, and structure of the deserialized object are as expected for an Monomer. - - Parameters - ---------- - populated_monomer: Monomer - Takes as input a populated Monomer object. - - Returns - ------- - None - """ - # Starts by getting a Monomer object, converting it to a json string, and then decoding the string into - # a new Monomer object - starting_monomer = populated_monomer - json_str = starting_monomer.to_json() - - decoded_monomer = Monomer.from_json(json_str) - - # Asserts that the starting and ending objects have the expected Monomer type - assert isinstance(starting_monomer, Monomer) - assert isinstance(decoded_monomer, Monomer) - - check_monomer_equality(decoded_monomer, starting_monomer) - return - - -def test_pdbqt_writing_from_decoded_polymer(populated_polymer): - """ - Takes a fully populated Polymer, writes a PDBQT string from it, encodes and decodes it, writes - another PDBQT string from the decoded polymer, and then checks that the PDBQT strings are identical. - - Parameters - ---------- - populated_polymer: Polymer - Takes as input a populated Polymer object. - - Returns - ------- - None - """ - - starting_polymer = populated_polymer - starting_pdbqt = PDBQTWriterLegacy.write_from_polymer(starting_polymer) - json_str = starting_polymer.to_json() - decoded_polymer = Polymer.from_json(json_str) - decoded_pdbqt = PDBQTWriterLegacy.write_from_polymer(decoded_polymer) - assert decoded_pdbqt == starting_pdbqt - return - - - -def test_residue_template_encoding_decoding(populated_residue_template): +# region Helper Functions +def subobject_factory(cls, root): """ - Takes a fully populated ResidueTemplate, checks that it can be serialized to JSON and deserialized back into an - object without any errors, then checks that the deserialized object matches the starting object and that the - attribute types, values, and structure of the deserialized object are as expected for an ResidueTemplate. + Factory function to create subobjects based on the class and root object. Parameters ---------- - populated_residue_template: ResidueTemplate - Takes as input a populated ResidueTemplate object. - + cls : type + The class of the subobject to create. + root : object + The root object from which to create the subobject. + Returns ------- - None - """ - # Starts by getting a ResidueTemplate object, converting it to a json string, and then decoding the string into - # a new ResidueTemplate object - starting_template = populated_residue_template - json_str = starting_template.to_json() - decoded_template = ResidueTemplate.from_json(json_str) - - # Asserts that the starting and ending objects have the expected ResidueTemplate type - assert isinstance(starting_template, ResidueTemplate) - assert isinstance(decoded_template, ResidueTemplate) - - # Checks that the two residue templates are equal - check_residue_template_equality(decoded_template, starting_template) - return - - -def test_residue_padder_encoding_decoding(populated_residue_padder): - """ - Takes a fully populated ResiduePadder, checks that it can be serialized to JSON and deserialized back into an - object without any errors, then checks that the deserialized object matches the starting object and that the - attribute types, values, and structure of the deserialized object are as expected for an ResiduePadder. - - Parameters - ---------- - populated_residue_padder: ResiduePadder - Takes as input a populated ResiduePadder object. - - Returns - ------- - None - """ - # Starts by getting a ResiduePadder object, converting it to a json string, and then decoding the string into - # a new ResiduePadder object - starting_padder = populated_residue_padder - json_str = starting_padder.to_json() - decoded_padder = ResiduePadder.from_json(json_str) - - # Asserts that the starting and ending objects have the expected ResiduePadder type - assert isinstance(starting_padder, ResiduePadder) - assert isinstance(decoded_padder, ResiduePadder) - - # Checks that the two residue padders are equal - check_residue_padder_equality(decoded_padder, starting_padder) - return - - -def test_residue_chem_templates_encoding_decoding(populated_residue_chem_templates): - """ - Takes a fully populated ResidueChemTemplates, checks that it can be serialized to JSON and deserialized back into an - object without any errors, then checks that the deserialized object matches the starting object and that the - attribute types, values, and structure of the deserialized object are as expected for an ResidueChemTemplates. - + iterable + An iterable of subobjects of the specified class. + + Raises + ------ + ValueError + If the class or root object is not recognized by given schema. + """ + # Polymer-based hierarchy + if isinstance(root, Polymer): + if cls is Polymer: + return [root] + if cls is Monomer: + return root.monomers.values() + if cls is RDKitMoleculeSetup: + return [m.molsetup for m in root.monomers.values()] + + # ResidueChemTemplates hierarchy + if isinstance(root, ResidueChemTemplates): + if cls is ResidueChemTemplates: + return [root] + if cls is ResidueTemplate: + return root.residue_templates.values() + if cls is ResiduePadder: + return root.padders.values() + + # RDKitMoleculeSetup hierarchy + if isinstance(root, RDKitMoleculeSetup): + if cls is RDKitMoleculeSetup: + return [root] + if cls is Atom: + return root.atoms + if cls is Bond: + return root.bond_info.values() + if cls is Ring: + return root.rings.values() + if cls is Restraint: + return root.restraints + + raise ValueError(f"Unexpected class or root: {cls}, {type(root)}") + +def deep_assert_equal(decoded, original, path="root"): + """Recursively compares two objects with support for type-aware handling and skip lists. + Parameters ---------- - populated_residue_chem_templates: ResidueChemTemplates - Takes as input a populated ResidueChemTemplates object. - - Returns - ------- - None - """ - # Starts by getting a ResidueChemTemplates object, converting it to a json string, and then decoding the string into - # a new ResidueChemTemplates object - starting_templates = populated_residue_chem_templates - json_str = starting_templates.to_json() - decoded_templates = ResidueChemTemplates.from_json(json_str) - - # Asserts that the starting and ending objects have the expected ResidueChemTemplates type - assert isinstance(starting_templates, ResidueChemTemplates) - assert isinstance(decoded_templates, ResidueChemTemplates) - - # Checks that the two chem templates are equal - check_residue_chem_templates_equality(decoded_templates, starting_templates) + decoded : object + The decoded object to compare. + original : object + The original object to compare against. + path : str + The current path in the object hierarchy for error reporting. + + Raises + ------ + AssertionError + If the objects are not equal or if there are type mismatches. + """ + if type(decoded) != type(original): + raise AssertionError(f"[{path}] Type mismatch: {type(decoded)} != {type(original)}") + + # Basic types + if isinstance(decoded, (int, float, bool, str)): + assert decoded == original, f"[{path}] Value mismatch: {decoded} != {original}" + return + + # Dicts + if isinstance(decoded, dict): + assert decoded.keys() == original.keys(), f"[{path}] Dict keys mismatch" + for key in decoded: + deep_assert_equal(decoded[key], original[key], path=f"{path}.{key}") + return + + # Lists or Tuples + if isinstance(decoded, (list, tuple)): + assert len(decoded) == len(original), f"[{path}] Length mismatch" + for i, (d_item, o_item) in enumerate(zip(decoded, original)): + deep_assert_equal(d_item, o_item, path=f"{path}[{i}]") + return + + # Numpy arrays + if isinstance(decoded, np.ndarray): + assert np.allclose(decoded, original), f"[{path}] Numpy arrays not equal" + return + + # RDKit Molecules + if isinstance(decoded, Chem.Mol): + decoded_smiles = Chem.MolToSmiles(decoded) + original_smiles = Chem.MolToSmiles(original) + assert decoded_smiles == original_smiles, f"[{path}] Mol SMILES mismatch" + return + + # RDKit Reactions + if isinstance(decoded, rdChemReactions.ChemicalReaction): + assert rdChemReactions.ReactionToSmarts(decoded) == rdChemReactions.ReactionToSmarts(original), f"[{path}] Reaction SMARTS mismatch" + return + + # Custom objects with attributes + if hasattr(decoded, "__dict__"): + cls = type(decoded) + skip_attrs = EQUALITY_SKIP_FIELDS.get(cls, set()) + + # Check for extra attributes that are not in the original + decoded_attr = set(dir(decoded)) + original_attr = set(dir(original)) + if decoded_attr - original_attr: + raise AssertionError(f"[{path}] Extra attributes in decoded object: {decoded_attr - original_attr}") + + for attr in original_attr: + # skip private + if attr.startswith("_"): + continue + # skip methods/functions/descriptors + try: + orig_val = getattr(original, attr) + except Exception: + continue + if callable(orig_val): + continue + # skip attributes if explicitly stated + if attr in skip_attrs: + continue + if not hasattr(decoded, attr): + raise AssertionError(f"[{path}] Missing attribute in decoded object: {attr}") + decoded_val = getattr(decoded, attr) + original_val = getattr(original, attr) + logger.info(f"[{path}] Checking attribute: {attr}") + deep_assert_equal(decoded_val, original_val, path=f"{path}.{attr}") + return + + # Fallback + assert decoded == original, f"[{path}] Fallback mismatch: {decoded} != {original}" return +# endregion -def test_polymer_encoding_decoding( - populated_polymer, populated_polymer_missing -): - """ - Takes a fully populated Polymer, checks that it can be serialized to JSON and deserialized back into an - object without any errors, then checks that the deserialized object matches the starting object and that the - attribute types, values, and structure of the deserialized object are as expected for a Polymer. - - Parameters - ---------- - populated_polymer: Polymer - Takes as input a populated Polymer object. - - Returns - ------- - None - """ - # Starts by getting a Polymer object, converting it to a json string, and then decoding the string into - # a new Polymer object - polymers = ( - populated_polymer, - populated_polymer_missing, - ) - for starting_polymer in polymers: - json_str = starting_polymer.to_json() - decoded_polymer = Polymer.from_json(json_str) - - # Asserts that the starting and ending objects have the expected Polymer type - assert isinstance(starting_polymer, Polymer) - assert isinstance(decoded_polymer, Polymer) - - # Checks that the two polymers are equal - check_polymer_equality(decoded_polymer, starting_polymer) - return +# region Hierachical Tests +# iterate over nested classes in the Polymer hierarchy +@pytest.mark.parametrize("polymer_fixture", [ + "populated_polymer", "populated_polymer_v061", "populated_polymer_missing", "populated_polymer_from_pqr" +]) +@pytest.mark.parametrize("cls", [ + Polymer, + Monomer, + RDKitMoleculeSetup, +]) +# check for seralization/deserialization and deep equality +def test_json_roundtrip(cls, polymer_fixture, request): + """Tests starting from a populated polymer object""" + polymer = request.getfixturevalue(polymer_fixture) + for obj in subobject_factory(cls, polymer): + if obj is None: + warnings.warn(f"Subobject of type {cls.__name__} is None — skipping.", stacklevel=1) + continue + json_str = obj.to_json() + decoded = cls.from_json(json_str) + assert isinstance(decoded, cls) + deep_assert_equal(decoded, obj) + +# same test but starting from the default ResidueChemTemplates object +@pytest.mark.parametrize("cls", [ + ResidueChemTemplates, + ResidueTemplate, + ResiduePadder, +]) +def test_json_rct(cls, populated_residue_chem_templates): + for obj in subobject_factory(cls, populated_residue_chem_templates): + if obj is None: + warnings.warn(f"Subobject of type {cls.__name__} is None — skipping.", stacklevel=1) + continue + json_str = obj.to_json() + decoded = cls.from_json(json_str) + assert isinstance(decoded, cls) + deep_assert_equal(decoded, obj) + +# same test but iterating over the RDKitMoleculeSetup hierarchy +@pytest.mark.parametrize("cls", [ + RDKitMoleculeSetup, + Atom, + Bond, + Ring, + Restraint, +]) +# the RDKitMoleculeSetup instances used for this test are created from the populated polymer +def test_json_molsetup(cls, populated_polymer): + for molsetup in subobject_factory(RDKitMoleculeSetup, populated_polymer): + for obj in subobject_factory(cls, molsetup): + if obj is None: + warnings.warn(f"Subobject of type {cls.__name__} is None — skipping.", stacklevel=1) + continue + json_str = obj.to_json() + decoded = cls.from_json(json_str) + assert isinstance(decoded, cls) + deep_assert_equal(decoded, obj) +# endregion +# region Other Tests def test_load_reference_json(): fn = str(pkgdir/"test"/"polymer_data"/"AHHY_reference_fewer_templates.json") with open(fn) as f: @@ -344,7 +330,7 @@ def test_dihedral_equality(): starting_molsetup = mk_prep(mol)[0] json_str = starting_molsetup.to_json() decoded_molsetup = RDKitMoleculeSetup.from_json(json_str) - check_molsetup_equality(starting_molsetup, decoded_molsetup) + deep_assert_equal(starting_molsetup, decoded_molsetup) return @@ -361,435 +347,4 @@ def test_broken_bond(): count_breakable += bond_info.breakable assert count_rotatable == 10 assert count_breakable == 1 - -# endregion - - -# region Object Equality Checks -def check_molsetup_equality(decoded_obj: MoleculeSetup, starting_obj: MoleculeSetup): - """ - Asserts that two MoleculeSetup objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: MoleculeSetup - A MoleculeSetup object that we want to check is correctly typed and contains the correct data. - starting_obj: MoleculeSetup - A MoleculeSetup object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - - # Checks if the MoleculeSetup is an RDKitMoleculeSetup, and if so also checks the RDKitMoleculeSetup attributes - if isinstance(starting_obj, RDKitMoleculeSetup): - assert isinstance(decoded_obj.mol, Chem.rdchem.Mol) - pass - - # Going through and checking MoleculeSetup attributes - assert decoded_obj.name == starting_obj.name - assert isinstance(decoded_obj.pseudoatom_count, int) - assert decoded_obj.pseudoatom_count == starting_obj.pseudoatom_count - - # Checking atoms - atom_idx = 0 - assert len(decoded_obj.atoms) == len(starting_obj.atoms) - for atom in decoded_obj.atoms: - assert isinstance(atom, Atom) - assert atom.index == atom_idx - check_atom_equality(atom, starting_obj.atoms[atom_idx]) - atom_idx += 1 - - # Checking bonds - for bond_id in starting_obj.bond_info: - assert isinstance(decoded_obj.bond_info[bond_id], Bond) - assert bond_id in decoded_obj.bond_info - check_bond_equality( - decoded_obj.bond_info[bond_id], starting_obj.bond_info[bond_id] - ) - - # Checking rings - for ring_id in starting_obj.rings: - assert isinstance(decoded_obj.rings[ring_id], Ring) - assert ring_id in decoded_obj.rings - check_ring_equality(decoded_obj.rings[ring_id], starting_obj.rings[ring_id]) - assert isinstance(decoded_obj.ring_closure_info, RingClosureInfo) - assert ( - decoded_obj.ring_closure_info.bonds_removed - == starting_obj.ring_closure_info.bonds_removed - ) - for key in starting_obj.ring_closure_info.pseudos_by_atom: - assert key in decoded_obj.ring_closure_info.pseudos_by_atom - assert ( - decoded_obj.ring_closure_info.pseudos_by_atom[key] - == starting_obj.ring_closure_info.pseudos_by_atom[key] - ) - - # Checking other fields - assert len(decoded_obj.rotamers) == len(starting_obj.rotamers) - for idx, component_dict in enumerate(starting_obj.rotamers): - decoded_dict = decoded_obj.rotamers[idx] - for key in component_dict: - assert key in decoded_dict - assert decoded_dict[key] == component_dict[key] - for key in starting_obj.atom_params: - assert key in decoded_obj.atom_params - assert decoded_obj.atom_params[key] == starting_obj.atom_params[key] - assert len(decoded_obj.restraints) == len(starting_obj.restraints) - for idx, restraint in starting_obj.restraints: - assert isinstance(decoded_obj.restraints[idx], Restraint) - check_restraint_equality( - decoded_obj.restraints[idx], starting_obj.restraints[idx] - ) - - # dihedrals - assert decoded_obj.dihedral_partaking_atoms == starting_obj.dihedral_partaking_atoms - assert decoded_obj.dihedral_interactions == starting_obj.dihedral_interactions - assert decoded_obj.dihedral_labels == starting_obj.dihedral_labels - - # Checking flexibility model - for key in starting_obj.flexibility_model: - assert key in decoded_obj.flexibility_model - assert decoded_obj.flexibility_model[key] == starting_obj.flexibility_model[key] - return - - -def check_atom_equality(decoded_obj: Atom, starting_obj: Atom): - """ - Asserts that two Atom objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: Atom - An Atom object that we want to check is correctly typed and contains the correct data. - starting_obj: Atom - An Atom object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - correct_val_type = True - # np.array conversion checks - assert isinstance(decoded_obj.coord, numpy.ndarray) - for i_vec in decoded_obj.interaction_vectors: - correct_val_type = correct_val_type and isinstance(i_vec, numpy.ndarray) - assert correct_val_type - - # Checks for equality between decoded and original fields - assert isinstance(decoded_obj.index, int) - assert decoded_obj.index == starting_obj.index - # Only checks pdb info if the starting object's pdbinfo was a string. Otherwise, the decoder is not going to convert - # the pdbinfo field back to the PDBInfo type right now. - if isinstance(starting_obj.pdbinfo, str): - assert decoded_obj.pdbinfo == starting_obj.pdbinfo - assert isinstance(decoded_obj.charge, float) - assert decoded_obj.charge == starting_obj.charge - for idx, val in enumerate(decoded_obj.coord): - assert val == starting_obj.coord[idx] - assert isinstance(decoded_obj.atomic_num, int) - assert decoded_obj.atomic_num == starting_obj.atomic_num - assert decoded_obj.atom_type == starting_obj.atom_type - assert decoded_obj.graph == starting_obj.graph - assert isinstance(decoded_obj.is_ignore, bool) - assert decoded_obj.is_ignore == starting_obj.is_ignore - assert isinstance(decoded_obj.is_dummy, bool) - assert decoded_obj.is_dummy == starting_obj.is_dummy - assert isinstance(decoded_obj.is_pseudo_atom, bool) - assert decoded_obj.is_pseudo_atom == starting_obj.is_pseudo_atom - return - - -def check_bond_equality(decoded_obj: Bond, starting_obj: Bond): - """ - Asserts that two Bond objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: Bond - An Bond object that we want to check is correctly typed and contains the correct data. - starting_obj: Bond - An Bond object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - assert isinstance(decoded_obj.canon_id, tuple) - assert isinstance(decoded_obj.canon_id[0], int) - assert isinstance(decoded_obj.canon_id[1], int) - assert decoded_obj.canon_id == starting_obj.canon_id - assert isinstance(decoded_obj.index1, int) - assert decoded_obj.index1 == starting_obj.index1 - assert isinstance(decoded_obj.index2, int) - assert decoded_obj.index2 == starting_obj.index2 - assert isinstance(decoded_obj.rotatable, bool) - assert decoded_obj.rotatable == starting_obj.rotatable - return - - -def check_ring_equality(decoded_obj: Ring, starting_obj: Ring): - """ - Asserts that two Ring objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: Ring - An Ring object that we want to check is correctly typed and contains the correct data. - starting_obj: Ring - An Ring object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - assert isinstance(decoded_obj.ring_id, tuple) - assert decoded_obj.ring_id == starting_obj.ring_id - return - - -def check_restraint_equality(decoded_obj: Restraint, starting_obj: Restraint): - """ - Asserts that two Restraint objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: Restraint - An Restraint object that we want to check is correctly typed and contains the correct data. - starting_obj: Restraint - An Restraint object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - assert isinstance(decoded_obj.atom_index, int) - assert decoded_obj.atom_index == starting_obj.atom_index - assert isinstance(decoded_obj.target_coords, tuple) - assert decoded_obj.target_coords == starting_obj.target_coords - assert isinstance(decoded_obj.kcal_per_angstrom_square, float) - assert decoded_obj.kcal_per_angstrom_square == starting_obj.kcal_per_angstrom_square - assert isinstance(decoded_obj.delay_angstroms, float) - assert decoded_obj.delay_angstroms == starting_obj.delay_angstroms - return - - -def check_monomer_equality(decoded_obj: Monomer, starting_obj: Monomer): - """ - Asserts that two Monomer objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: Monomer - A Monomer object that we want to check is correctly typed and contains the correct data. - starting_obj: Monomer - A Monomer object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - # Goes through the Monomer's fields and checks that they are the expected type and match the Monomer - # object before serialization (that we have effectively rebuilt the Monomer) - - # RDKit Mols - Check whether we can test for equality with RDKit Mols - # assert decoded_monomer.raw_rdkit_mol == starting_residue.raw_rdkit_mol - assert type(decoded_obj.raw_rdkit_mol) == type(starting_obj.raw_rdkit_mol) - if isinstance(decoded_obj.raw_rdkit_mol, Chem.rdchem.Mol): - assert Chem.MolToSmiles(decoded_obj.raw_rdkit_mol) == Chem.MolToSmiles( - starting_obj.raw_rdkit_mol - ) - # assert decoded_monomer.rdkit_mol == starting_monomer.rdkit_mol - assert type(decoded_obj.rdkit_mol) == type(starting_obj.rdkit_mol) - if isinstance(decoded_obj.rdkit_mol, Chem.rdchem.Mol): - assert Chem.MolToSmiles(decoded_obj.rdkit_mol) == Chem.MolToSmiles( - starting_obj.rdkit_mol - ) - # assert decoded_monomer.padded_mol == starting_monomer.padded_mol - assert type(decoded_obj.padded_mol) == type(starting_obj.padded_mol) - if isinstance(decoded_obj.padded_mol, Chem.rdchem.Mol): - assert Chem.MolToSmiles(decoded_obj.padded_mol) == Chem.MolToSmiles( - starting_obj.padded_mol - ) - - # MapIDX - assert decoded_obj.mapidx_to_raw == starting_obj.mapidx_to_raw - assert decoded_obj.mapidx_from_raw == starting_obj.mapidx_from_raw - - # Non-Bool vars - assert decoded_obj.residue_template_key == starting_obj.residue_template_key - assert decoded_obj.input_resname == starting_obj.input_resname - assert decoded_obj.atom_names == starting_obj.atom_names - assert type(decoded_obj.molsetup) == type(starting_obj.molsetup) - if isinstance(decoded_obj.molsetup, RDKitMoleculeSetup): - check_molsetup_equality(decoded_obj.molsetup, starting_obj.molsetup) - - # Bools - assert decoded_obj.is_flexres_atom == starting_obj.is_flexres_atom - assert decoded_obj.is_movable == starting_obj.is_movable - return - - -def check_residue_chem_templates_equality( - decoded_obj: ResidueChemTemplates, starting_obj: ResidueChemTemplates -): - """ - Asserts that two ResidueChemTemplates objects are equal, and that the decoded_obj input has fields contain correctly - typed data. - - Parameters - ---------- - decoded_obj: ResidueChemTemplates - A ResidueChemTemplates object that we want to check is correctly typed and contains the correct data. - starting_obj: ResidueChemTemplates - A ResidueChemTemplates object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - # correct_val_type is used to check that all type conversions for nested data have happened correctly - correct_val_type = True - # Checks residue_templates by ensuring it has the same members as the starting object, that each value in the - # dictionary is a ResidueTemplate object, and that each template is equal to its corresponding ResidueTemplate in - # the starting object. - assert decoded_obj.residue_templates.keys() == starting_obj.residue_templates.keys() - for key in decoded_obj.residue_templates: - correct_val_type = correct_val_type & isinstance( - decoded_obj.residue_templates[key], ResidueTemplate - ) - check_residue_template_equality( - decoded_obj.residue_templates[key], starting_obj.residue_templates[key] - ) - assert correct_val_type - - # Directly compares ambiguous values. - assert decoded_obj.ambiguous == starting_obj.ambiguous - - # Checks padders by ensuring it has the same members as the starting object, that each value in the dictionary is a - # ResiduePadder object, and that each padder is equal to its corresponding ResiduePadder in the starting object. - assert decoded_obj.padders.keys() == starting_obj.padders.keys() - for key in decoded_obj.padders: - correct_val_type = correct_val_type & isinstance( - decoded_obj.padders[key], ResiduePadder - ) - check_residue_padder_equality( - decoded_obj.padders[key], starting_obj.padders[key] - ) - assert correct_val_type - return - - -def check_residue_template_equality( - decoded_obj: ResidueTemplate, starting_obj: ResidueTemplate -): - """ - Asserts that two ResidueTemplate objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: ResidueTemplate - A ResidueTemplate object that we want to check is correctly typed and contains the correct data. - starting_obj: ResidueTemplate - A ResidueTemplate object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - # Goes through the ResidueTemplate's fields and checks that they have the expected type and that they match the - # ResidueTemplate object before serialization - assert isinstance(decoded_obj.mol, Chem.rdchem.Mol) - - assert decoded_obj.link_labels == starting_obj.link_labels - assert decoded_obj.atom_names == starting_obj.atom_names - return - - -def check_residue_padder_equality( - decoded_obj: ResiduePadder, starting_obj: ResiduePadder -): - """ - Asserts that two ResiduePadder objects are equal, and that the decoded_obj input has fields contain correctly typed - data. - - Parameters - ---------- - decoded_obj: ResiduePadder - A ResiduePadder object that we want to check is correctly typed and contains the correct data. - starting_obj: ResiduePadder - A ResiduePadder object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - assert isinstance(decoded_obj.rxn, rdChemReactions.ChemicalReaction) - decoded_obj_rxn_smarts = rdChemReactions.ReactionToSmarts(decoded_obj.rxn) - starting_obj_rxn_smarts = rdChemReactions.ReactionToSmarts(starting_obj.rxn) - assert decoded_obj_rxn_smarts == starting_obj_rxn_smarts - - assert ( - decoded_obj.adjacent_smartsmol_mapidx == starting_obj.adjacent_smartsmol_mapidx - ) - - decoded_adj = decoded_obj.adjacent_smartsmol - starting_adj = starting_obj.adjacent_smartsmol - assert isinstance(decoded_adj, Chem.rdchem.Mol) or decoded_adj is None - if decoded_adj is None: - assert decoded_adj == starting_adj - else: - decoded_adj_smarts = Chem.MolToSmarts(decoded_adj) - starting_adj_smarts = Chem.MolToSmarts(starting_adj) - assert decoded_adj_smarts == starting_adj_smarts - return - - -def check_polymer_equality( - decoded_obj: Polymer, starting_obj: Polymer -): - """ - Asserts that two Polymer objects are equal, and that the decoded_obj input has fields contain correctly - typed data. - - Parameters - ---------- - decoded_obj: Polymer - A Polymer object that we want to check is correctly typed and contains the correct data. - starting_obj: Polymer - A Polymer object with the desired values to check the decoded object against. - - Returns - ------- - None - """ - # correct_val_type is used to check that all type conversions for nested data have happened correctly - correct_val_type = True - # Checks residue_chem_templates equality - check_residue_chem_templates_equality( - decoded_obj.residue_chem_templates, starting_obj.residue_chem_templates - ) - - # Loops through residues, checks that the decoded and starting obj share the same set of keys, that all the residues - # are represented as Monomer objects, and that the decoding and starting obj Monomers are equal. - assert decoded_obj.monomers.keys() == starting_obj.monomers.keys() - for key in decoded_obj.monomers: - correct_val_type = correct_val_type & isinstance( - decoded_obj.monomers[key], Monomer - ) - check_monomer_equality(decoded_obj.monomers[key], starting_obj.monomers[key]) - assert correct_val_type - - # Checks log equality - assert decoded_obj.log == starting_obj.log - return - # endregion