Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
lightning_logs/
notebooks/_*.ipynb

wandb/
# vscode
.vscode

# jupyter
MANIFEST
build
Expand Down Expand Up @@ -157,3 +156,12 @@ venv.bak/

# mypy
.mypy_cache/
molexpress/**/*.ckpt
molexpress/**/*.pth
molexpress/**/*.txt
molexpress/**/*.csv
molexpress/**/*.zip




123 changes: 89 additions & 34 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from functools import lru_cache
import logging
from typing import Dict, Tuple, Union

import numpy as np
Expand All @@ -9,6 +9,8 @@
from molexpress.datasets import featurizers
from molexpress.ops import chem_ops

LOGGER = logging.getLogger(__name__)


class PeptideGraphEncoder:
def __init__(
Expand All @@ -19,27 +21,39 @@ def __init__(
supports_masking: bool = False,
) -> None:
self.node_encoder = MolecularNodeEncoder(
atom_featurizers, supports_masking=supports_masking)
atom_featurizers, supports_masking=supports_masking
)
self.edge_encoder = MolecularEdgeEncoder(
bond_featurizers, self_loops=self_loops, supports_masking=supports_masking)
bond_featurizers, self_loops=self_loops, supports_masking=supports_masking
)

def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray:
residue_graphs = []
residue_sizes = []
for residue in residues:
residue_graph, residue_size = self._encode_residue(
residue, self.node_encoder, self.edge_encoder
)
residue_graphs.append(residue_graph)
residue_sizes.append(residue_size)
try:
residue_graph, residue_size = self._encode_residue(
residue, self.node_encoder, self.edge_encoder
)
except AttributeError as e:
raise GraphConstructionError(
f"Could not construct graph from residue: {residue}"
) from e
else:
residue_graphs.append(residue_graph)
residue_sizes.append(residue_size)

disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)

try:
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
except Exception as e:
raise GraphConstructionError("Cannot construct disjoint graph") from e

disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32")
return disjoint_peptide_graph

@staticmethod
@lru_cache(maxsize=None)
def _encode_residue(
residue: types.Molecule | types.SMILES | types.InChI,
node_encoder: MolecularNodeEncoder,
Expand Down Expand Up @@ -99,25 +113,41 @@ def masked_collate_fn(
disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
)
disjoint_peptide_batch_graph["peptide_size"] = np.concatenate(
[g["residue_size"].shape[:1] for g in disjoint_peptide_graphs]
).astype("int32")
disjoint_peptide_batch_graph["residue_size"] = np.concatenate(
[g["residue_size"] for g in disjoint_peptide_graphs]
).astype("int32")

node_state = disjoint_peptide_batch_graph['node_state']
node_state = disjoint_peptide_batch_graph["node_state"]
node_mask = np.random.uniform(size=node_state.shape[0]) < node_masking_rate
disjoint_peptide_batch_graph['node_loss_weight'] = np.copy(node_mask.astype(node_state.dtype))
disjoint_peptide_batch_graph['node_label'] = np.copy(disjoint_peptide_batch_graph['node_state'])
disjoint_peptide_batch_graph["node_loss_weight"] = np.copy(
node_mask.astype(node_state.dtype)
)
disjoint_peptide_batch_graph["node_label"] = np.copy(
disjoint_peptide_batch_graph["node_state"]
)
mask_state = np.zeros_like(node_state)
mask_state[:, -1] = 1.
disjoint_peptide_batch_graph['node_state'] = np.where(
node_mask[:, None], mask_state, node_state)

edge_state = disjoint_peptide_batch_graph['edge_state']
mask_state[:, -1] = 1.0
disjoint_peptide_batch_graph["node_state"] = np.where(
node_mask[:, None], mask_state, node_state
)

edge_state = disjoint_peptide_batch_graph["edge_state"]
edge_mask = np.random.uniform(size=edge_state.shape[0]) < edge_masking_rate
disjoint_peptide_batch_graph['edge_loss_weight'] = np.copy(edge_mask.astype(edge_state.dtype))
disjoint_peptide_batch_graph['edge_label'] = np.copy(disjoint_peptide_batch_graph['edge_state'])
disjoint_peptide_batch_graph["edge_loss_weight"] = np.copy(
edge_mask.astype(edge_state.dtype)
)
disjoint_peptide_batch_graph["edge_label"] = np.copy(
disjoint_peptide_batch_graph["edge_state"]
)
mask_state = np.zeros_like(edge_state)
mask_state[:, -1] = 1.
disjoint_peptide_batch_graph['edge_state'] = np.where(
edge_mask[:, None], mask_state, edge_state)

mask_state[:, -1] = 1.0
disjoint_peptide_batch_graph["edge_state"] = np.where(
edge_mask[:, None], mask_state, edge_state
)

return disjoint_peptide_batch_graph

@staticmethod
Expand All @@ -128,14 +158,23 @@ def _merge_molecular_graphs(

disjoint_molecular_graph = {}

if len(molecular_graphs) == 0:
raise GraphMergingError("No graphs to merge.")

disjoint_molecular_graph["node_state"] = np.concatenate(
[g["node_state"] for g in molecular_graphs]
)

if "edge_state" in molecular_graphs[0]:
disjoint_molecular_graph["edge_state"] = np.concatenate(
[g["edge_state"] for g in molecular_graphs]
)
try:
disjoint_molecular_graph["edge_state"] = np.concatenate(
[g["edge_state"] for g in molecular_graphs]
)
except ValueError as e:
raise GraphMergingError(
"Error during concatenation. Structure without bonds? Shapes of edge_state arrays: "
f"{[g['edge_state'].shape for g in molecular_graphs]}"
) from e

edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs])
edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs])
Expand Down Expand Up @@ -181,9 +220,9 @@ def output_dtype(self):

class MolecularEdgeEncoder:
def __init__(
self,
featurizers: list[featurizers.Featurizer],
self_loops: bool = False,
self,
featurizers: list[featurizers.Featurizer],
self_loops: bool = False,
supports_masking: bool = False,
) -> None:
self.featurizer = Composer(featurizers)
Expand Down Expand Up @@ -217,13 +256,15 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray:
if bond is None:
assert self.self_loops, "Found a bond to be None."
bond_encoding = np.zeros(
self.output_dim + int(self.self_loops) + int(self.supports_masking),
dtype=self.output_dtype)
self.output_dim + int(self.self_loops) + int(self.supports_masking),
dtype=self.output_dtype,
)
bond_encoding[-(int(self.self_loops) + int(self.supports_masking))] = 1
else:
bond_encoding = self.featurizer(bond)
bond_encoding = np.pad(
bond_encoding, (0, int(self.self_loops) + int(self.supports_masking)))
bond_encoding, (0, int(self.self_loops) + int(self.supports_masking))
)

bond_encodings.append(bond_encoding)

Expand All @@ -244,9 +285,23 @@ def __init__(
self.supports_masking = supports_masking

def __call__(self, molecule: types.Molecule) -> np.ndarray:
node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0)
node_encodings = np.stack(
[self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0
)
if self.supports_masking:
node_encodings = np.pad(node_encodings, [(0, 0), (0, 1)])
return {
"node_state": np.stack(node_encodings),
}


class GraphConstructionError(Exception):
"""Error during graph construction."""

pass


class GraphMergingError(Exception):
"""Error during graph merging."""

pass
1 change: 0 additions & 1 deletion molexpress/datasets/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from molexpress import types


DEFAULT_VOCABULARY = {
"AtomType": {
'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
Expand Down
2 changes: 1 addition & 1 deletion molexpress/layers/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
if self.skip_connection:
if self._transform_skip_connection:
node_state = gnn_ops.transform(state=node_state, kernel=self.skip_connect_kernel)
node_state_updated += node_state
node_state_updated = node_state_updated + node_state

if self.dropout_rate:
node_state_updated = self.dropout(node_state_updated)
Expand Down
2 changes: 1 addition & 1 deletion molexpress/layers/gin_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
edge_weight=edge_weight,
)

node_state_updated += (1 + self.epsilon) * node_state
node_state_updated = node_state_updated + (1 + self.epsilon) * node_state

node_state_updated = gnn_ops.transform(
state=node_state_updated, kernel=self.node_kernel_1, bias=self.node_bias_1
Expand Down
29 changes: 14 additions & 15 deletions molexpress/ops/chem_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,30 @@


def get_molecule(
molecule: types.Molecule | types.SMILES | types.InChI,
input_molecule: types.Molecule | types.SMILES | types.InChI,
catch_errors: bool = False,
) -> Chem.Mol | None:
"""Generates an molecule object."""

if isinstance(molecule, Chem.Mol):
return molecule
if isinstance(input_molecule, Chem.Mol):
return input_molecule

string = molecule

if string.startswith("InChI"):
molecule = Chem.MolFromInchi(string, sanitize=False)
if input_molecule.startswith("InChI"):
molecule = Chem.MolFromInchi(input_molecule, sanitize=False)
else:
molecule = Chem.MolFromSmiles(string, sanitize=False)
molecule = Chem.MolFromSmiles(input_molecule, sanitize=False)

if molecule is None:
raise ValueError(f"{string!r} is invalid.")
if not molecule:
raise ValueError(f"{input_molecule!r} is invalid.")

flag = Chem.SanitizeMol(molecule, catchErrors=True)
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
if not catch_errors:
return None
# Sanitize molecule again, without the sanitization step that caused
# the error previously. Unrealistic molecules might pass without an error.
Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)
if catch_errors:
raise ValueError(f"{input_molecule!r} is invalid.")
else:
# Sanitize molecule again, without the sanitization step that caused
# the error previously. Unrealistic molecules might pass without an error.
Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)

Chem.AssignStereochemistry(molecule, cleanIt=True, force=True, flagPossibleStereoCenters=True)

Expand Down
20 changes: 12 additions & 8 deletions molexpress/ops/gnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@ def transform(
Returns:
A transformed node state.
"""
if len(keras.ops.shape(kernel)) == 2:
if len(keras.ops.shape(kernel)) == 2:
# kernel.rank == state.rank == 2
state_transformed = keras.ops.matmul(state, kernel)
elif len(keras.ops.shape(kernel)) == len(keras.ops.shape(state)):
# kernel.rank == state.rank == 3
state_transformed = keras.ops.einsum('ijh,jkh->ikh', state, kernel)
# kernel.rank == state.rank == 3
state_transformed = keras.ops.einsum("ijh,jkh->ikh", state, kernel)
else:
# kernel.rank == 3 and state.rank == 2
state_transformed = keras.ops.einsum('ij,jkh->ikh', state, kernel)
state_transformed = keras.ops.einsum("ij,jkh->ikh", state, kernel)
if bias is not None:
state_transformed += bias
state_transformed = state_transformed + bias
return state_transformed


def aggregate(
node_state: types.Array,
edge_src: types.Array,
Expand Down Expand Up @@ -72,12 +73,12 @@ def aggregate(
edge_dst = keras.ops.expand_dims(edge_dst, axis=-1)

node_state_src = keras.ops.take_along_axis(node_state, edge_src, axis=0)

if edge_weight is not None:
node_state_src *= edge_weight

if edge_state is not None:
node_state_src += edge_state
node_state_src = node_state_src + edge_state

edge_dst = keras.ops.squeeze(edge_dst)

Expand All @@ -86,6 +87,7 @@ def aggregate(
)
return node_state_updated


def edge_softmax(score, edge_dst):
num_segments = keras.ops.maximum(keras.ops.max(edge_dst) + 1, 1)
score_max = keras.ops.segment_max(score, edge_dst, num_segments, sorted=False)
Expand All @@ -95,9 +97,10 @@ def edge_softmax(score, edge_dst):
denominator = gather(denominator, edge_dst)
return numerator / denominator


def gather(
node_state: types.Array,
edge: types.Array,
edge: types.Array,
) -> types.Array:
expected_rank = len(keras.ops.shape(node_state))
current_rank = len(keras.ops.shape(edge))
Expand All @@ -106,6 +109,7 @@ def gather(
node_state_edge = keras.ops.take_along_axis(node_state, edge, axis=0)
return node_state_edge


def segment_mean(
data: types.Array,
segment_ids: types.Array,
Expand Down