1313import random
1414import warnings
1515from collections import defaultdict
16- from typing import Any , Dict , List , Optional , Sequence
16+ from typing import Any , List , Optional , Sequence
1717
1818import yaml
1919from rdkit import Chem
2020
2121from syntheseus .interface .molecule import Molecule
22- from syntheseus .interface .reaction import SingleProductReaction
22+ from syntheseus .interface .reaction import ReactionMetaData , SingleProductReaction
2323from syntheseus .reaction_prediction .inference .base import ExternalBackwardReactionModel
2424from syntheseus .reaction_prediction .utils .inference import (
2525 get_unique_file_in_dir ,
@@ -85,7 +85,7 @@ def _mols_to_batch(self, inputs) -> List[bytes]:
8585 # Example outcome: b'C C ( = O ) c 1 c c c 2 c ( c c n 2 C ( = O ) O C ( C ) ( C ) C ) c 1\n'.
8686 return [bytes (smi_tokenizer (input .smiles ) + "\n " , "utf-8" ) for input in inputs ]
8787
88- def _build_kwargs_from_scores (self , scores : List [float ]) -> List [Dict [ str , Any ] ]:
88+ def _build_kwargs_from_scores (self , scores : List [float ]) -> List [ReactionMetaData ]:
8989 """Compute kwargs to save in the predictions given raw scores from the RootAligned model.
9090
9191 The scores we get from the model cannot be directly interpreted as a (log) probability.
@@ -111,7 +111,7 @@ def _build_kwargs_from_scores(self, scores: List[float]) -> List[Dict[str, Any]]
111111 1.0 / (k + 1 ) for k in range (self .beam_size )
112112 )
113113
114- kwargs_list : List [Dict [ str , Any ] ] = []
114+ kwargs_list : List [ReactionMetaData ] = []
115115 for score in scores :
116116 best_pos = - math .floor (score / 1e8 )
117117 total_rr = score + best_pos * 1e8
@@ -121,14 +121,15 @@ def _build_kwargs_from_scores(self, scores: List[float]) -> List[Dict[str, Any]]
121121
122122 new_score = total_rr - (best_pos + 1 ) * max_possible_total_rr
123123 assert new_score <= 0.0
124- metadata = {
125- "original_score" : score ,
126- "best_pos" : best_pos ,
127- "total_rr" : total_rr ,
128- "score" : new_score ,
129- }
130-
131- kwargs_list .append (metadata )
124+
125+ kwargs_list .append (
126+ { # type: ignore[typeddict-unknown-key]
127+ "original_score" : score ,
128+ "best_pos" : best_pos ,
129+ "total_rr" : total_rr ,
130+ "score" : new_score ,
131+ }
132+ )
132133
133134 # Make sure the new scores produce the same ranking.
134135 for kwargs , next_kwargs in zip (kwargs_list , kwargs_list [1 :]):
0 commit comments