diff --git a/reinvent/runmodes/RL/data_classes/work_package.py b/reinvent/runmodes/RL/data_classes/work_package.py index e31abadc..d48dc326 100644 --- a/reinvent/runmodes/RL/data_classes/work_package.py +++ b/reinvent/runmodes/RL/data_classes/work_package.py @@ -22,6 +22,7 @@ class WorkPackage: scoring_function: Scorer learning_strategy: RLReward max_steps: int + max_smiles: int terminator: Callable diversity_filter: DiversityFilter = None out_state_filename: str = None diff --git a/reinvent/runmodes/RL/learning.py b/reinvent/runmodes/RL/learning.py index 0a0dc960..8abf23ef 100644 --- a/reinvent/runmodes/RL/learning.py +++ b/reinvent/runmodes/RL/learning.py @@ -49,6 +49,7 @@ class Learning(ABC): def __init__( self, max_steps: int, + max_smiles: int, stage_no: int, prior: ModelAdapter, state: ModelState, @@ -67,6 +68,7 @@ def __init__( """Setup of the common framework""" self.max_steps = max_steps + self.max_smiles = max_smiles self.stage_no = stage_no self.prior = prior @@ -127,11 +129,14 @@ def optimize(self, converged: terminator_callable) -> bool: step = -1 scaffolds = None + if converged.__class__.__name__ == 'TopkTerminator': + term_mode = 'topk' + else: + term_mode = 'score' self.start_time = time.time() for step in range(self.max_steps): self.sampled = self.sampling_model.sample(self.input_smilies) - self.smiles_memory.update(self.sampled.smilies) # NOTE: global -> only update here! self.invalid_mask = np.where(self.sampled.states == SmilesState.INVALID, False, True) self.duplicate_mask = np.where( @@ -181,10 +186,32 @@ def optimize(self, converged: terminator_callable) -> bool: loss=float(loss), ) - if converged(mean_scores, step): + # Consider only valid non-duplicate SMILES for top-k/memory + valid_idx = np.where(self.sampled.states == SmilesState.VALID)[0] + + if term_mode == 'topk': + # Consider only scores for new unique molecules + new_scores = [ + scores[i] for i in valid_idx + if self.sampled.smilies[i] not in self.smiles_memory + ] + is_converged = converged(new_scores, step) + else: + is_converged = converged(mean_scores, step) + print(step, mean_scores) + + if is_converged: logger.info(f"Terminating early in {step = }") break + # Update memory after checking top-k convergence + memory_update = [self.sampled.smilies[i] for i in valid_idx] + self.smiles_memory.update(memory_update) # NOTE: global -> only update here! + + if len(self.smiles_memory) > self.max_smiles: + logger.info(f"Max SMILES ({self.max_smiles}) reached, terminating in {step = }") + break + if self.tb_reporter: # FIXME: context manager? self.tb_reporter.flush() self.tb_reporter.close() diff --git a/reinvent/runmodes/RL/run_staged_learning.py b/reinvent/runmodes/RL/run_staged_learning.py index dacda1d4..edb3631f 100644 --- a/reinvent/runmodes/RL/run_staged_learning.py +++ b/reinvent/runmodes/RL/run_staged_learning.py @@ -170,6 +170,7 @@ def run_staged_learning( optimize = model_learning( max_steps=package.max_steps, + max_smiles=package.max_smiles, stage_no=stage_no, prior=prior, state=state, diff --git a/reinvent/runmodes/RL/setup/create_packages.py b/reinvent/runmodes/RL/setup/create_packages.py index ca2debb3..697275e1 100644 --- a/reinvent/runmodes/RL/setup/create_packages.py +++ b/reinvent/runmodes/RL/setup/create_packages.py @@ -38,12 +38,21 @@ def create_packages( max_score = stage.max_score min_steps = stage.min_steps max_steps = stage.max_steps + patience = stage.patience + topk = stage.topk + max_smiles = stage.max_smiles terminator_param = stage.termination terminator_name = terminator_param.lower().title() try: terminator: terminator_callable = getattr(terminators, f"{terminator_name}Terminator") + # Uses mean scores + if terminator_name in ["Simple", "Plateau"]: + terminator = terminator(max_score, min_steps) + # Uses batch scores + elif terminator_name == "Topk": + terminator = terminator(patience, min_steps, topk) except KeyError: msg = f"Unknown termination criterion: {terminator_name}" logger.critical(msg) @@ -59,7 +68,8 @@ def create_packages( scoring_function, reward_strategy, max_steps, - terminator(max_score, min_steps), + max_smiles, + terminator, diversity_filter, chkpt_filename, ) diff --git a/reinvent/runmodes/RL/setup/terminators.py b/reinvent/runmodes/RL/setup/terminators.py index 0867d0ed..a64b3d0a 100644 --- a/reinvent/runmodes/RL/setup/terminators.py +++ b/reinvent/runmodes/RL/setup/terminators.py @@ -96,3 +96,48 @@ def __call__(self, score: int, step: int) -> bool: return True return False + + +class TopkTerminator: + """Terminate when the top-k scores of unique molecules no longer improves.""" + + def __init__(self, patience: int, min_steps: float, topk: int = 10): + """Parameterise terminator. + + :param patience: terminate if top-k stops improving for patience epochs + :param min_steps: minimum number of steps to carry out + :param topk: how many scores to consider + """ + self.min_steps = min_steps + self.topk = topk + self.patience = patience + self.count = 0 + self.sum = 0 + + self.heap = [] + + def __call__(self, scores: int, step: int) -> bool: + """Terminate when top-k scores for unique SMILES stops improving + + :param scores: current scores + :param step: current step number + """ + + for score in scores: + if len(self.heap) < self.topk: + heapq.heappush(self.heap, score) + else: + if score > self.heap[0]: + heapq.heapreplace(self.heap, score) + + if step > self.min_steps: + new_sum = sum(self.heap) + if new_sum > self.sum: + self.sum = new_sum + self.count = 0 + else: + self.count += 1 + if self.count >= self.patience: + return True + + return False diff --git a/reinvent/runmodes/RL/validation.py b/reinvent/runmodes/RL/validation.py index b82e56e0..6025aa9f 100644 --- a/reinvent/runmodes/RL/validation.py +++ b/reinvent/runmodes/RL/validation.py @@ -58,6 +58,9 @@ class SectionInception(GlobalConfig): class SectionStage(GlobalConfig): max_steps: int = Field(ge=1) max_score: Optional[float] = Field(1.0, ge=0.0, le=1.0) + max_smiles: int = Field(100000000, ge=0) # arbitrary maximum + topk: Optional[int] = Field(10, ge=1) + patience: Optional[int] = Field(100, ge=1) chkpt_file: Optional[str] = None termination: str = "simple" min_steps: Optional[int] = Field(50, ge=0)