|
| 1 | +import optuna |
| 2 | +import numpy as np |
| 3 | +import gymnasium as gym |
| 4 | +import stable_baselines3 as sb3 |
| 5 | +import sb3_contrib |
| 6 | +from stable_baselines3.common.callbacks import BaseCallback |
| 7 | +import os |
| 8 | +import wandb |
| 9 | +from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER |
| 10 | +from stable_baselines3.common.logger import configure |
| 11 | +import stable_baselines3 |
| 12 | +import sb3_contrib |
| 13 | +import time |
| 14 | + |
| 15 | + |
| 16 | +class TrackingCallback(BaseCallback): |
| 17 | + def __init__(self, trial, start_tracking_step=50_000, verbose=0, mean_elements=100, eval_frequency=10000): |
| 18 | + """ |
| 19 | + Callback to track episode rewards and store them when an episode ends. |
| 20 | +
|
| 21 | + :param start_tracking_step: Step after which we start storing rewards. |
| 22 | + :param verbose: Verbosity level. |
| 23 | + """ |
| 24 | + super().__init__(verbose) |
| 25 | + self.trial = trial |
| 26 | + self.start_tracking_step = start_tracking_step |
| 27 | + self.ep_rewards = [] # Stores rewards at the end of episodes |
| 28 | + self.last_eval_step = 0 |
| 29 | + self.global_step = 0 # Track global training steps |
| 30 | + self.mean_elements = mean_elements |
| 31 | + self.eval_frequency = eval_frequency |
| 32 | + self.counter = 0 |
| 33 | + self.mean_ep_reward = float("-inf") |
| 34 | + |
| 35 | + def _on_step(self) -> bool: |
| 36 | + |
| 37 | + if self.global_step>=self.start_tracking_step: |
| 38 | + for info in self.locals["infos"]: |
| 39 | + if "episode" in info: # Only store reward at end of episode |
| 40 | + self.ep_rewards.append(info["episode"]["r"]) |
| 41 | + |
| 42 | + if "reward" in self.locals: |
| 43 | + self.logger.record("env/reward_value", self.locals["rewards"][0] ) |
| 44 | + |
| 45 | + if "number_cancer_cells" in self.locals["infos"][0]: |
| 46 | + self.logger.record("env/cancer_cell_count", self.locals["infos"][0]["number_cancer_cells"]) |
| 47 | + |
| 48 | + if "actions" in self.locals: |
| 49 | + actions = self.locals["actions"][0] |
| 50 | + self.logger.record("env/drug_apoptosis", actions[0]) |
| 51 | + self.logger.record("env/drug_reducing_antiapoptosis", actions[1]) |
| 52 | + |
| 53 | + self.global_step += 1 # Increment step counter |
| 54 | + self.logger.dump(step=self.global_step) |
| 55 | + |
| 56 | + if len(self.ep_rewards) >= self.mean_elements and self.global_step%self.eval_frequency==0: |
| 57 | + self.mean_ep_reward = np.mean(self.ep_rewards[-self.mean_elements:]) |
| 58 | + self.trial.report(self.mean_ep_reward, self.counter) |
| 59 | + self.counter +=1 |
| 60 | + if self.trial.should_prune(): # Check if Optuna wants to prune |
| 61 | + print("⚠️ Trial pruned by Optuna!") |
| 62 | + raise optuna.TrialPruned() |
| 63 | + |
| 64 | + return True |
| 65 | + |
| 66 | + |
| 67 | + |
| 68 | + |
| 69 | +class RLHyperparamTuner: |
| 70 | + def __init__(self, algo, env_id, n_trials=100, total_timesteps=int(1e6), pruner_type="median", |
| 71 | + start_tracking_step=50000, mean_elements=int(1e3), policy="CnnPolicy", |
| 72 | + wandb_project="RL_Optimization", wandb_entity=None, eval_frequency=int(1e4)): |
| 73 | + """ |
| 74 | + Class to tune hyperparameters for RL algorithms using Optuna. |
| 75 | +
|
| 76 | + :param algo: Algorithm name (e.g., "ppo", "sac", "tqc") |
| 77 | + :param env_id: Gymnasium environment ID |
| 78 | + :param n_trials: Number of Optuna trials |
| 79 | + :param total_timesteps: Total training timesteps per trial |
| 80 | + :param pruner_type: Type of Optuna pruner ("median", "halving", "hyperband") |
| 81 | + :param start_tracking_step: Number of warmup steps |
| 82 | + :param mean_elements: Number of episodes for averaging reward |
| 83 | + :param wandb_project: WandB project name |
| 84 | + :param wandb_entity: WandB entity (team/user) |
| 85 | + """ |
| 86 | + self.algo = algo.lower() |
| 87 | + self.env_id = env_id |
| 88 | + self.n_trials = n_trials |
| 89 | + self.total_timesteps = total_timesteps |
| 90 | + self.start_tracking_step = start_tracking_step |
| 91 | + self.mean_elements = mean_elements |
| 92 | + self.policy = policy |
| 93 | + self.wandb_project = wandb_project |
| 94 | + self.wandb_entity = wandb_entity |
| 95 | + self.eval_frequency = eval_frequency |
| 96 | + |
| 97 | + # Validate algorithm |
| 98 | + if self.algo not in HYPERPARAMS_SAMPLER: |
| 99 | + raise ValueError(f"Algorithm {self.algo} not supported. Choose from {list(HYPERPARAMS_SAMPLER.keys())}.") |
| 100 | + |
| 101 | + # Select Optuna pruning strategy |
| 102 | + if pruner_type == "median": |
| 103 | + self.pruner = optuna.pruners.MedianPruner(n_warmup_steps=3) |
| 104 | + elif pruner_type == "halving": |
| 105 | + self.pruner = optuna.pruners.SuccessiveHalvingPruner() |
| 106 | + elif pruner_type == "hyperband": |
| 107 | + self.pruner = optuna.pruners.HyperbandPruner() |
| 108 | + else: |
| 109 | + raise ValueError("Invalid pruner_type. Choose from 'median', 'halving', or 'hyperband'.") |
| 110 | + |
| 111 | + def create_env(self): |
| 112 | + """Create and wrap the environment.""" |
| 113 | + env = gym.make(self.env_id) |
| 114 | + env = gym.wrappers.RescaleAction(env, min_action=-1, max_action=1) |
| 115 | + env = gym.wrappers.GrayscaleObservation(env) |
| 116 | + env = gym.wrappers.FrameStackObservation(env, stack_size=1) |
| 117 | + return env |
| 118 | + |
| 119 | + def objective(self, trial: optuna.Trial): |
| 120 | + |
| 121 | + """Objective function for Optuna hyperparameter optimization.""" |
| 122 | + env = self.create_env() |
| 123 | + hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial, n_actions=env.action_space.shape[0], n_envs=1, additional_args={}) |
| 124 | + |
| 125 | + if self.algo in sb3_contrib.__all__: |
| 126 | + algorithm = getattr(sb3_contrib,self.algo.upper()) |
| 127 | + elif self.algo in stable_baselines3.__all__: |
| 128 | + algorithm = getattr(stable_baselines3,self.algo.upper()) |
| 129 | + else: |
| 130 | + raise f"Algorith name does not exist: {self.algo.upper()}" |
| 131 | + # ---------------------- |
| 132 | + # 📂 Logging Setup |
| 133 | + # ---------------------- |
| 134 | + # WandB run setup |
| 135 | + run_name = f"{self.env_id}__{self.algo}_{int(time.time())}" |
| 136 | + wandb.init( |
| 137 | + project=self.wandb_project, |
| 138 | + entity=self.wandb_entity, |
| 139 | + name=run_name, |
| 140 | + sync_tensorboard=True, |
| 141 | + config=hyperparams, |
| 142 | + monitor_gym=True, |
| 143 | + save_code=True, |
| 144 | + ) |
| 145 | + |
| 146 | + # Logging directory for TensorBoard |
| 147 | + log_dir = f"/tensorboard_logs/{self.algo}/runs/{run_name}" |
| 148 | + os.makedirs(log_dir, exist_ok=True) |
| 149 | + |
| 150 | + # Create model |
| 151 | + model = algorithm(self.policy, env, verbose=0, tensorboard_log=log_dir, **hyperparams) |
| 152 | + new_logger = configure(log_dir, ["tensorboard"]) |
| 153 | + model.set_logger(new_logger) |
| 154 | + # Create pruning callback |
| 155 | + pruning_callback = TrackingCallback(trial=trial, start_tracking_step=self.start_tracking_step, mean_elements=self.mean_elements, eval_frequency=self.eval_frequency) |
| 156 | + |
| 157 | + |
| 158 | + model.learn(total_timesteps=self.total_timesteps, callback=pruning_callback) |
| 159 | + wandb.finish() |
| 160 | + try: |
| 161 | + return pruning_callback.mean_ep_reward # ✅ Get mean reward from the callback |
| 162 | + except: |
| 163 | + return None |
| 164 | + |
| 165 | + |
| 166 | + def run_optimization(self): |
| 167 | + """Run Optuna optimization.""" |
| 168 | + study = optuna.create_study(direction="maximize", pruner=self.pruner) |
| 169 | + study.optimize(self.objective, n_trials=self.n_trials) |
| 170 | + print("✅ Best hyperparameters:", study.best_params) |
0 commit comments