Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions reinvent/runmodes/RL/data_classes/work_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class WorkPackage:
scoring_function: Scorer
learning_strategy: RLReward
max_steps: int
max_smiles: int
terminator: Callable
diversity_filter: DiversityFilter = None
out_state_filename: str = None
31 changes: 29 additions & 2 deletions reinvent/runmodes/RL/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Learning(ABC):
def __init__(
self,
max_steps: int,
max_smiles: int,
stage_no: int,
prior: ModelAdapter,
state: ModelState,
Expand All @@ -67,6 +68,7 @@ def __init__(
"""Setup of the common framework"""

self.max_steps = max_steps
self.max_smiles = max_smiles
self.stage_no = stage_no
self.prior = prior

Expand Down Expand Up @@ -127,11 +129,14 @@ def optimize(self, converged: terminator_callable) -> bool:

step = -1
scaffolds = None
if converged.__class__.__name__ == 'TopkTerminator':
term_mode = 'topk'
else:
term_mode = 'score'
self.start_time = time.time()

for step in range(self.max_steps):
self.sampled = self.sampling_model.sample(self.input_smilies)
self.smiles_memory.update(self.sampled.smilies) # NOTE: global -> only update here!

self.invalid_mask = np.where(self.sampled.states == SmilesState.INVALID, False, True)
self.duplicate_mask = np.where(
Expand Down Expand Up @@ -181,10 +186,32 @@ def optimize(self, converged: terminator_callable) -> bool:
loss=float(loss),
)

if converged(mean_scores, step):
# Consider only valid non-duplicate SMILES for top-k/memory
valid_idx = np.where(self.sampled.states == SmilesState.VALID)[0]

if term_mode == 'topk':
# Consider only scores for new unique molecules
new_scores = [
scores[i] for i in valid_idx
if self.sampled.smilies[i] not in self.smiles_memory
]
is_converged = converged(new_scores, step)
else:
is_converged = converged(mean_scores, step)
print(step, mean_scores)

if is_converged:
logger.info(f"Terminating early in {step = }")
break

# Update memory after checking top-k convergence
memory_update = [self.sampled.smilies[i] for i in valid_idx]
self.smiles_memory.update(memory_update) # NOTE: global -> only update here!

if len(self.smiles_memory) > self.max_smiles:
logger.info(f"Max SMILES ({self.max_smiles}) reached, terminating in {step = }")
break

if self.tb_reporter: # FIXME: context manager?
self.tb_reporter.flush()
self.tb_reporter.close()
Expand Down
1 change: 1 addition & 0 deletions reinvent/runmodes/RL/run_staged_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def run_staged_learning(

optimize = model_learning(
max_steps=package.max_steps,
max_smiles=package.max_smiles,
stage_no=stage_no,
prior=prior,
state=state,
Expand Down
12 changes: 11 additions & 1 deletion reinvent/runmodes/RL/setup/create_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,21 @@ def create_packages(
max_score = stage.max_score
min_steps = stage.min_steps
max_steps = stage.max_steps
patience = stage.patience
topk = stage.topk
max_smiles = stage.max_smiles

terminator_param = stage.termination
terminator_name = terminator_param.lower().title()

try:
terminator: terminator_callable = getattr(terminators, f"{terminator_name}Terminator")
# Uses mean scores
if terminator_name in ["Simple", "Plateau"]:
terminator = terminator(max_score, min_steps)
# Uses batch scores
elif terminator_name == "Topk":
terminator = terminator(patience, min_steps, topk)
except KeyError:
msg = f"Unknown termination criterion: {terminator_name}"
logger.critical(msg)
Expand All @@ -59,7 +68,8 @@ def create_packages(
scoring_function,
reward_strategy,
max_steps,
terminator(max_score, min_steps),
max_smiles,
terminator,
diversity_filter,
chkpt_filename,
)
Expand Down
45 changes: 45 additions & 0 deletions reinvent/runmodes/RL/setup/terminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,48 @@ def __call__(self, score: int, step: int) -> bool:
return True

return False


class TopkTerminator:
"""Terminate when the top-k scores of unique molecules no longer improves."""

def __init__(self, patience: int, min_steps: float, topk: int = 10):
"""Parameterise terminator.

:param patience: terminate if top-k stops improving for patience epochs
:param min_steps: minimum number of steps to carry out
:param topk: how many scores to consider
"""
self.min_steps = min_steps
self.topk = topk
self.patience = patience
self.count = 0
self.sum = 0

self.heap = []

def __call__(self, scores: int, step: int) -> bool:
"""Terminate when top-k scores for unique SMILES stops improving

:param scores: current scores
:param step: current step number
"""

for score in scores:
if len(self.heap) < self.topk:
heapq.heappush(self.heap, score)
else:
if score > self.heap[0]:
heapq.heapreplace(self.heap, score)

if step > self.min_steps:
new_sum = sum(self.heap)
if new_sum > self.sum:
self.sum = new_sum
self.count = 0
else:
self.count += 1
if self.count >= self.patience:
return True

return False
3 changes: 3 additions & 0 deletions reinvent/runmodes/RL/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class SectionInception(GlobalConfig):
class SectionStage(GlobalConfig):
max_steps: int = Field(ge=1)
max_score: Optional[float] = Field(1.0, ge=0.0, le=1.0)
max_smiles: int = Field(100000000, ge=0) # arbitrary maximum
topk: Optional[int] = Field(10, ge=1)
patience: Optional[int] = Field(100, ge=1)
chkpt_file: Optional[str] = None
termination: str = "simple"
min_steps: Optional[int] = Field(50, ge=0)
Expand Down