Skip to content

Commit 001538d

Browse files
committed
fix cli and ensemble imports
1 parent f3b3905 commit 001538d

File tree

2 files changed

+2
-20
lines changed

2 files changed

+2
-20
lines changed

chebifier/cli.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,12 @@
22
import yaml
33

44
from .model_registry import ENSEMBLES
5-
from chebifier.ensemble.base_ensemble import BaseEnsemble
6-
from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble
7-
85

96
@click.group()
107
def cli():
118
"""Command line interface for Chebifier."""
129
pass
1310

14-
ENSEMBLES = {
15-
"mv": BaseEnsemble,
16-
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
17-
"wmv-f1": WMVwithF1Ensemble
18-
}
19-
2011
@cli.command()
2112
@click.argument('config_file', type=click.Path(exists=True))
2213
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')

chebifier/ensemble/base_ensemble.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
11
import os
2-
from abc import ABC
32
import torch
43
import tqdm
54
from chebai.preprocessing.datasets.chebi import ChEBIOver50
65
from chebai.result.analyse_sem import PredictionSmoother
76

87
from chebifier.prediction_models.base_predictor import BasePredictor
9-
from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor
10-
from chebifier.prediction_models.electra_predictor import ElectraPredictor
11-
from chebifier.prediction_models.gnn_predictor import ResGatedPredictor
12-
13-
MODEL_TYPES = {
14-
"electra": ElectraPredictor,
15-
"resgated": ResGatedPredictor,
16-
"chemlog": ChemLogPredictor
17-
}
188

19-
class BaseEnsemble(ABC):
9+
10+
class BaseEnsemble:
2011

2112
def __init__(self, model_configs: dict, chebi_version: int = 241):
2213
# Deferred Import: To avoid circular import error

0 commit comments

Comments
 (0)