Skip to content

Commit 20f2e06

Browse files
committed
@ physigym now videos are saved
1 parent 859ec90 commit 20f2e06

File tree

1 file changed

+145
-60
lines changed

1 file changed

+145
-60
lines changed

rl/sb/stable_baselines.py

Lines changed: 145 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
21
import gymnasium as gym
32
import numpy as np
43
import os
54
import wandb
65
import physigym
76
import stable_baselines3
87
import sb3_contrib
9-
from stable_baselines3 import SAC, PPO
108
from sb3_contrib import TQC
119
from stable_baselines3.common.callbacks import BaseCallback
1210
from stable_baselines3.common.logger import configure
1311
import tyro
1412
import time
1513
from gymnasium.spaces import Box
1614
from dataclasses import dataclass
17-
from extending import physicell #from embedding import physicell
15+
from embedding import physicell # from extending import physicell
1816
import matplotlib.pyplot as plt
1917
import pandas as pd
18+
19+
2020
# ----------------------
2121
# 🌟 Dataclass
2222
# ----------------------
@@ -36,6 +36,7 @@ class Args:
3636
seed: int = 1
3737
"""seed"""
3838

39+
3940
args = tyro.cli(Args)
4041
config = vars(args)
4142
# ----------------------
@@ -46,10 +47,10 @@ class Args:
4647
##### 📍 Choose Algorithm (SB3 or SB3-Contrib)
4748
##### ----------------------
4849
algo_name = args.algo_name
49-
if algo_name in sb3_contrib.__all__:
50-
algorithm = getattr(sb3_contrib,algo_name)
50+
if algo_name in sb3_contrib.__all__:
51+
algorithm = getattr(sb3_contrib, algo_name)
5152
elif algo_name in stable_baselines3.__all__:
52-
algorithm = getattr(stable_baselines3,algo_name)
53+
algorithm = getattr(stable_baselines3, algo_name)
5354
else:
5455
raise f"Algorith name does not exist: {algo_name}"
5556

@@ -83,10 +84,12 @@ def __init__(self, verbose=0, video_frequency=50000):
8384
def _on_step(self) -> bool:
8485
# Get information from the environment
8586
if "reward" in self.locals:
86-
self.logger.record("env/reward_value", self.locals["rewards"][0] )
87+
self.logger.record("env/reward_value", self.locals["rewards"][0])
8788

8889
if "number_cancer_cells" in self.locals["infos"][0]:
89-
self.logger.record("env/cancer_cell_count", self.locals["infos"][0]["number_cancer_cells"])
90+
self.logger.record(
91+
"env/cancer_cell_count", self.locals["infos"][0]["number_cancer_cells"]
92+
)
9093

9194
if "actions" in self.locals:
9295
actions = self.locals["actions"][0]
@@ -97,6 +100,7 @@ def _on_step(self) -> bool:
97100
self.logger.dump(step=self.global_step)
98101
return True
99102

103+
100104
class PhysiCellModelWrapper(gym.Wrapper):
101105
def __init__(
102106
self,
@@ -179,42 +183,59 @@ def step(self, action: np.ndarray):
179183
)
180184
# Preprocess observation (if needed)
181185
o_observation = np.array(o_observation, dtype=float)
182-
info["action"] = d_action
186+
info["action"] = d_action
183187
self.info = info
184188
return o_observation, r_reward, b_terminated, b_truncated, info
185-
186-
def render(self, path="./output/image"):
187-
os.makedirs(path,exist_ok=True)
188-
df_cell = pd.DataFrame(physicell.get_cell(), columns=['ID','x','y','z','dead','cell_type'])
189-
fig, ax = plt.subplots(1, 3, figsize=(10, 6), gridspec_kw={'width_ratios': [1, 0.2, 0.2]})
190-
191-
for s_celltype, s_color in sorted({'cancer_cell': 'gray', 'nurse_cell': 'red'}.items()):
192-
df_celltype = df_cell.loc[(df_cell.z == 0.0) & (df_cell.cell_type == s_celltype), :]
189+
190+
def render(
191+
self,
192+
path="./output/image",
193+
saving_title: str = "output_simulation_image_episode",
194+
):
195+
os.makedirs(path, exist_ok=True)
196+
df_cell = pd.DataFrame(
197+
physicell.get_cell(), columns=["ID", "x", "y", "z", "dead", "cell_type"]
198+
)
199+
fig, ax = plt.subplots(
200+
1, 3, figsize=(10, 6), gridspec_kw={"width_ratios": [1, 0.2, 0.2]}
201+
)
202+
count_cancer_cell = physicell.get_parameter("count_cancer_cell")
203+
204+
for s_celltype, s_color in sorted(
205+
{"cancer_cell": "gray", "nurse_cell": "red"}.items()
206+
):
207+
df_celltype = df_cell.loc[
208+
(df_cell.z == 0.0) & (df_cell.cell_type == s_celltype), :
209+
]
193210
df_celltype.plot(
194-
kind='scatter', x='x', y='y', c=s_color,
211+
kind="scatter",
212+
x="x",
213+
y="y",
214+
c=s_color,
195215
xlim=[
196-
self.x_min,
216+
self.x_min,
197217
self.x_max,
198218
],
199219
ylim=[
200220
self.y_min,
201221
self.y_max,
202222
],
203223
grid=True,
204-
label = s_celltype,
224+
label=s_celltype,
205225
s=100,
206-
title=f"episode step {str(self.unwrapped_env.step_episode).zfill(3)}",
226+
title=f"episode step {str(self.unwrapped_env.step_episode).zfill(3)}, cancer cell: {count_cancer_cell}",
207227
ax=ax[0],
208-
).legend(loc='lower left')
209-
228+
).legend(loc="lower left")
210229

211230
# Create a colormap for the color bars (from -1 to 1)
212-
list_colors = ["royalblue","darkorange"]
231+
list_colors = ["royalblue", "darkorange"]
213232

214233
# Function to create fluid-like color bars
215234
def create_fluid_bar(ax_bar, drug_amount, title, max_amount=30, color="cyan"):
216235
ax_bar.set_xlim(0, 1)
217-
ax_bar.set_ylim(0, 1) # Set y-axis from 0 to 1 for percentage representation
236+
ax_bar.set_ylim(
237+
0, 1
238+
) # Set y-axis from 0 to 1 for percentage representation
218239
ax_bar.set_title(title, fontsize=10)
219240
ax_bar.set_xticks([])
220241
ax_bar.set_yticks(np.linspace(0, 1, 5)) # 0% to 100% scale
@@ -226,35 +247,103 @@ def create_fluid_bar(ax_bar, drug_amount, title, max_amount=30, color="cyan"):
226247
ax_bar.fill_betweenx(np.linspace(0, fill_level, 100), 0, 1, color=color)
227248

228249
# Draw container border
229-
ax_bar.spines['left'].set_visible(False)
230-
ax_bar.spines['right'].set_visible(False)
231-
ax_bar.spines['top'].set_visible(True)
232-
ax_bar.spines['bottom'].set_visible(True)
233-
250+
ax_bar.spines["left"].set_visible(False)
251+
ax_bar.spines["right"].set_visible(False)
252+
ax_bar.spines["top"].set_visible(True)
253+
ax_bar.spines["bottom"].set_visible(True)
234254

235255
action = self.info["action"]
236256
for i, (key, value) in enumerate(action.items(), start=1): # Start index from 1
237-
create_fluid_bar(ax[i], value[0], f"drug_{i}", color=list_colors[i-1])
257+
create_fluid_bar(ax[i], value[0], f"drug_{i}", color=list_colors[i - 1])
238258

239-
# fig.savefig(f"output_image_{self.unwrapped_env.step_episode}.png", bbox_inches='tight')
240-
# Convert figure to NumPy array (store frame)
241-
plt.savefig(path+f"/output_simulation_image_episode step {str(self.unwrapped_env.step_episode).zfill(3)}")
259+
plt.savefig(
260+
path
261+
+ f"/{saving_title} step {str(self.unwrapped_env.step_episode).zfill(3)}"
262+
)
242263
plt.close(fig)
243264

265+
244266
import subprocess
267+
268+
245269
def png_to_video_ffmpeg(image_folder, output_video, fps=10):
246270
command = [
247-
"ffmpeg", "-framerate", str(fps),
248-
"-pattern_type", "glob", "-i", f"{image_folder}/*.png",
249-
"-c:v", "libx264", "-pix_fmt", "yuv420p",
250-
output_video
271+
"ffmpeg",
272+
"-framerate",
273+
str(fps),
274+
"-pattern_type",
275+
"glob",
276+
"-i",
277+
f"{image_folder}/*.png",
278+
"-c:v",
279+
"libx264",
280+
"-pix_fmt",
281+
"yuv420p",
282+
output_video,
251283
]
252284
subprocess.run(command, check=True)
253285
print(f"✅ Video saved as {output_video}")
286+
287+
288+
import os
289+
import glob
290+
import imageio
291+
import imageio.v3 as iio # Newer version of imageio
292+
import imageio_ffmpeg # Ensure ffmpeg support
293+
294+
295+
def png_to_video_imageio(image_folder, output_video, fps=10):
296+
images = sorted(glob.glob(os.path.join(image_folder, "*.png")))
297+
298+
if not images:
299+
print("❌ No images found in the directory:", image_folder)
300+
return
301+
302+
print(f"🖼️ Found {len(images)} images. First image: {images[0]}")
303+
304+
# Read first image to get size
305+
frame = iio.imread(images[0])
306+
height, width, _ = frame.shape
307+
print(f"📏 Image size: {width}x{height}")
308+
309+
writer = imageio.get_writer(
310+
output_video, fps=fps, codec="libx264", format="FFMPEG", pixelformat="yuv420p"
311+
)
312+
313+
for img in images:
314+
frame = iio.imread(img)
315+
writer.append_data(frame)
316+
317+
writer.close()
318+
print(f"✅ Video saved as {output_video}")
319+
320+
321+
def _video_save(
322+
env,
323+
seed,
324+
step,
325+
image_folder="./output/image",
326+
deterministic=False,
327+
wandb_path="test/simulation_video",
328+
wandb=wandb,
329+
):
330+
output_video = f"seed_{seed}_step_{step}.mp4"
331+
obs, info = env.reset(seed=seed)
332+
done = False
333+
while not done:
334+
action, _states = model.predict(obs, deterministic=deterministic)
335+
obs, reward, terminated, truncated, info = env.step(action)
336+
env.render()
337+
if terminated or truncated:
338+
png_to_video_imageio(image_folder, output_video, fps=10)
339+
wandb.log({wandb_path: wandb.Video(output_video, fps=10, format="mp4")})
340+
obs, info = env.reset(seed=args.seed)
341+
342+
254343
# ----------------------
255344
# 🏗️ Environment Setup
256345
# ----------------------
257-
env = gym.make(args.env_id,observation_type=args.observation_type)
346+
env = gym.make(args.env_id, observation_type=args.observation_type)
258347
env = PhysiCellModelWrapper(env)
259348
env = gym.wrappers.RescaleAction(env, min_action=-1, max_action=1)
260349
env = gym.wrappers.GrayscaleObservation(env)
@@ -264,7 +353,7 @@ def png_to_video_ffmpeg(image_folder, output_video, fps=10):
264353
# ----------------------
265354
# 📂 Logging Setup
266355
# ----------------------
267-
log_dir = f"./tensorboard_logs/{algo_name}"
356+
log_dir = f"./tensorboard_logs/{algo_name}"
268357
os.makedirs(log_dir, exist_ok=True)
269358

270359
# ----------------------
@@ -273,28 +362,24 @@ def png_to_video_ffmpeg(image_folder, output_video, fps=10):
273362
model = algorithm("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, seed=args.seed)
274363
new_logger = configure(log_dir, ["tensorboard"])
275364
model.set_logger(new_logger)
276-
model.learn(total_timesteps=int(1e6), log_interval=1, progress_bar=False, callback=TensorboardCallback())
277-
path_saving_model = run_name+"/model"
278-
model.save(path_saving_model)
279365
# ✅ Finish WandB run
280-
del model # remove to demonstrate saving and loading
281-
wandb.finish() # ✅ Finish WandB run
282-
366+
# del model # remove to demonstrate saving and loading
283367
# ----------------------
284368
# 🎮 Run the Trained Agent
285369
# ----------------------
286-
model = algorithm.load(path_saving_model) # load model
287-
obs, info = env.reset()
288-
dictionnary = {}
289-
for i in range(5):
290-
step = 0
291-
while True:
292-
action, _states = model.predict(obs, deterministic=True)
293-
obs, reward, terminated, truncated, info = env.step(action)
294-
step +=1
295-
if terminated or truncated:
296-
png_to_video_ffmpeg("./output/image", f"output_video_{i}.mp4", fps=10)
297-
obs, info = env.reset()
298-
print("Finished")
370+
# model = algorithm.load(path_saving_model) # load model
371+
for i in range(10):
372+
_video_save(env=env, seed=args.seed, step=(i) * 25000, wandb=wandb)
373+
model.learn(
374+
total_timesteps=int(25000),
375+
log_interval=1,
376+
progress_bar=False,
377+
callback=TensorboardCallback(),
378+
)
379+
# _video_save(env=env,seed=args.seed, step=(i+1)*25000,wandb=wandb)
380+
381+
path_saving_model = run_name + "/model"
382+
model.save(path_saving_model)
299383

300-
384+
print("Finished")
385+
wandb.finish() # ✅ Finish WandB run

0 commit comments

Comments
 (0)