|
1 | 1 | """ Module containing classes to store and manipulate collections of synthetic routes. |
2 | 2 | """ |
| 3 | + |
3 | 4 | from __future__ import annotations |
4 | 5 |
|
5 | 6 | import copy |
6 | 7 | from typing import TYPE_CHECKING |
7 | 8 |
|
8 | 9 | import numpy as np |
| 10 | +from rxnutils.routes.comparison import simple_route_similarity |
| 11 | +from rxnutils.routes.readers import read_aizynthfinder_dict |
9 | 12 |
|
10 | 13 | try: |
11 | 14 | from route_distances.clustering import ClusteringHelper |
12 | | - from route_distances.route_distances import route_distances_calculator |
13 | 15 | except ImportError: |
14 | | - pass |
| 16 | + SUPPORT_CLUSTERING = False |
| 17 | +else: |
| 18 | + SUPPORT_CLUSTERING = True |
15 | 19 |
|
16 | 20 | from aizynthfinder.analysis.utils import CombinedReactionTrees, RouteSelectionArguments |
17 | | -from aizynthfinder.reactiontree import SUPPORT_DISTANCES, ReactionTree |
| 21 | +from aizynthfinder.reactiontree import ReactionTree |
18 | 22 | from aizynthfinder.search.mcts import MctsNode, MctsSearchTree |
19 | 23 |
|
20 | 24 | if TYPE_CHECKING: |
|
27 | 31 | PilImage, |
28 | 32 | Sequence, |
29 | 33 | StrDict, |
30 | | - Union, |
31 | 34 | ) |
32 | 35 |
|
33 | 36 |
|
@@ -77,7 +80,7 @@ def __init__(self, reaction_trees: Sequence[ReactionTree], **kwargs) -> None: |
77 | 80 | self.clusters: Optional[Sequence[RouteCollection]] = self._unpack_kwarg( |
78 | 81 | "clusters", **kwargs |
79 | 82 | ) |
80 | | - self._distance_matrix: Dict[str, np.ndarray] = {} |
| 83 | + self._distance_matrix: Optional[np.ndarray] = None |
81 | 84 | self._combined_reaction_trees: Optional[CombinedReactionTrees] = None |
82 | 85 |
|
83 | 86 | @classmethod |
@@ -142,41 +145,30 @@ def cluster( |
142 | 145 | self, |
143 | 146 | n_clusters: int, |
144 | 147 | max_clusters: int = 5, |
145 | | - distances_model: str = "ted", |
146 | 148 | **kwargs: Any, |
147 | 149 | ) -> np.ndarray: |
148 | 150 | """ |
149 | 151 | Cluster the route collection into a number of clusters. |
150 | 152 |
|
151 | | - Additional arguments to the distance or clustering algorithm |
152 | | - can be passed in as key-word arguments. |
153 | | -
|
154 | | - When `distances_model` is "lstm", a key-word argument `model_path` needs to be given |
155 | | - when `distances_model` is "ted", two optional key-word arguments `timeout` and `content` |
156 | | - can be given. |
| 153 | + Additional arguments to the clustering algorithm can be passed in as key-word arguments. |
157 | 154 |
|
158 | 155 | If the number of reaction trees are less than 3, no clustering will be performed |
159 | 156 |
|
160 | 157 | :param n_clusters: the desired number of clusters, if less than 2 triggers optimization |
161 | 158 | :param max_clusters: the maximum number of clusters to consider |
162 | | - :param distances_model: can be ted or lstm and determines how the route distances are computed |
163 | 159 | :return: the cluster labels |
| 160 | + :raises ValueError: if the route_distance package is not installed |
164 | 161 | """ |
165 | | - if not SUPPORT_DISTANCES: |
| 162 | + if not SUPPORT_CLUSTERING: |
166 | 163 | raise ValueError( |
167 | 164 | "Clustering is not supported by this installation." |
168 | 165 | " Please install aizynthfinder with extras dependencies." |
169 | 166 | ) |
170 | 167 |
|
171 | 168 | if len(self.reaction_trees) < 3: |
172 | 169 | return np.asarray([]) |
173 | | - dist_kwargs = { |
174 | | - "content": kwargs.pop("content", "both"), |
175 | | - "timeout": kwargs.pop("timeout", None), |
176 | | - "model_path": kwargs.pop("model_path", None), |
177 | | - } |
178 | 170 | try: |
179 | | - distances = self.distance_matrix(model=distances_model, **dist_kwargs) |
| 171 | + distances = self.distance_matrix() |
180 | 172 | except ValueError: |
181 | 173 | return np.asarray([]) |
182 | 174 |
|
@@ -213,7 +205,7 @@ def compute_scores(self, *scorers: Scorer) -> None: |
213 | 205 | for scorer in scorers: |
214 | 206 | for idx, score in enumerate(scorer(list_)): # type: ignore |
215 | 207 | self.all_scores[idx][repr(scorer)] = score |
216 | | - self._update_route_dict(self.all_scores, "all_score") |
| 208 | + self._update_route_dict(self.all_scores, "all_scores") |
217 | 209 |
|
218 | 210 | def dict_with_extra( |
219 | 211 | self, include_scores=False, include_metadata=False |
@@ -244,41 +236,19 @@ def dict_with_scores(self) -> Sequence[StrDict]: |
244 | 236 | """ |
245 | 237 | return self.dict_with_extra(include_scores=True) |
246 | 238 |
|
247 | | - def distance_matrix( |
248 | | - self, recreate: bool = False, model: str = "ted", **kwargs: Any |
249 | | - ) -> np.ndarray: |
| 239 | + def distance_matrix(self, recreate: bool = False) -> np.ndarray: |
250 | 240 | """ |
251 | 241 | Compute the distance matrix between each pair of reaction trees |
252 | 242 |
|
253 | | - All key-word arguments are passed along to the `route_distance_calculator` |
254 | | - function from the `route_distances` package. |
255 | | -
|
256 | | - When `model` is "lstm", a key-word argument `model_path` needs to be given |
257 | | - when `model` is "ted", two optional key-word arguments `timeout` and `content` |
258 | | - can be given. |
259 | | -
|
260 | 243 | :param recreate: if False, use a cached one if available |
261 | | - :param model: the type of model to use "ted" or "lstm" |
262 | 244 | :return: the square distance matrix |
263 | 245 | """ |
264 | | - if not SUPPORT_DISTANCES: |
265 | | - raise ValueError( |
266 | | - "Distance calculations are not supported by this installation." |
267 | | - " Please install aizynthfinder with extras dependencies." |
268 | | - ) |
269 | | - |
270 | | - if model == "lstm" and not kwargs.get("model_path"): |
271 | | - raise KeyError( |
272 | | - "Need to provide 'model_path' argument when using LSTM model for computing distances" |
273 | | - ) |
274 | | - content = kwargs.get("content", "both") |
275 | | - cache_key = kwargs.get("model_path", "") if model == "lstm" else content |
276 | | - if self._distance_matrix.get(cache_key) is not None and not recreate: |
277 | | - return self._distance_matrix[cache_key] |
278 | | - calculator = route_distances_calculator(model, **kwargs) |
279 | | - distances = calculator(self.dicts) |
280 | | - self._distance_matrix[cache_key] = distances |
281 | | - return distances |
| 246 | + if self._distance_matrix is not None and not recreate: |
| 247 | + return self._distance_matrix |
| 248 | + routes = [read_aizynthfinder_dict(dict_) for dict_ in self.dicts] |
| 249 | + self._distance_matrix = 1.0 - simple_route_similarity(routes) |
| 250 | + assert self._distance_matrix is not None |
| 251 | + return self._distance_matrix |
282 | 252 |
|
283 | 253 | def make_dicts(self) -> Sequence[StrDict]: |
284 | 254 | """Convert all reaction trees to dictionaries""" |
|
0 commit comments