Skip to content

Commit 824a7a6

Browse files
committed
@ physigym new file for hyperopt search
1 parent b8ec4c1 commit 824a7a6

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

rl/sb/sb_hyperopt_own.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import optuna
2+
import numpy as np
3+
import gymnasium as gym
4+
import stable_baselines3 as sb3
5+
import sb3_contrib
6+
from stable_baselines3.common.callbacks import BaseCallback
7+
import os
8+
import wandb
9+
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER
10+
from stable_baselines3.common.logger import configure
11+
import stable_baselines3
12+
import sb3_contrib
13+
import time
14+
15+
16+
class TrackingCallback(BaseCallback):
17+
def __init__(self, trial, start_tracking_step=50_000, verbose=0, mean_elements=100, eval_frequency=10000):
18+
"""
19+
Callback to track episode rewards and store them when an episode ends.
20+
21+
:param start_tracking_step: Step after which we start storing rewards.
22+
:param verbose: Verbosity level.
23+
"""
24+
super().__init__(verbose)
25+
self.trial = trial
26+
self.start_tracking_step = start_tracking_step
27+
self.ep_rewards = [] # Stores rewards at the end of episodes
28+
self.last_eval_step = 0
29+
self.global_step = 0 # Track global training steps
30+
self.mean_elements = mean_elements
31+
self.eval_frequency = eval_frequency
32+
self.counter = 0
33+
self.mean_ep_reward = float("-inf")
34+
35+
def _on_step(self) -> bool:
36+
37+
if self.global_step>=self.start_tracking_step:
38+
for info in self.locals["infos"]:
39+
if "episode" in info: # Only store reward at end of episode
40+
self.ep_rewards.append(info["episode"]["r"])
41+
42+
if "reward" in self.locals:
43+
self.logger.record("env/reward_value", self.locals["rewards"][0] )
44+
45+
if "number_cancer_cells" in self.locals["infos"][0]:
46+
self.logger.record("env/cancer_cell_count", self.locals["infos"][0]["number_cancer_cells"])
47+
48+
if "actions" in self.locals:
49+
actions = self.locals["actions"][0]
50+
self.logger.record("env/drug_apoptosis", actions[0])
51+
self.logger.record("env/drug_reducing_antiapoptosis", actions[1])
52+
53+
self.global_step += 1 # Increment step counter
54+
self.logger.dump(step=self.global_step)
55+
56+
if len(self.ep_rewards) >= self.mean_elements and self.global_step%self.eval_frequency==0:
57+
self.mean_ep_reward = np.mean(self.ep_rewards[-self.mean_elements:])
58+
self.trial.report(self.mean_ep_reward, self.counter)
59+
self.counter +=1
60+
if self.trial.should_prune(): # Check if Optuna wants to prune
61+
print("⚠️ Trial pruned by Optuna!")
62+
raise optuna.TrialPruned()
63+
64+
return True
65+
66+
67+
68+
69+
class RLHyperparamTuner:
70+
def __init__(self, algo, env_id, n_trials=100, total_timesteps=int(1e6), pruner_type="median",
71+
start_tracking_step=50000, mean_elements=int(1e3), policy="CnnPolicy",
72+
wandb_project="RL_Optimization", wandb_entity=None, eval_frequency=int(1e4)):
73+
"""
74+
Class to tune hyperparameters for RL algorithms using Optuna.
75+
76+
:param algo: Algorithm name (e.g., "ppo", "sac", "tqc")
77+
:param env_id: Gymnasium environment ID
78+
:param n_trials: Number of Optuna trials
79+
:param total_timesteps: Total training timesteps per trial
80+
:param pruner_type: Type of Optuna pruner ("median", "halving", "hyperband")
81+
:param start_tracking_step: Number of warmup steps
82+
:param mean_elements: Number of episodes for averaging reward
83+
:param wandb_project: WandB project name
84+
:param wandb_entity: WandB entity (team/user)
85+
"""
86+
self.algo = algo.lower()
87+
self.env_id = env_id
88+
self.n_trials = n_trials
89+
self.total_timesteps = total_timesteps
90+
self.start_tracking_step = start_tracking_step
91+
self.mean_elements = mean_elements
92+
self.policy = policy
93+
self.wandb_project = wandb_project
94+
self.wandb_entity = wandb_entity
95+
self.eval_frequency = eval_frequency
96+
97+
# Validate algorithm
98+
if self.algo not in HYPERPARAMS_SAMPLER:
99+
raise ValueError(f"Algorithm {self.algo} not supported. Choose from {list(HYPERPARAMS_SAMPLER.keys())}.")
100+
101+
# Select Optuna pruning strategy
102+
if pruner_type == "median":
103+
self.pruner = optuna.pruners.MedianPruner(n_warmup_steps=3)
104+
elif pruner_type == "halving":
105+
self.pruner = optuna.pruners.SuccessiveHalvingPruner()
106+
elif pruner_type == "hyperband":
107+
self.pruner = optuna.pruners.HyperbandPruner()
108+
else:
109+
raise ValueError("Invalid pruner_type. Choose from 'median', 'halving', or 'hyperband'.")
110+
111+
def create_env(self):
112+
"""Create and wrap the environment."""
113+
env = gym.make(self.env_id)
114+
env = gym.wrappers.RescaleAction(env, min_action=-1, max_action=1)
115+
env = gym.wrappers.GrayscaleObservation(env)
116+
env = gym.wrappers.FrameStackObservation(env, stack_size=1)
117+
return env
118+
119+
def objective(self, trial: optuna.Trial):
120+
121+
"""Objective function for Optuna hyperparameter optimization."""
122+
env = self.create_env()
123+
hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial, n_actions=env.action_space.shape[0], n_envs=1, additional_args={})
124+
125+
if self.algo in sb3_contrib.__all__:
126+
algorithm = getattr(sb3_contrib,self.algo.upper())
127+
elif self.algo in stable_baselines3.__all__:
128+
algorithm = getattr(stable_baselines3,self.algo.upper())
129+
else:
130+
raise f"Algorith name does not exist: {self.algo.upper()}"
131+
# ----------------------
132+
# 📂 Logging Setup
133+
# ----------------------
134+
# WandB run setup
135+
run_name = f"{self.env_id}__{self.algo}_{int(time.time())}"
136+
wandb.init(
137+
project=self.wandb_project,
138+
entity=self.wandb_entity,
139+
name=run_name,
140+
sync_tensorboard=True,
141+
config=hyperparams,
142+
monitor_gym=True,
143+
save_code=True,
144+
)
145+
146+
# Logging directory for TensorBoard
147+
log_dir = f"/tensorboard_logs/{self.algo}/runs/{run_name}"
148+
os.makedirs(log_dir, exist_ok=True)
149+
150+
# Create model
151+
model = algorithm(self.policy, env, verbose=0, tensorboard_log=log_dir, **hyperparams)
152+
new_logger = configure(log_dir, ["tensorboard"])
153+
model.set_logger(new_logger)
154+
# Create pruning callback
155+
pruning_callback = TrackingCallback(trial=trial, start_tracking_step=self.start_tracking_step, mean_elements=self.mean_elements, eval_frequency=self.eval_frequency)
156+
157+
158+
model.learn(total_timesteps=self.total_timesteps, callback=pruning_callback)
159+
wandb.finish()
160+
try:
161+
return pruning_callback.mean_ep_reward # ✅ Get mean reward from the callback
162+
except:
163+
return None
164+
165+
166+
def run_optimization(self):
167+
"""Run Optuna optimization."""
168+
study = optuna.create_study(direction="maximize", pruner=self.pruner)
169+
study.optimize(self.objective, n_trials=self.n_trials)
170+
print("✅ Best hyperparameters:", study.best_params)

0 commit comments

Comments
 (0)