Skip to content
This repository was archived by the owner on May 24, 2025. It is now read-only.

Commit e2463b8

Browse files
committed
Merge branch 'reinvent.3.1'
2 parents 7ec8d30 + c8355fd commit e2463b8

File tree

6 files changed

+18
-14
lines changed

6 files changed

+18
-14
lines changed

running_modes/curriculum_learning/logging/base_curriculum_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def save_checkpoint(self, step, scaffold_filter, agent):
4646
actual_step = step + 1
4747
if self._log_config.logging_frequency > 0 and actual_step % self._log_config.logging_frequency == 0:
4848
self.save_diversity_memory(scaffold_filter)
49-
agent.save_to_file(os.path.join(self._log_config.result_folder, f'Agent.{actual_step}.ckpt'))
49+
agent.save(os.path.join(self._log_config.result_folder, f'Agent.{actual_step}.ckpt'))
5050

5151
@abstractmethod
5252
def save_final_state(self, agent, scaffold_filter):

running_modes/reinforcement_learning/scoring_strategy/lib_invent_scoring_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def __init__(self, strategy_configuration: LibInventScoringStrategyConfiguration
1919
self.reaction_filter = ReactionFilter(strategy_configuration.reaction_filter)
2020

2121
def evaluate(self, sampled_sequences: List[SampledSequencesDTO], step) -> FinalSummary:
22-
score_summary = self._apply_scoring_function(sampled_sequences)
22+
score_summary = self._apply_scoring_function(sampled_sequences, step)
2323

2424
score_summary.total_score = self.diversity_filter.update_score(score_summary, sampled_sequences, step)
2525
return score_summary
2626

27-
def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO]) -> FinalSummary:
27+
def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO], step:int) -> FinalSummary:
2828
molecules = self._join_scaffolds_and_decorations(sampled_sequences)
2929
smiles = [self._conversion.mol_to_smiles(molecule) if molecule else "INVALID" for molecule in molecules]
30-
final_score: FinalSummary = self.scoring_function.get_final_score(smiles)
30+
final_score: FinalSummary = self.scoring_function.get_final_score_for_step(smiles, step)
3131
final_score = self._apply_reaction_filters(molecules, final_score)
3232
return final_score
3333

running_modes/reinforcement_learning/scoring_strategy/link_invent_scoring_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ def __init__(self, strategy_configuration: ScoringStrategyConfiguration, diversi
1313
super().__init__(strategy_configuration, diversity_filter, logger)
1414

1515
def evaluate(self, sampled_sequences: List[SampledSequencesDTO], step) -> FinalSummary:
16-
score_summary = self._apply_scoring_function(sampled_sequences)
16+
score_summary = self._apply_scoring_function(sampled_sequences, step)
1717
score_summary = self._clean_scored_smiles(score_summary)
1818
score_summary.total_score = self.diversity_filter.update_score(score_summary, sampled_sequences, step)
1919
return score_summary
2020

21-
def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO]) -> FinalSummary:
21+
def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO], step) -> FinalSummary:
2222
molecules = self._join_linker_and_warheads(sampled_sequences, keep_labels=True)
2323
smiles = []
2424
for idx, molecule in enumerate(molecules):
@@ -33,7 +33,7 @@ def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO])
3333
f'\n\toutput: {sampled_sequences[idx].output}\n')
3434
finally:
3535
smiles.append(smiles_str)
36-
final_score: FinalSummary = self.scoring_function.get_final_score(smiles)
36+
final_score: FinalSummary = self.scoring_function.get_final_score_for_step(smiles, step)
3737
return final_score
3838

3939
def _join_linker_and_warheads(self, sampled_sequences: List[SampledSequencesDTO], keep_labels=False):

running_modes/sampling/logging/local_sampling_logger.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ def log_message(self, message: str):
1818
def timestep_report(self, smiles: [], likelihoods: np.array):
1919
self._log_timestep(smiles, likelihoods)
2020

21-
def __del__(self):
22-
self._summary_writer.close()
23-
2421
def _log_timestep(self, smiles: np.array, likelihoods: np.array):
2522
valid_smiles_fraction = fraction_valid_smiles(smiles)
2623
fraction_unique_entries = self._get_unique_entires_fraction(likelihoods)

running_modes/transfer_learning/logging/local_transfer_learning_logger.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ def __init__(self, configuration: GeneralConfigurationEnvelope):
1616
super().__init__(configuration)
1717
self._summary_writer = SummaryWriter(log_dir=self._log_config.logging_path)
1818

19-
def __del__(self):
20-
self._summary_writer.close()
21-
2219
def log_out_input_configuration(self):
2320
file = os.path.join(self._log_config.logging_path, "input.json")
2421
jsonstr = json.dumps(self._configuration, default=lambda x: x.__dict__, sort_keys=True, indent=4,

running_modes/validation/logging/remote_validation_logger.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
from running_modes.configurations.general_configuration_envelope import GeneralConfigurationEnvelope
44
from running_modes.validation.logging.base_validation_logger import BaseValidationLogger
5+
import running_modes.utils.configuration as utils_log
6+
7+
from running_modes.configurations.logging import get_remote_logging_auth_token
58

69

710
class RemoteValidationLogger(BaseValidationLogger):
811
def __init__(self, configuration: GeneralConfigurationEnvelope):
912
super().__init__(configuration)
13+
self._is_dev = utils_log._is_development_environment()
1014

1115
def log_message(self, message: str):
1216
data = {"valid": self.model_is_valid, "message": message}
@@ -15,13 +19,19 @@ def log_message(self, message: str):
1519
def _notify_server(self, data, to_address):
1620
"""This is called every time we are posting data to server"""
1721
try:
22+
headers = {
23+
'Accept': 'application/json', 'Content-Type': 'application/json',
24+
'Authorization': get_remote_logging_auth_token()
25+
}
1826
self._common_logger.warning(f"posting to {to_address}")
19-
response = requests.post(to_address, data=data)
27+
response = requests.post(to_address, json=data, headers=headers)
2028

2129
if response.status_code == requests.codes.ok:
2230
self._common_logger.info(f"SUCCESS: {response.status_code}")
31+
self._common_logger.info(response.content)
2332
else:
2433
self._common_logger.info(f"PROBLEM: {response.status_code}")
34+
self._common_logger.exception(data, exc_info=False)
2535
except Exception as e:
2636
self._common_logger.exception("Exception occurred", exc_info=True)
2737
self._common_logger.exception(f"Attempted posting the following data:")

0 commit comments

Comments
 (0)