Skip to content

Commit 802712a

Browse files
authored
Fix issues with GLN (#92)
The GLN model is the only model class that uses a bespoke environment based on Python 3.7 instead of the shared environment that uses Python 3.9. Due to this, it is not tested in CI, and is generally prone to bugs being introduced over time. After trying to run the GLN model using the weights uploaded to figshare, I found that the upload was incomplete, and some files were missing from the checkpoint, preventing the model from being loaded. On top of this, there were a few small issues in the wrapper that accumulated over time, which also needed to be addressed to get the model to run again. This PR changes the weights link to a complete upload and also fixes the small issues in the model wrapper.
1 parent e01914f commit 802712a

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
1010
### Fixed
1111

1212
- Fix incorrectly uploaded RetroKNN weights ([#91](https://github.com/microsoft/syntheseus/pull/91)) ([@kmaziarz])
13+
- Fix GLN weights and issues in its model wrapper ([#92](https://github.com/microsoft/syntheseus/pull/92)) ([@kmaziarz])
1314

1415
## [0.4.0] - 2024-04-10
1516

docs/single_step.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ See table below for the links to the default checkpoints.
88
| Model checkpoint link | Source |
99
|----------------------------------------------------------------|--------|
1010
| [Chemformer](https://figshare.com/ndownloader/files/42009888) | finetuned by us starting from checkpoint released by authors |
11-
| [GLN](https://figshare.com/ndownloader/files/42012720) | released by authors |
11+
| [GLN](https://figshare.com/ndownloader/files/45882867) | released by authors |
1212
| [Graph2Edits](https://figshare.com/ndownloader/files/44194301) | released by authors |
1313
| [LocalRetro](https://figshare.com/ndownloader/files/42287319) | trained by us |
1414
| [MEGAN](https://figshare.com/ndownloader/files/42012732) | trained by us |

syntheseus/reaction_prediction/inference/default_checkpoint_links.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
backward:
22
Chemformer: https://figshare.com/ndownloader/files/42009888
3-
GLN: https://figshare.com/ndownloader/files/42012720
3+
GLN: https://figshare.com/ndownloader/files/45882867
44
Graph2Edits: https://figshare.com/ndownloader/files/44194301
55
LocalRetro: https://figshare.com/ndownloader/files/42287319
66
MEGAN: https://figshare.com/ndownloader/files/42012732

syntheseus/reaction_prediction/inference/gln.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def __init__(self, *args, dataset_name: str = "schneider50k", **kwargs) -> None:
5757

5858
self.model = RetroGLN(self.model_dir, chkpt_path)
5959

60+
@property
61+
def name(self) -> str:
62+
return "GLN"
63+
6064
def get_parameters(self):
6165
return self.model.gln.parameters()
6266

@@ -73,7 +77,9 @@ def _get_model_predictions(
7377
return process_raw_smiles_outputs_backwards(
7478
input=input,
7579
output_list=result["reactants"],
76-
metadata_list=[{"probability": probability} for probability in result["scores"]],
80+
metadata_list=[
81+
{"probability": probability.item()} for probability in result["scores"]
82+
],
7783
)
7884

7985
def _get_reactions(

0 commit comments

Comments
 (0)