Skip to content

Commit 9eb973e

Browse files
AustinTkmaziarz
andauthored
Add option to stop search after search graph exceeds a certain size (#85)
This option is called `limit_graph_nodes`. I also added a basic test for this. --------- Co-authored-by: Krzysztof Maziarz <krzysztof.maziarz@microsoft.com>
1 parent 47933d3 commit 9eb973e

File tree

5 files changed

+52
-3
lines changed

5 files changed

+52
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
2121
- Integrate the Graph2Edits model ([#65](https://github.com/microsoft/syntheseus/pull/65), [#66](https://github.com/microsoft/syntheseus/pull/66)) ([@kmaziarz])
2222
- Improve the docs and add tutorials ([#54](https://github.com/microsoft/syntheseus/pull/54), [#77](https://github.com/microsoft/syntheseus/pull/77), [#78](https://github.com/microsoft/syntheseus/pull/78), [#79](https://github.com/microsoft/syntheseus/pull/79), [#82](https://github.com/microsoft/syntheseus/pull/82)) ([@kmaziarz], [@austint])
2323
- Add random search algorithm as a simple baseline ([#83](https://github.com/microsoft/syntheseus/pull/83)) ([@austint])
24+
- Add optional argument `limit_graph_nodes` to base search algorithm class to stop search after the search graph exceeds a certain number of nodes ([#85](https://github.com/microsoft/syntheseus/pull/85)) ([@austint])
2425

2526
### Fixed
2627

syntheseus/cli/search.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from syntheseus.reaction_prediction.utils.config import get_config as cli_get_config
3535
from syntheseus.reaction_prediction.utils.misc import set_random_seed
3636
from syntheseus.reaction_prediction.utils.model_loading import get_model
37+
from syntheseus.search import INT_INF
3738
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch
3839
from syntheseus.search.algorithms.mcts import base as mcts_base
3940
from syntheseus.search.algorithms.mcts.molset import MolSetMCTS
@@ -114,6 +115,7 @@ class SearchConfig(BackwardModelConfig):
114115
time_limit_s: float = 600
115116
limit_reaction_model_calls: int = 1_000_000
116117
limit_iterations: int = 1_000_000
118+
limit_graph_nodes: int = INT_INF
117119
prevent_repeat_mol_in_trees: bool = True
118120

119121
use_gpu: bool = True # Whether to use a GPU
@@ -186,6 +188,7 @@ def run_from_config(config: SearchConfig) -> Path:
186188
"time_limit_s",
187189
"limit_reaction_model_calls",
188190
"limit_iterations",
191+
"limit_graph_nodes",
189192
"prevent_repeat_mol_in_trees",
190193
]
191194
}

syntheseus/search/algorithms/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
self,
7171
limit_iterations: int = INT_INF,
7272
limit_reaction_model_calls: int = INT_INF,
73+
limit_graph_nodes: int = INT_INF,
7374
time_limit_s: float = math.inf,
7475
max_expansion_depth: int = 50,
7576
expand_purchasable_mols: bool = False,
@@ -85,6 +86,7 @@ def __init__(
8586

8687
self.limit_iterations = limit_iterations
8788
self.limit_reaction_model_calls = limit_reaction_model_calls
89+
self.limit_graph_nodes = limit_graph_nodes
8890
self.time_limit_s = time_limit_s
8991
self.max_expansion_depth = max_expansion_depth
9092
self.expand_purchasable_mols = expand_purchasable_mols
@@ -176,6 +178,7 @@ def should_stop_search(self, graph) -> bool:
176178
return (
177179
(elapsed_time >= self.time_limit_s)
178180
or (self.reaction_model.num_calls() >= self.limit_reaction_model_calls)
181+
or (len(graph) >= self.limit_graph_nodes)
179182
or (self.stop_on_first_solution and graph.root_node.has_solution)
180183
)
181184

syntheseus/search/algorithms/mcts/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,18 @@ def _check_infinite_runtime(self) -> None:
137137
no_limit_iter = self.limit_iterations >= INT_INF
138138
no_limit_rxn = self.limit_reaction_model_calls >= INT_INF
139139
no_limit_time = self.time_limit_s >= math.inf
140+
no_limit_nodes = self.limit_graph_nodes >= INT_INF
140141

141-
if no_limit_iter and no_limit_rxn and no_limit_time:
142+
if no_limit_iter and no_limit_rxn and no_limit_time and no_limit_nodes:
142143
warnings.warn(
143144
"No kind of run limit set. This algorithm will almost certainty run forever."
144145
)
145146
elif no_limit_iter and no_limit_time:
146147
warnings.warn(
147-
"No iteration or time limit set. It is possible (but not certain) that MCTS "
148-
"will run forever with these settings. At the very least, it could run for a very long time."
148+
"No iteration or time limit set (although a reaction model call and/or graph node limit was set)."
149+
" Under these conditions, it is possible (but not certain) that MCTS "
150+
"will run forever (for example if there are no leaf nodes eligible for expansion in the graph)."
151+
" At the very least, it could run for an unexpectedly long time."
149152
" It is recommended to set either an iteration limit or a time limit."
150153
)
151154

syntheseus/tests/search/algorithms/test_base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,45 @@ def test_limit_reaction_model_calls(
157157
# Default case is to match exactly
158158
assert alg.reaction_model.num_calls() == limit
159159

160+
@pytest.mark.parametrize("limit", [0, 1, 2, 25, 100])
161+
def test_limit_graph_nodes(
162+
self,
163+
retrosynthesis_task6: RetrosynthesisTask,
164+
limit: int,
165+
) -> None:
166+
"""
167+
Test that limiting the number of nodes in the graph works as intended.
168+
The algorithm should run until the node limit is reached, and then stop
169+
(without adding *too many* extra nodes).
170+
171+
`retrosynthesis_task6` is chosen because it can create a very large graph.
172+
"""
173+
174+
# Run algorithm
175+
alg = self.setup_algorithm(
176+
reaction_model=retrosynthesis_task6.reaction_model,
177+
mol_inventory=retrosynthesis_task6.inventory,
178+
limit_graph_nodes=limit,
179+
limit_iterations=int(1e6), # a very high limit, but avoids MCTS warnings
180+
)
181+
output_graph, _ = alg.run_from_mol(retrosynthesis_task6.target_mol)
182+
183+
# The algorithm will stop running when the graph size meets or exceeds the limit.
184+
# However, since multiple nodes are added during each expansion, the node count might
185+
# not exactly equal the limit. Therefore, we choose a variable tolerance.
186+
# "Tolerance" here is len(graph) - limit
187+
if limit == 0:
188+
tolerance = 1 # will stop search immediately with only the root node
189+
elif limit == 1:
190+
tolerance = 0 # should not expand root node
191+
elif limit == 2:
192+
tolerance = 19 # a very high number, since first expansion brings node count to 21 for AND/OR graphs
193+
else:
194+
tolerance = 20 # a fairly high tolerance (should always be enough for one expansion)
195+
196+
# The actual test
197+
assert limit <= len(output_graph) <= tolerance + limit
198+
160199
@pytest.mark.parametrize("limit", [0, 1, 2, 100])
161200
def test_limit_iterations(
162201
self,

0 commit comments

Comments
 (0)