diff --git a/README.md b/README.md index c6428f2..1df3adf 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,86 @@ # python-chebifier An AI ensemble model for predicting chemical classes. + +## Installation + +```bash +# Clone the repository +git clone https://github.com/yourusername/python-chebifier.git +cd python-chebifier + +# Install the package +pip install -e . +``` + +## Usage + +### Command Line Interface + +The package provides a command-line interface (CLI) for making predictions using an ensemble model. + +```bash +# Get help +python -m chebifier.cli --help + +# Make predictions using a configuration file +python -m chebifier.cli predict example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O" + +# Make predictions using SMILES from a file +python -m chebifier.cli predict example_config.yml --smiles-file smiles.txt +``` + +### Configuration File + +The CLI requires a YAML configuration file that defines the ensemble model. Here's an example: + +```yaml +# Example configuration file for Chebifier ensemble model + +# Each key in the top-level dictionary is a model name +model1: + # Required: type of model (must be one of the keys in MODEL_TYPES) + type: electra + # Required: name of the model + model_name: electra_model1 + # Required: path to the checkpoint file + ckpt_path: /path/to/checkpoint1.ckpt + # Required: path to the target labels file + target_labels_path: /path/to/target_labels1.txt + # Optional: batch size for predictions (default is likely defined in the model) + batch_size: 32 + +model2: + type: electra + model_name: electra_model2 + ckpt_path: /path/to/checkpoint2.ckpt + target_labels_path: /path/to/target_labels2.txt + batch_size: 64 +``` + +### Python API + +You can also use the package programmatically: + +```python +from chebifier.ensemble.base_ensemble import BaseEnsemble +import yaml + +# Load configuration from YAML file +with open('configs/example_config.yml', 'r') as f: + config = yaml.safe_load(f) + +# Instantiate ensemble model +ensemble = BaseEnsemble(config) + +# Make predictions +smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"] +predictions = ensemble.predict_smiles_list(smiles_list) + +# Print results +for smile, prediction in zip(smiles_list, predictions): + print(f"SMILES: {smile}") + if prediction: + print(f"Predicted classes: {prediction}") + else: + print("No predictions") +``` diff --git a/chebifier/__init__.py b/chebifier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebifier/cli.py b/chebifier/cli.py new file mode 100644 index 0000000..704f8a0 --- /dev/null +++ b/chebifier/cli.py @@ -0,0 +1,70 @@ + + + +import click +import yaml +import sys +from chebifier.ensemble.base_ensemble import BaseEnsemble +from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble + + +@click.group() +def cli(): + """Command line interface for Chebifier.""" + pass + +ENSEMBLES = { + "mv": BaseEnsemble, + "wmv-ppvnpv": WMVwithPPVNPVEnsemble, + "wmv-f1": WMVwithF1Ensemble +} + +@cli.command() +@click.argument('config_file', type=click.Path(exists=True)) +@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') +@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') +@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') +@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') +def predict(config_file, smiles, smiles_file, output, ensemble_type): + """Predict ChEBI classes for SMILES strings using an ensemble model. + + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. + """ + # Load configuration from YAML file + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + # Instantiate ensemble model + ensemble = ENSEMBLES[ensemble_type](config) + + # Collect SMILES strings from arguments and/or file + smiles_list = list(smiles) + if smiles_file: + with open(smiles_file, 'r') as f: + smiles_list.extend([line.strip() for line in f if line.strip()]) + + if not smiles_list: + click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") + return + + # Make predictions + predictions = ensemble.predict_smiles_list(smiles_list) + + if output: + # save as json + import json + with open(output, 'w') as f: + json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2) + + else: + # Print results + for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)): + click.echo(f"Result for: {smiles}") + if prediction: + click.echo(f" Predicted classes: {', '.join(map(str, prediction))}") + else: + click.echo(" No predictions") + + +if __name__ == '__main__': + cli() diff --git a/chebifier/ensemble/__init__.py b/chebifier/ensemble/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py new file mode 100644 index 0000000..5166f43 --- /dev/null +++ b/chebifier/ensemble/base_ensemble.py @@ -0,0 +1,131 @@ +import os +from abc import ABC +import torch +import tqdm +from rdkit import Chem + +from chebifier.prediction_models.base_predictor import BasePredictor +from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor +from chebifier.prediction_models.electra_predictor import ElectraPredictor +from chebifier.prediction_models.gnn_predictor import ResGatedPredictor + +MODEL_TYPES = { + "electra": ElectraPredictor, + "resgated": ResGatedPredictor, + "chemlog": ChemLogPredictor +} + +class BaseEnsemble(ABC): + + def __init__(self, model_configs: dict): + self.models = [] + self.positive_prediction_threshold = 0.5 + for model_name, model_config in model_configs.items(): + model_cls = MODEL_TYPES[model_config["type"]] + model_instance = model_cls(**model_config) + assert isinstance(model_instance, BasePredictor) + self.models.append(model_instance) + + def gather_predictions(self, smiles_list): + # get predictions from all models for the SMILES list + # order them by alphabetically by label class + model_predictions = [] + predicted_classes = set() + for model in self.models: + model_predictions.append(model.predict_smiles_list(smiles_list)) + for logits_for_smiles in model_predictions[-1]: + if logits_for_smiles is not None: + for cls in logits_for_smiles: + predicted_classes.add(cls) + print(f"Sorting predictions...") + predicted_classes = sorted(list(predicted_classes)) + predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} + ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan + for i, model_prediction in enumerate(model_predictions): + for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction), + total=len(model_prediction), + desc=f"Sorting predictions for {self.models[i].model_name}"): + if logits_for_smiles is not None: + for cls in logits_for_smiles: + ordered_logits[j, predicted_classes[cls], i] = logits_for_smiles[cls] + + return ordered_logits, predicted_classes + + + def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs): + """ + Aggregates predictions from multiple models using weighted majority voting. + Optimized version using tensor operations instead of for loops. + """ + num_smiles, num_classes, num_models = predictions.shape + + # Create a mapping from class indices to class names for faster lookup + class_names = list(predicted_classes.keys()) + class_indices = {predicted_classes[cls]: cls for cls in class_names} + + # Get predictions for all classes + valid_predictions = ~torch.isnan(predictions) + valid_counts = valid_predictions.sum(dim=2) # Sum over models dimension + + # Skip classes with no valid predictions + has_valid_predictions = valid_counts > 0 + + # Calculate positive and negative predictions for all classes at once + positive_mask = (predictions > 0.5) & valid_predictions + negative_mask = (predictions < 0.5) & valid_predictions + + confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold) + + # Extract positive and negative weights + pos_weights = classwise_weights[0] # Shape: (num_classes, num_models) + neg_weights = classwise_weights[1] # Shape: (num_classes, num_models) + + # Calculate weighted predictions using broadcasting + # predictions shape: (num_smiles, num_classes, num_models) + # weights shape: (num_classes, num_models) + positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0) + negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0) + + # Sum over models dimension + positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) + negative_sum = negative_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) + + # Determine which classes to include for each SMILES + net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) + class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) + + # Convert tensor decisions to result list using list comprehension for efficiency + result = [ + [class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]] + for i in range(num_smiles) + ] + + return result + + + def calculate_classwise_weights(self, predicted_classes): + """No weights, simple majority voting""" + positive_weights = torch.ones(len(predicted_classes), len(self.models)) + negative_weights = torch.ones(len(predicted_classes), len(self.models)) + + return positive_weights, negative_weights + + def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: + preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" + predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" + if not load_preds_if_possible or not os.path.isfile(preds_file): + ordered_predictions = predicted_classes = self.gather_predictions(smiles_list) + # save predictions + torch.save(ordered_predictions, preds_file) + with open(predicted_classes_file, "w") as f: + for cls in predicted_classes: + f.write(f"{cls}\n") + else: + print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}") + ordered_predictions = torch.load(preds_file) + with open(predicted_classes_file, "r") as f: + predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} + + classwise_weights = self.calculate_classwise_weights(predicted_classes) + aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights) + return aggregated_predictions diff --git a/chebifier/ensemble/weighted_majority_ensemble.py b/chebifier/ensemble/weighted_majority_ensemble.py new file mode 100644 index 0000000..7100c46 --- /dev/null +++ b/chebifier/ensemble/weighted_majority_ensemble.py @@ -0,0 +1,54 @@ +import torch + +from chebifier.ensemble.base_ensemble import BaseEnsemble + + + +class WMVwithPPVNPVEnsemble(BaseEnsemble): + + def calculate_classwise_weights(self, predicted_classes): + """ + Given the positions of predicted classes in the predictions tensor, assign weights to each class. The + result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight + (default: 1) multiplied by the class-specific positive / negative weight (default 1). + """ + positive_weights = torch.ones(len(predicted_classes), len(self.models)) + negative_weights = torch.ones(len(predicted_classes), len(self.models)) + for j, model in enumerate(self.models): + positive_weights[:, j] *= model.model_weight + negative_weights[:, j] *= model.model_weight + if model.classwise_weights is None: + continue + for cls, weights in model.classwise_weights.items(): + positive_weights[predicted_classes[cls], j] *= weights["PPV"] + negative_weights[predicted_classes[cls], j] *= weights["NPV"] + + print(f"Calculated model weightings. The averages for positive / negative weights are:") + for i, model in enumerate(self.models): + print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}") + + return positive_weights, negative_weights + + +class WMVwithF1Ensemble(BaseEnsemble): + + def calculate_classwise_weights(self, predicted_classes): + """ + Given the positions of predicted classes in the predictions tensor, assign weights to each class. The + result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight + (default: 1) multiplied by the class-specific validation-f1 (default 1). + """ + weights_by_cls = torch.ones(len(predicted_classes), len(self.models)) + for j, model in enumerate(self.models): + weights_by_cls[:, j] *= model.model_weight + if model.classwise_weights is None: + continue + for cls, weights in model.classwise_weights.items(): + f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + weights_by_cls[predicted_classes[cls], j] *= f1 + + print(f"Calculated model weightings. The average weights are:") + for i, model in enumerate(self.models): + print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}") + + return weights_by_cls, weights_by_cls \ No newline at end of file diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py new file mode 100644 index 0000000..5633458 --- /dev/null +++ b/chebifier/prediction_models/base_predictor.py @@ -0,0 +1,16 @@ +from abc import ABC +import json + +class BasePredictor(ABC): + + def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs): + self.model_name = model_name + self.model_weight = model_weight + if classwise_weights_path is not None: + self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8")) + else: + self.classwise_weights = None + + + def predict_smiles_list(self, smiles_list: list[str]) -> dict: + raise NotImplementedError \ No newline at end of file diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py new file mode 100644 index 0000000..54b020a --- /dev/null +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -0,0 +1,35 @@ +import tqdm + +from chebifier.prediction_models.base_predictor import BasePredictor +from chemlog.alg_classification.charge_classifier import AlgChargeClassifier +from chemlog.alg_classification.peptide_size_classifier import AlgPeptideSizeClassifier +from chemlog.alg_classification.proteinogenics_classifier import AlgProteinogenicsClassifier +from chemlog.alg_classification.substructure_classifier import AlgSubstructureClassifier +from chemlog.cli import strategy_call, _smiles_to_mol, CLASSIFIERS + +class ChemLogPredictor(BasePredictor): + + def __init__(self, model_name: str, **kwargs): + super().__init__(model_name, **kwargs) + self.strategy = "algo" + self.classifier_instances = { + k: v() for k, v in CLASSIFIERS[self.strategy].items() + } + self.peptide_labels = ["15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923", + "48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837"] + + print(f"Initialised ChemLog model {self.model_name}") + + def predict_smiles_list(self, smiles_list: list[str]) -> list: + results = [] + for i, smiles in tqdm.tqdm(enumerate(smiles_list)): + mol = _smiles_to_mol(smiles) + if mol is None: + results.append(None) + else: + results.append({label: 1 if label in strategy_call(self.strategy, self.classifier_instances, mol)["chebi_classes"] else 0 for label in self.peptide_labels}) + + for classifier in self.classifier_instances.values(): + classifier.on_finish() + + return results \ No newline at end of file diff --git a/chebifier/prediction_models/electra_predictor.py b/chebifier/prediction_models/electra_predictor.py new file mode 100644 index 0000000..075eafa --- /dev/null +++ b/chebifier/prediction_models/electra_predictor.py @@ -0,0 +1,22 @@ +from chebifier.prediction_models.nn_predictor import NNPredictor +from chebai.models.electra import Electra +from chebai.preprocessing.reader import ChemDataReader + + +class ElectraPredictor(NNPredictor): + + def __init__(self, model_name: str, ckpt_path: str, **kwargs): + super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs) + print(f"Initialised Electra model {self.model_name} (device: {self.device})") + + def init_model(self, ckpt_path: str, **kwargs) -> Electra: + model = Electra.load_from_checkpoint( + ckpt_path, + map_location=self.device, + criterion=None, strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None + ) + model.eval() + return model + + diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py new file mode 100644 index 0000000..b139c6c --- /dev/null +++ b/chebifier/prediction_models/gnn_predictor.py @@ -0,0 +1,96 @@ +from chebifier.prediction_models.nn_predictor import NNPredictor +import chebai_graph.preprocessing.properties as p +import torch +from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred +from chebai_graph.preprocessing.reader import GraphPropertyReader +from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder +from torch_geometric.data.data import Data as GeomData + + +class ResGatedPredictor(NNPredictor): + + def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs): + super().__init__(model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs) + # molecular_properties is a list of class paths + if molecular_properties is not None: + properties = [self.load_class(prop)() for prop in molecular_properties] + properties = sorted( + properties, key=lambda prop: f"{prop.name}_{prop.encoder.name}" + ) + else: + properties = [] + self.molecular_properties = properties + assert isinstance(self.molecular_properties, list) and all( + isinstance(prop, p.MolecularProperty) for prop in self.molecular_properties + ) + print(f"Initialised GNN model {self.model_name} (device: {self.device})") + + def load_class(self, class_path: str): + module_path, class_name = class_path.rsplit(".", 1) + module = __import__(module_path, fromlist=[class_name]) + return getattr(module, class_name) + + def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: + model = ResGatedGraphConvNetGraphPred.load_from_checkpoint( + ckpt_path, map_location=torch.device(self.device), criterion=None, strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None, + config={"in_length": 256, "hidden_length": 512, "dropout_rate": 0.1, "n_conv_layers": 3, + "n_linear_layers": 3, "n_atom_properties": 158, "n_bond_properties": 7, + "n_molecule_properties": 200}) + model.eval() + return model + + def read_smiles(self, smiles): + reader = self.reader_cls() + d = reader.to_data(dict(features=smiles, labels=None)) + geom_data = d["features"] + edge_attr = geom_data.edge_attr + x = geom_data.x + molecule_attr = torch.empty((1, 0)) + for prop in self.molecular_properties: + property_values = reader.read_property(smiles, prop) + encoded_values = [] + for value in property_values: + # cant use standard encode for index encoder because model has been trained on a certain range of values + # use default value if we meet an unseen value + if isinstance(prop.encoder, IndexEncoder): + if str(value) in prop.encoder.cache: + index = prop.encoder.cache.index(str(value)) + prop.encoder.offset + else: + index = 0 + print(f"Unknown property value {value} for property {prop} at smiles {smiles}") + if isinstance(prop.encoder, OneHotEncoder): + encoded_values.append(torch.nn.functional.one_hot( + torch.tensor(index), num_classes=prop.encoder.get_encoding_length() + )) + else: + encoded_values.append(torch.tensor([index])) + + else: + encoded_values.append(prop.encoder.encode(value)) + if len(encoded_values) > 0: + encoded_values = torch.stack(encoded_values) + + if isinstance(encoded_values, torch.Tensor): + if len(encoded_values.size()) == 0: + encoded_values = encoded_values.unsqueeze(0) + if len(encoded_values.size()) == 1: + encoded_values = encoded_values.unsqueeze(1) + else: + encoded_values = torch.zeros( + (0, prop.encoder.get_encoding_length()) + ) + if isinstance(prop, p.AtomProperty): + x = torch.cat([x, encoded_values], dim=1) + elif isinstance(prop, p.BondProperty): + edge_attr = torch.cat([edge_attr, encoded_values], dim=1) + else: + molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1) + + d["features"] = GeomData( + x=x, + edge_index=geom_data.edge_index, + edge_attr=edge_attr, + molecule_attr=molecule_attr, + ) + return d \ No newline at end of file diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py new file mode 100644 index 0000000..1ee5e46 --- /dev/null +++ b/chebifier/prediction_models/nn_predictor.py @@ -0,0 +1,79 @@ +import tqdm + +from chebifier.prediction_models.base_predictor import BasePredictor +from rdkit import Chem +import numpy as np +import torch + +class NNPredictor(BasePredictor): + + def __init__(self, model_name: str, ckpt_path: str, reader_cls, target_labels_path: str, **kwargs): + super().__init__(model_name, **kwargs) + self.reader_cls = reader_cls + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self.init_model(ckpt_path=ckpt_path) + self.target_labels = [line.strip() for line in open(target_labels_path, encoding="utf-8")] + self.batch_size = kwargs.get("batch_size", 1) + + + def init_model(self, ckpt_path: str, **kwargs): + raise NotImplementedError("Model initialization must be implemented in subclasses.") + + def calculate_results(self, batch): + collator = self.reader_cls.COLLATOR() + dat = self.model._process_batch(collator(batch).to(self.device), 0) + return self.model(dat, **dat["model_kwargs"]) + + def batchify(self, batch): + cache = [] + for r in batch: + cache.append(r) + if len(cache) >= self.batch_size: + yield cache + cache = [] + if cache: + yield cache + + def read_smiles(self, smiles): + reader = self.reader_cls() + d = reader.to_data(dict(features=smiles, labels=None)) + return d + + def predict_smiles_list(self, smiles_list) -> list: + """Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary + Of classes and predicted values.""" + token_dicts = [] + could_not_parse = [] + index_map = dict() + for i, smiles in enumerate(smiles_list): + try: + # Try to parse the smiles string + if not smiles: + raise ValueError() + d = self.read_smiles(smiles) + # This is just for sanity checks + rdmol = Chem.MolFromSmiles(smiles, sanitize=False) + except Exception as e: + # Note if it fails + could_not_parse.append(i) + print(f"Failing to parse {smiles} due to {e}") + else: + if rdmol is None: + could_not_parse.append(i) + else: + index_map[i] = len(token_dicts) + token_dicts.append(d) + results = [] + if token_dicts: + for batch in tqdm.tqdm(self.batchify(token_dicts), desc=f"{self.model_name}", total=len(token_dicts)//self.batch_size): + result = self.calculate_results(batch) + if isinstance(result, dict) and "logits" in result: + result = result["logits"] + results += torch.sigmoid(result).cpu().detach().tolist() + results = np.stack(results, axis=0) + preds = [{self.target_labels[j]: p for j, p in enumerate(results[index_map[i]])} + if i not in could_not_parse else None for i in range(len(smiles_list))] + return preds + else: + return [None for _ in smiles_list] \ No newline at end of file