@@ -18,12 +18,30 @@ def __init__(self, model_name: str, program_directory: Optional[Path]=None, chem
1818 self .chebi_graph = kwargs .get ("chebi_graph" , None )
1919
2020 def predict_smiles_list (self , smiles_list : list [str ]) -> list :
21- result_list = c3p_classifier .classify (smiles_list , self .program_directory , self .chemical_classes , strict = True )
21+ result_list = c3p_classifier .classify (smiles_list , self .program_directory , self .chemical_classes , strict = False )
2222 result_reformatted = [dict () for _ in range (len (smiles_list ))]
2323 for result in result_list :
2424 chebi_id = result .class_id .split (":" )[1 ]
2525 result_reformatted [smiles_list .index (result .input_smiles )][chebi_id ] = result .is_match
2626 if result .is_match and self .chebi_graph is not None :
2727 for parent in list (self .chebi_graph .predecessors (int (chebi_id ))):
2828 result_reformatted [smiles_list .index (result .input_smiles )][str (parent )] = 1
29- return result_reformatted
29+ return result_reformatted
30+
31+ def explain_smiles (self , smiles ):
32+ """
33+ C3P provides natural language explanations for each prediction (positive or negative). Since there are more
34+ than 300 classes, only take the positive ones.
35+ """
36+ highlights = []
37+ result_list = c3p_classifier .classify ([smiles ], self .program_directory , self .chemical_classes , strict = False )
38+ for result in result_list :
39+ if result .is_match :
40+ highlights .append (
41+ ("text" , f"For class { result .class_name } ({ result .class_id } ), C3P gave the following explanation: { result .reason } " )
42+ )
43+ highlights = [
44+ ("text" , f"C3P made positive predictions for { len (highlights )} classes. The explanations are as follows:" )
45+ ] + highlights
46+
47+ return {"highlights" : highlights }
0 commit comments