Skip to content

Commit f9fc345

Browse files
committed
@ physigym update rl folder i figure out TQC working while SAC does not
1 parent 40124af commit f9fc345

File tree

4 files changed

+28
-10
lines changed

4 files changed

+28
-10
lines changed

rl/sb/launch_sb_hyperopt.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
SCRIPT_PATH="rl/sb/sb_hyperopt_own.py"
44
ALGO="tqc"
5-
NUM_INSTANCES=2
5+
NUM_INSTANCES=3
66
NAME="${ALGO}_sb_hyperopt_own"
77

88
for i in $(seq 1 $NUM_INSTANCES); do
99
# Replace 0 by 255 (unclear in your script, so removed it)
10-
nohup python "$SCRIPT_PATH" --algo "$ALGO" > "${NAME}_${i}.log" 2>&1 &
10+
nohup python "$SCRIPT_PATH" --algo "$ALGO" --seed "$i" > "${NAME}_${i}.log" 2>&1 &
1111
echo "Instance $i launched with PID $!"
1212
sleep 10
1313
done

rl/sb/launch_stable_baselines.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
3+
SCRIPT_PATH="rl/sb/stable_baselines.py"
4+
ALGO_NAME="SAC"
5+
NUM_INSTANCES=3
6+
NAME="${ALGO_NAME}_sb"
7+
8+
for i in $(seq 1 $NUM_INSTANCES); do
9+
# Replace 0 by 255 (unclear in your script, so removed it)
10+
nohup python "$SCRIPT_PATH" --algo_name "$ALGO_NAME" --seed "$i" > "${NAME}_${i}.log" 2>&1 &
11+
echo "Instance $i launched with PID $!"
12+
sleep 10
13+
done

rl/sb/sb_hyperopt_own.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,14 @@ class TunerConfig:
160160
wandb_entity: str = "corporate-manu-sureli"
161161
eval_frequency: int = int(2.5e4)
162162
observation_type: str = "image"
163+
seed: int = 1
163164

164165

165166
class RLHyperparamTuner:
166167
def __init__(self, algo="TQC", env_id="physigym/ModelPhysiCellEnv-v0", n_trials=300, total_timesteps=int(1e6), pruner_type="median",
167168
start_tracking_step=50000, mean_elements=int(1e2), policy="CnnPolicy",
168-
wandb_project_name="IMAGE_TME_PHYSIGYM", wandb_entity="corporate-manu-sureli", eval_frequency=int(2.5e4), observation_type="image"):
169+
wandb_project_name="IMAGE_TME_PHYSIGYM", wandb_entity="corporate-manu-sureli", eval_frequency=int(2.5e4), observation_type="image",
170+
seed = 1):
169171
"""
170172
Class to tune hyperparameters for RL algorithms using Optuna.
171173
@@ -195,6 +197,7 @@ def __init__(self, algo="TQC", env_id="physigym/ModelPhysiCellEnv-v0", n_trials=
195197
os.makedirs(self.log_dir, exist_ok=True)
196198
self.storage_study = self.log_dir +"/"+self.study_name
197199
os.makedirs(self.storage_study, exist_ok=True)
200+
self.seed = seed
198201
# Validate algorithm
199202
if self.algo not in HYPERPARAMS_SAMPLER:
200203
raise ValueError(f"Algorithm {self.algo} not supported. Choose from {list(HYPERPARAMS_SAMPLER.keys())}.")
@@ -249,8 +252,8 @@ def objective(self, trial: optuna.Trial):
249252
save_code=True,
250253
)
251254
os.makedirs(dir, exist_ok=True)
252-
obs, info = env.reset(seed=1)
253-
model = algorithm(self.policy, env, verbose=0, tensorboard_log=dir, **hyperparams, seed=1)
255+
obs, info = env.reset(seed=self.seed)
256+
model = algorithm(self.policy, env, verbose=0, tensorboard_log=dir, **hyperparams, seed=self.seed)
254257
new_logger = configure(dir, ["tensorboard"])
255258
model.set_logger(new_logger)
256259
pruning_callback = TrackingCallback(trial=trial, start_tracking_step=self.start_tracking_step, mean_elements=self.mean_elements, eval_frequency=self.eval_frequency)

rl/sb/stable_baselines.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
@dataclass
1919
class Args:
20-
algo_name: str = "TQC"
20+
algo_name: str = "SAC"
2121
"""the name of the algo"""
2222
wandb_project_name: str = "IMAGE_TME_PHYSIGYM"
2323
"""the wandb's project name"""
2424
wandb_entity: str = "corporate-manu-sureli"
25-
# Algorithm specific arguments
25+
# Algorithm specific argumentswandb.finish()
2626
env_id: str = "physigym/ModelPhysiCellEnv-v0"
2727
"""the id of the environment"""
2828
observation_type: str = "image"
29+
"""seed"""
30+
seed: int = 1
2931
# ----------------------
3032
# 🏆 Initialize WandB
3133
# ----------------------
@@ -50,7 +52,7 @@ class Args:
5052
wandb.init(
5153
project=args.wandb_project_name,
5254
entity=args.wandb_entity,
53-
name=f"{args.algo_name}: {args.observation_type}",
55+
name=f"{args.algo_name}: observation {args.observation_type}, seed {args.seed}",
5456
sync_tensorboard=True, # Sync TensorBoard logs
5557
config=config,
5658
monitor_gym=True, # Monitor Gym environment
@@ -176,7 +178,7 @@ def step(self, action: np.ndarray):
176178
env = gym.wrappers.RescaleAction(env, min_action=-1, max_action=1)
177179
env = gym.wrappers.GrayscaleObservation(env)
178180
env = gym.wrappers.FrameStackObservation(env, stack_size=1)
179-
obs, info = env.reset()
181+
obs, info = env.reset(seed=args.seed)
180182

181183
# ----------------------
182184
# 📂 Logging Setup
@@ -187,7 +189,7 @@ def step(self, action: np.ndarray):
187189
# ----------------------
188190
# 🏃 Train the Model (with WandB Callback)
189191
# ----------------------
190-
model = TQC("CnnPolicy", env, verbose=1, tensorboard_log=log_dir)
192+
model = algorithm("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, seed=args.seed)
191193
new_logger = configure(log_dir, ["tensorboard"])
192194
model.set_logger(new_logger)
193195
model.learn(total_timesteps=int(2e6), log_interval=1, progress_bar=False, callback=TensorboardCallback())

0 commit comments

Comments
 (0)