diff --git a/.gitignore b/.gitignore index e8060b3..715369f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,8 @@ lightning_logs/ notebooks/_*.ipynb - +wandb/ # vscode .vscode - # jupyter MANIFEST build @@ -157,3 +156,12 @@ venv.bak/ # mypy .mypy_cache/ +molexpress/**/*.ckpt +molexpress/**/*.pth +molexpress/**/*.txt +molexpress/**/*.csv +molexpress/**/*.zip + + + + diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index 76db488..3962e2f 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import lru_cache +import logging from typing import Dict, Tuple, Union import numpy as np @@ -9,6 +9,8 @@ from molexpress.datasets import featurizers from molexpress.ops import chem_ops +LOGGER = logging.getLogger(__name__) + class PeptideGraphEncoder: def __init__( @@ -19,27 +21,39 @@ def __init__( supports_masking: bool = False, ) -> None: self.node_encoder = MolecularNodeEncoder( - atom_featurizers, supports_masking=supports_masking) + atom_featurizers, supports_masking=supports_masking + ) self.edge_encoder = MolecularEdgeEncoder( - bond_featurizers, self_loops=self_loops, supports_masking=supports_masking) + bond_featurizers, self_loops=self_loops, supports_masking=supports_masking + ) def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray: residue_graphs = [] residue_sizes = [] for residue in residues: - residue_graph, residue_size = self._encode_residue( - residue, self.node_encoder, self.edge_encoder - ) - residue_graphs.append(residue_graph) - residue_sizes.append(residue_size) + try: + residue_graph, residue_size = self._encode_residue( + residue, self.node_encoder, self.edge_encoder + ) + except AttributeError as e: + raise GraphConstructionError( + f"Could not construct graph from residue: {residue}" + ) from e + else: + residue_graphs.append(residue_graph) + residue_sizes.append(residue_size) disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs) - disjoint_peptide_graph["residue_size"] = np.array(residue_sizes) + + try: + disjoint_peptide_graph["residue_size"] = np.array(residue_sizes) + except Exception as e: + raise GraphConstructionError("Cannot construct disjoint graph") from e + disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32") return disjoint_peptide_graph @staticmethod - @lru_cache(maxsize=None) def _encode_residue( residue: types.Molecule | types.SMILES | types.InChI, node_encoder: MolecularNodeEncoder, @@ -99,25 +113,41 @@ def masked_collate_fn( disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs( disjoint_peptide_graphs ) + disjoint_peptide_batch_graph["peptide_size"] = np.concatenate( + [g["residue_size"].shape[:1] for g in disjoint_peptide_graphs] + ).astype("int32") + disjoint_peptide_batch_graph["residue_size"] = np.concatenate( + [g["residue_size"] for g in disjoint_peptide_graphs] + ).astype("int32") - node_state = disjoint_peptide_batch_graph['node_state'] + node_state = disjoint_peptide_batch_graph["node_state"] node_mask = np.random.uniform(size=node_state.shape[0]) < node_masking_rate - disjoint_peptide_batch_graph['node_loss_weight'] = np.copy(node_mask.astype(node_state.dtype)) - disjoint_peptide_batch_graph['node_label'] = np.copy(disjoint_peptide_batch_graph['node_state']) + disjoint_peptide_batch_graph["node_loss_weight"] = np.copy( + node_mask.astype(node_state.dtype) + ) + disjoint_peptide_batch_graph["node_label"] = np.copy( + disjoint_peptide_batch_graph["node_state"] + ) mask_state = np.zeros_like(node_state) - mask_state[:, -1] = 1. - disjoint_peptide_batch_graph['node_state'] = np.where( - node_mask[:, None], mask_state, node_state) - - edge_state = disjoint_peptide_batch_graph['edge_state'] + mask_state[:, -1] = 1.0 + disjoint_peptide_batch_graph["node_state"] = np.where( + node_mask[:, None], mask_state, node_state + ) + + edge_state = disjoint_peptide_batch_graph["edge_state"] edge_mask = np.random.uniform(size=edge_state.shape[0]) < edge_masking_rate - disjoint_peptide_batch_graph['edge_loss_weight'] = np.copy(edge_mask.astype(edge_state.dtype)) - disjoint_peptide_batch_graph['edge_label'] = np.copy(disjoint_peptide_batch_graph['edge_state']) + disjoint_peptide_batch_graph["edge_loss_weight"] = np.copy( + edge_mask.astype(edge_state.dtype) + ) + disjoint_peptide_batch_graph["edge_label"] = np.copy( + disjoint_peptide_batch_graph["edge_state"] + ) mask_state = np.zeros_like(edge_state) - mask_state[:, -1] = 1. - disjoint_peptide_batch_graph['edge_state'] = np.where( - edge_mask[:, None], mask_state, edge_state) - + mask_state[:, -1] = 1.0 + disjoint_peptide_batch_graph["edge_state"] = np.where( + edge_mask[:, None], mask_state, edge_state + ) + return disjoint_peptide_batch_graph @staticmethod @@ -128,14 +158,23 @@ def _merge_molecular_graphs( disjoint_molecular_graph = {} + if len(molecular_graphs) == 0: + raise GraphMergingError("No graphs to merge.") + disjoint_molecular_graph["node_state"] = np.concatenate( [g["node_state"] for g in molecular_graphs] ) if "edge_state" in molecular_graphs[0]: - disjoint_molecular_graph["edge_state"] = np.concatenate( - [g["edge_state"] for g in molecular_graphs] - ) + try: + disjoint_molecular_graph["edge_state"] = np.concatenate( + [g["edge_state"] for g in molecular_graphs] + ) + except ValueError as e: + raise GraphMergingError( + "Error during concatenation. Structure without bonds? Shapes of edge_state arrays: " + f"{[g['edge_state'].shape for g in molecular_graphs]}" + ) from e edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs]) edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs]) @@ -181,9 +220,9 @@ def output_dtype(self): class MolecularEdgeEncoder: def __init__( - self, - featurizers: list[featurizers.Featurizer], - self_loops: bool = False, + self, + featurizers: list[featurizers.Featurizer], + self_loops: bool = False, supports_masking: bool = False, ) -> None: self.featurizer = Composer(featurizers) @@ -217,13 +256,15 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray: if bond is None: assert self.self_loops, "Found a bond to be None." bond_encoding = np.zeros( - self.output_dim + int(self.self_loops) + int(self.supports_masking), - dtype=self.output_dtype) + self.output_dim + int(self.self_loops) + int(self.supports_masking), + dtype=self.output_dtype, + ) bond_encoding[-(int(self.self_loops) + int(self.supports_masking))] = 1 else: bond_encoding = self.featurizer(bond) bond_encoding = np.pad( - bond_encoding, (0, int(self.self_loops) + int(self.supports_masking))) + bond_encoding, (0, int(self.self_loops) + int(self.supports_masking)) + ) bond_encodings.append(bond_encoding) @@ -244,9 +285,23 @@ def __init__( self.supports_masking = supports_masking def __call__(self, molecule: types.Molecule) -> np.ndarray: - node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0) + node_encodings = np.stack( + [self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0 + ) if self.supports_masking: node_encodings = np.pad(node_encodings, [(0, 0), (0, 1)]) return { "node_state": np.stack(node_encodings), } + + +class GraphConstructionError(Exception): + """Error during graph construction.""" + + pass + + +class GraphMergingError(Exception): + """Error during graph merging.""" + + pass diff --git a/molexpress/datasets/featurizers.py b/molexpress/datasets/featurizers.py index 88d09cd..a4de2d8 100644 --- a/molexpress/datasets/featurizers.py +++ b/molexpress/datasets/featurizers.py @@ -8,7 +8,6 @@ from molexpress import types - DEFAULT_VOCABULARY = { "AtomType": { 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', diff --git a/molexpress/layers/gcn_conv.py b/molexpress/layers/gcn_conv.py index d85dc3b..5e83989 100644 --- a/molexpress/layers/gcn_conv.py +++ b/molexpress/layers/gcn_conv.py @@ -101,7 +101,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph: if self.skip_connection: if self._transform_skip_connection: node_state = gnn_ops.transform(state=node_state, kernel=self.skip_connect_kernel) - node_state_updated += node_state + node_state_updated = node_state_updated + node_state if self.dropout_rate: node_state_updated = self.dropout(node_state_updated) diff --git a/molexpress/layers/gin_conv.py b/molexpress/layers/gin_conv.py index 77c35a0..8c59655 100644 --- a/molexpress/layers/gin_conv.py +++ b/molexpress/layers/gin_conv.py @@ -107,7 +107,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph: edge_weight=edge_weight, ) - node_state_updated += (1 + self.epsilon) * node_state + node_state_updated = node_state_updated + (1 + self.epsilon) * node_state node_state_updated = gnn_ops.transform( state=node_state_updated, kernel=self.node_kernel_1, bias=self.node_bias_1 diff --git a/molexpress/ops/chem_ops.py b/molexpress/ops/chem_ops.py index c71c810..066cf9b 100644 --- a/molexpress/ops/chem_ops.py +++ b/molexpress/ops/chem_ops.py @@ -7,31 +7,30 @@ def get_molecule( - molecule: types.Molecule | types.SMILES | types.InChI, + input_molecule: types.Molecule | types.SMILES | types.InChI, catch_errors: bool = False, ) -> Chem.Mol | None: """Generates an molecule object.""" - if isinstance(molecule, Chem.Mol): - return molecule + if isinstance(input_molecule, Chem.Mol): + return input_molecule - string = molecule - - if string.startswith("InChI"): - molecule = Chem.MolFromInchi(string, sanitize=False) + if input_molecule.startswith("InChI"): + molecule = Chem.MolFromInchi(input_molecule, sanitize=False) else: - molecule = Chem.MolFromSmiles(string, sanitize=False) + molecule = Chem.MolFromSmiles(input_molecule, sanitize=False) - if molecule is None: - raise ValueError(f"{string!r} is invalid.") + if not molecule: + raise ValueError(f"{input_molecule!r} is invalid.") flag = Chem.SanitizeMol(molecule, catchErrors=True) if flag != Chem.SanitizeFlags.SANITIZE_NONE: - if not catch_errors: - return None - # Sanitize molecule again, without the sanitization step that caused - # the error previously. Unrealistic molecules might pass without an error. - Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag) + if catch_errors: + raise ValueError(f"{input_molecule!r} is invalid.") + else: + # Sanitize molecule again, without the sanitization step that caused + # the error previously. Unrealistic molecules might pass without an error. + Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag) Chem.AssignStereochemistry(molecule, cleanIt=True, force=True, flagPossibleStereoCenters=True) diff --git a/molexpress/ops/gnn_ops.py b/molexpress/ops/gnn_ops.py index 8e009dd..a1f099a 100644 --- a/molexpress/ops/gnn_ops.py +++ b/molexpress/ops/gnn_ops.py @@ -23,19 +23,20 @@ def transform( Returns: A transformed node state. """ - if len(keras.ops.shape(kernel)) == 2: + if len(keras.ops.shape(kernel)) == 2: # kernel.rank == state.rank == 2 state_transformed = keras.ops.matmul(state, kernel) elif len(keras.ops.shape(kernel)) == len(keras.ops.shape(state)): - # kernel.rank == state.rank == 3 - state_transformed = keras.ops.einsum('ijh,jkh->ikh', state, kernel) + # kernel.rank == state.rank == 3 + state_transformed = keras.ops.einsum("ijh,jkh->ikh", state, kernel) else: # kernel.rank == 3 and state.rank == 2 - state_transformed = keras.ops.einsum('ij,jkh->ikh', state, kernel) + state_transformed = keras.ops.einsum("ij,jkh->ikh", state, kernel) if bias is not None: - state_transformed += bias + state_transformed = state_transformed + bias return state_transformed + def aggregate( node_state: types.Array, edge_src: types.Array, @@ -72,12 +73,12 @@ def aggregate( edge_dst = keras.ops.expand_dims(edge_dst, axis=-1) node_state_src = keras.ops.take_along_axis(node_state, edge_src, axis=0) - + if edge_weight is not None: node_state_src *= edge_weight if edge_state is not None: - node_state_src += edge_state + node_state_src = node_state_src + edge_state edge_dst = keras.ops.squeeze(edge_dst) @@ -86,6 +87,7 @@ def aggregate( ) return node_state_updated + def edge_softmax(score, edge_dst): num_segments = keras.ops.maximum(keras.ops.max(edge_dst) + 1, 1) score_max = keras.ops.segment_max(score, edge_dst, num_segments, sorted=False) @@ -95,9 +97,10 @@ def edge_softmax(score, edge_dst): denominator = gather(denominator, edge_dst) return numerator / denominator + def gather( node_state: types.Array, - edge: types.Array, + edge: types.Array, ) -> types.Array: expected_rank = len(keras.ops.shape(node_state)) current_rank = len(keras.ops.shape(edge)) @@ -106,6 +109,7 @@ def gather( node_state_edge = keras.ops.take_along_axis(node_state, edge, axis=0) return node_state_edge + def segment_mean( data: types.Array, segment_ids: types.Array,