Skip to content

Commit 0de1ca2

Browse files
authored
Improve typing of reaction metadata (#86)
Following a discussion in #80, this PR adjusts the type hints for reaction metadata to be `ReactionMetaData` instead of the generic `Dict[str, Any]` to allow for more precise type checking.
1 parent 9eb973e commit 0de1ca2

File tree

3 files changed

+20
-24
lines changed

3 files changed

+20
-24
lines changed

syntheseus/reaction_prediction/inference/chemformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from syntheseus.interface.bag import Bag
1616
from syntheseus.interface.models import InputType, ReactionType
1717
from syntheseus.interface.molecule import Molecule
18+
from syntheseus.interface.reaction import ReactionMetaData
1819
from syntheseus.reaction_prediction.inference.base import ExternalReactionModel
1920
from syntheseus.reaction_prediction.utils.inference import (
2021
get_module_path,
@@ -135,7 +136,7 @@ def _get_reactions(
135136
# and [InputType, ReactionType] is not visible to mypy.
136137
if self.is_forward():
137138
process_fn: Callable[
138-
[InputType, List[str], List[Dict[str, Any]]], Sequence[ReactionType]
139+
[InputType, List[str], List[ReactionMetaData]], Sequence[ReactionType]
139140
] = process_raw_smiles_outputs_forwards # type: ignore[assignment]
140141
else:
141142
process_fn = process_raw_smiles_outputs_backwards # type: ignore[assignment]

syntheseus/reaction_prediction/inference/root_aligned.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
import random
1414
import warnings
1515
from collections import defaultdict
16-
from typing import Any, Dict, List, Optional, Sequence
16+
from typing import Any, List, Optional, Sequence
1717

1818
import yaml
1919
from rdkit import Chem
2020

2121
from syntheseus.interface.molecule import Molecule
22-
from syntheseus.interface.reaction import SingleProductReaction
22+
from syntheseus.interface.reaction import ReactionMetaData, SingleProductReaction
2323
from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
2424
from syntheseus.reaction_prediction.utils.inference import (
2525
get_unique_file_in_dir,
@@ -85,7 +85,7 @@ def _mols_to_batch(self, inputs) -> List[bytes]:
8585
# Example outcome: b'C C ( = O ) c 1 c c c 2 c ( c c n 2 C ( = O ) O C ( C ) ( C ) C ) c 1\n'.
8686
return [bytes(smi_tokenizer(input.smiles) + "\n", "utf-8") for input in inputs]
8787

88-
def _build_kwargs_from_scores(self, scores: List[float]) -> List[Dict[str, Any]]:
88+
def _build_kwargs_from_scores(self, scores: List[float]) -> List[ReactionMetaData]:
8989
"""Compute kwargs to save in the predictions given raw scores from the RootAligned model.
9090
9191
The scores we get from the model cannot be directly interpreted as a (log) probability.
@@ -111,7 +111,7 @@ def _build_kwargs_from_scores(self, scores: List[float]) -> List[Dict[str, Any]]
111111
1.0 / (k + 1) for k in range(self.beam_size)
112112
)
113113

114-
kwargs_list: List[Dict[str, Any]] = []
114+
kwargs_list: List[ReactionMetaData] = []
115115
for score in scores:
116116
best_pos = -math.floor(score / 1e8)
117117
total_rr = score + best_pos * 1e8
@@ -121,14 +121,15 @@ def _build_kwargs_from_scores(self, scores: List[float]) -> List[Dict[str, Any]]
121121

122122
new_score = total_rr - (best_pos + 1) * max_possible_total_rr
123123
assert new_score <= 0.0
124-
metadata = {
125-
"original_score": score,
126-
"best_pos": best_pos,
127-
"total_rr": total_rr,
128-
"score": new_score,
129-
}
130-
131-
kwargs_list.append(metadata)
124+
125+
kwargs_list.append(
126+
{ # type: ignore[typeddict-unknown-key]
127+
"original_score": score,
128+
"best_pos": best_pos,
129+
"total_rr": total_rr,
130+
"score": new_score,
131+
}
132+
)
132133

133134
# Make sure the new scores produce the same ranking.
134135
for kwargs, next_kwargs in zip(kwargs_list, kwargs_list[1:]):

syntheseus/reaction_prediction/utils/inference.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Dict, List, Sequence, Union, cast
2+
from typing import Any, List, Sequence, Union
33

44
from syntheseus.interface.bag import Bag
55
from syntheseus.interface.molecule import Molecule
@@ -12,7 +12,7 @@
1212

1313

1414
def process_raw_smiles_outputs_backwards(
15-
input: Molecule, output_list: List[str], metadata_list: List[Dict[str, Any]]
15+
input: Molecule, output_list: List[str], metadata_list: List[ReactionMetaData]
1616
) -> Sequence[SingleProductReaction]:
1717
"""Convert raw SMILES outputs into a list of `SingleProductReaction` objects.
1818
@@ -33,16 +33,14 @@ def process_raw_smiles_outputs_backwards(
3333
# Only consider the prediction if the SMILES can be parsed.
3434
if reactants is not None:
3535
predictions.append(
36-
SingleProductReaction(
37-
product=input, reactants=reactants, metadata=cast(ReactionMetaData, metadata)
38-
)
36+
SingleProductReaction(product=input, reactants=reactants, metadata=metadata)
3937
)
4038

4139
return predictions
4240

4341

4442
def process_raw_smiles_outputs_forwards(
45-
input: Bag[Molecule], output_list: List[str], metadata_list: List[Dict[str, Any]]
43+
input: Bag[Molecule], output_list: List[str], metadata_list: List[ReactionMetaData]
4644
) -> Sequence[Reaction]:
4745
"""Convert raw SMILES outputs into a list of `Reaction` objects.
4846
Like method `process_raw_smiles_outputs_backwards`, but for forward models.
@@ -63,11 +61,7 @@ def process_raw_smiles_outputs_forwards(
6361

6462
# Only consider the prediction if the SMILES can be parsed.
6563
if products is not None:
66-
predictions.append(
67-
Reaction(
68-
products=products, reactants=input, metadata=cast(ReactionMetaData, metadata)
69-
)
70-
)
64+
predictions.append(Reaction(products=products, reactants=input, metadata=metadata))
7165

7266
return predictions
7367

0 commit comments

Comments
 (0)