Skip to content

Commit f8583cb

Browse files
committed
add huggingface download to cli
1 parent 001538d commit f8583cb

File tree

6 files changed

+49
-16
lines changed

6 files changed

+49
-16
lines changed

chebifier/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from chebifier.cli import cli
2+
3+
if __name__ == '__main__':
4+
cli()

chebifier/cli.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import click
24
import yaml
35

@@ -9,14 +11,14 @@ def cli():
911
pass
1012

1113
@cli.command()
12-
@click.argument('config_file', type=click.Path(exists=True))
14+
@click.option('--config_file', type=click.Path(exists=True), default=os.path.join('configs', 'huggingface_config.yml'), help="Configuration file for ensemble models")
1315
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
1416
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
1517
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
1618
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
1719
@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)")
1820
@click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)")
19-
def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version):
21+
def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version, use_confidence):
2022
"""Predict ChEBI classes for SMILES strings using an ensemble model.
2123
2224
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
@@ -39,7 +41,7 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_versi
3941
return
4042

4143
# Make predictions
42-
predictions = ensemble.predict_smiles_list(smiles_list)
44+
predictions = ensemble.predict_smiles_list(smiles_list, use_confidence=use_confidence)
4345

4446
if output:
4547
# save as json

chebifier/ensemble/base_ensemble.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from chebai.preprocessing.datasets.chebi import ChEBIOver50
55
from chebai.result.analyse_sem import PredictionSmoother
66

7+
from api.hugging_face import download_model_files
78
from chebifier.prediction_models.base_predictor import BasePredictor
89

910

@@ -17,14 +18,20 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
1718
self.positive_prediction_threshold = 0.5
1819
for model_name, model_config in model_configs.items():
1920
model_cls = MODEL_TYPES[model_config["type"]]
20-
model_instance = model_cls(model_name, **model_config)
21+
if "hugging_face" in model_config:
22+
hugging_face_kwargs = download_model_files(model_config["hugging_face"])
23+
else:
24+
hugging_face_kwargs = {}
25+
model_instance = model_cls(model_name, **model_config, **hugging_face_kwargs)
2126
assert isinstance(model_instance, BasePredictor)
2227
self.models.append(model_instance)
2328

24-
self.smoother = PredictionSmoother(ChEBIOver50(chebi_version=chebi_version), disjoint_files=[
29+
self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
30+
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
31+
self.disjoint_files=[
2532
os.path.join("data", "disjoint_chebi.csv"),
2633
os.path.join("data", "disjoint_additional.csv")
27-
])
34+
]
2835

2936

3037
def gather_predictions(self, smiles_list):
@@ -110,7 +117,7 @@ def calculate_classwise_weights(self, predicted_classes):
110117

111118
return positive_weights, negative_weights
112119

113-
def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
120+
def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs) -> list:
114121
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
115122
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
116123
if not load_preds_if_possible or not os.path.isfile(preds_file):
@@ -128,11 +135,12 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
128135
predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())}
129136

130137
classwise_weights = self.calculate_classwise_weights(predicted_classes)
131-
class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights)
138+
class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights, **kwargs)
132139
# Smooth predictions
133140
class_names = list(predicted_classes.keys())
134-
self.smoother.label_names = class_names
135-
class_decisions = self.smoother(class_decisions)
141+
# initialise new smoother class since we don't know the labels beforehand (this could be more efficient)
142+
new_smoother = PredictionSmoother(self.chebi_dataset, label_names=class_names, disjoint_files=self.disjoint_files)
143+
class_decisions = new_smoother(class_decisions)
136144

137145
class_names = list(predicted_classes.keys())
138146
class_indices = {predicted_classes[cls]: cls for cls in class_names}

chebifier/model_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
)
1111

1212
ENSEMBLES = {
13-
"en_mv": BaseEnsemble,
14-
"en_wmv-ppvnpv": WMVwithPPVNPVEnsemble,
15-
"en_wmv-f1": WMVwithF1Ensemble,
13+
"mv": BaseEnsemble,
14+
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
15+
"wmv-f1": WMVwithF1Ensemble,
1616
}
1717

1818

configs/huggingface_config.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
chemlog_peptides:
3+
type: chemlog
4+
model_weight: 100
5+
6+
#resgated_huggingface:
7+
# type: resgated
8+
# hugging_face:
9+
# repo_id: aditya0by0/python-chebifier
10+
# subfolder: resgated
11+
# files:
12+
# ckpt: resgated.ckpt
13+
# labels: classes.txt
14+
15+
electra_huggingface:
16+
type: electra
17+
hugging_face:
18+
repo_id: aditya0by0/python-chebifier
19+
subfolder: electra
20+
files:
21+
ckpt: electra.ckpt
22+
labels: classes.txt

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ dependencies = [
2727
"chemlog>=1.0.4"
2828
]
2929

30-
[project.scripts]
31-
chebifier = "chebifier.cli:cli"
32-
3330

3431
[tool.setuptools]
3532
packages = ["chebifier", "chebifier.ensemble", "chebifier.prediction_models"]

0 commit comments

Comments
 (0)