1-
21import gymnasium as gym
32import numpy as np
43import os
54import wandb
65import physigym
76import stable_baselines3
87import sb3_contrib
9- from stable_baselines3 import SAC , PPO
108from sb3_contrib import TQC
119from stable_baselines3 .common .callbacks import BaseCallback
1210from stable_baselines3 .common .logger import configure
1311import tyro
1412import time
1513from gymnasium .spaces import Box
1614from dataclasses import dataclass
17- from extending import physicell # from embedding import physicell
15+ from embedding import physicell # from extending import physicell
1816import matplotlib .pyplot as plt
1917import pandas as pd
18+
19+
2020# ----------------------
2121# 🌟 Dataclass
2222# ----------------------
@@ -36,6 +36,7 @@ class Args:
3636 seed : int = 1
3737 """seed"""
3838
39+
3940args = tyro .cli (Args )
4041config = vars (args )
4142# ----------------------
@@ -46,10 +47,10 @@ class Args:
4647##### 📍 Choose Algorithm (SB3 or SB3-Contrib)
4748##### ----------------------
4849algo_name = args .algo_name
49- if algo_name in sb3_contrib .__all__ :
50- algorithm = getattr (sb3_contrib ,algo_name )
50+ if algo_name in sb3_contrib .__all__ :
51+ algorithm = getattr (sb3_contrib , algo_name )
5152elif algo_name in stable_baselines3 .__all__ :
52- algorithm = getattr (stable_baselines3 ,algo_name )
53+ algorithm = getattr (stable_baselines3 , algo_name )
5354else :
5455 raise f"Algorith name does not exist: { algo_name } "
5556
@@ -83,10 +84,12 @@ def __init__(self, verbose=0, video_frequency=50000):
8384 def _on_step (self ) -> bool :
8485 # Get information from the environment
8586 if "reward" in self .locals :
86- self .logger .record ("env/reward_value" , self .locals ["rewards" ][0 ] )
87+ self .logger .record ("env/reward_value" , self .locals ["rewards" ][0 ])
8788
8889 if "number_cancer_cells" in self .locals ["infos" ][0 ]:
89- self .logger .record ("env/cancer_cell_count" , self .locals ["infos" ][0 ]["number_cancer_cells" ])
90+ self .logger .record (
91+ "env/cancer_cell_count" , self .locals ["infos" ][0 ]["number_cancer_cells" ]
92+ )
9093
9194 if "actions" in self .locals :
9295 actions = self .locals ["actions" ][0 ]
@@ -97,6 +100,7 @@ def _on_step(self) -> bool:
97100 self .logger .dump (step = self .global_step )
98101 return True
99102
103+
100104class PhysiCellModelWrapper (gym .Wrapper ):
101105 def __init__ (
102106 self ,
@@ -179,42 +183,59 @@ def step(self, action: np.ndarray):
179183 )
180184 # Preprocess observation (if needed)
181185 o_observation = np .array (o_observation , dtype = float )
182- info ["action" ] = d_action
186+ info ["action" ] = d_action
183187 self .info = info
184188 return o_observation , r_reward , b_terminated , b_truncated , info
185-
186- def render (self , path = "./output/image" ):
187- os .makedirs (path ,exist_ok = True )
188- df_cell = pd .DataFrame (physicell .get_cell (), columns = ['ID' ,'x' ,'y' ,'z' ,'dead' ,'cell_type' ])
189- fig , ax = plt .subplots (1 , 3 , figsize = (10 , 6 ), gridspec_kw = {'width_ratios' : [1 , 0.2 , 0.2 ]})
190-
191- for s_celltype , s_color in sorted ({'cancer_cell' : 'gray' , 'nurse_cell' : 'red' }.items ()):
192- df_celltype = df_cell .loc [(df_cell .z == 0.0 ) & (df_cell .cell_type == s_celltype ), :]
189+
190+ def render (
191+ self ,
192+ path = "./output/image" ,
193+ saving_title : str = "output_simulation_image_episode" ,
194+ ):
195+ os .makedirs (path , exist_ok = True )
196+ df_cell = pd .DataFrame (
197+ physicell .get_cell (), columns = ["ID" , "x" , "y" , "z" , "dead" , "cell_type" ]
198+ )
199+ fig , ax = plt .subplots (
200+ 1 , 3 , figsize = (10 , 6 ), gridspec_kw = {"width_ratios" : [1 , 0.2 , 0.2 ]}
201+ )
202+ count_cancer_cell = physicell .get_parameter ("count_cancer_cell" )
203+
204+ for s_celltype , s_color in sorted (
205+ {"cancer_cell" : "gray" , "nurse_cell" : "red" }.items ()
206+ ):
207+ df_celltype = df_cell .loc [
208+ (df_cell .z == 0.0 ) & (df_cell .cell_type == s_celltype ), :
209+ ]
193210 df_celltype .plot (
194- kind = 'scatter' , x = 'x' , y = 'y' , c = s_color ,
211+ kind = "scatter" ,
212+ x = "x" ,
213+ y = "y" ,
214+ c = s_color ,
195215 xlim = [
196- self .x_min ,
216+ self .x_min ,
197217 self .x_max ,
198218 ],
199219 ylim = [
200220 self .y_min ,
201221 self .y_max ,
202222 ],
203223 grid = True ,
204- label = s_celltype ,
224+ label = s_celltype ,
205225 s = 100 ,
206- title = f"episode step { str (self .unwrapped_env .step_episode ).zfill (3 )} " ,
226+ title = f"episode step { str (self .unwrapped_env .step_episode ).zfill (3 )} , cancer cell: { count_cancer_cell } " ,
207227 ax = ax [0 ],
208- ).legend (loc = 'lower left' )
209-
228+ ).legend (loc = "lower left" )
210229
211230 # Create a colormap for the color bars (from -1 to 1)
212- list_colors = ["royalblue" ,"darkorange" ]
231+ list_colors = ["royalblue" , "darkorange" ]
213232
214233 # Function to create fluid-like color bars
215234 def create_fluid_bar (ax_bar , drug_amount , title , max_amount = 30 , color = "cyan" ):
216235 ax_bar .set_xlim (0 , 1 )
217- ax_bar .set_ylim (0 , 1 ) # Set y-axis from 0 to 1 for percentage representation
236+ ax_bar .set_ylim (
237+ 0 , 1
238+ ) # Set y-axis from 0 to 1 for percentage representation
218239 ax_bar .set_title (title , fontsize = 10 )
219240 ax_bar .set_xticks ([])
220241 ax_bar .set_yticks (np .linspace (0 , 1 , 5 )) # 0% to 100% scale
@@ -226,35 +247,103 @@ def create_fluid_bar(ax_bar, drug_amount, title, max_amount=30, color="cyan"):
226247 ax_bar .fill_betweenx (np .linspace (0 , fill_level , 100 ), 0 , 1 , color = color )
227248
228249 # Draw container border
229- ax_bar .spines ['left' ].set_visible (False )
230- ax_bar .spines ['right' ].set_visible (False )
231- ax_bar .spines ['top' ].set_visible (True )
232- ax_bar .spines ['bottom' ].set_visible (True )
233-
250+ ax_bar .spines ["left" ].set_visible (False )
251+ ax_bar .spines ["right" ].set_visible (False )
252+ ax_bar .spines ["top" ].set_visible (True )
253+ ax_bar .spines ["bottom" ].set_visible (True )
234254
235255 action = self .info ["action" ]
236256 for i , (key , value ) in enumerate (action .items (), start = 1 ): # Start index from 1
237- create_fluid_bar (ax [i ], value [0 ], f"drug_{ i } " , color = list_colors [i - 1 ])
257+ create_fluid_bar (ax [i ], value [0 ], f"drug_{ i } " , color = list_colors [i - 1 ])
238258
239- # fig.savefig(f"output_image_{self.unwrapped_env.step_episode}.png", bbox_inches='tight')
240- # Convert figure to NumPy array (store frame)
241- plt .savefig (path + f"/output_simulation_image_episode step { str (self .unwrapped_env .step_episode ).zfill (3 )} " )
259+ plt .savefig (
260+ path
261+ + f"/{ saving_title } step { str (self .unwrapped_env .step_episode ).zfill (3 )} "
262+ )
242263 plt .close (fig )
243264
265+
244266import subprocess
267+
268+
245269def png_to_video_ffmpeg (image_folder , output_video , fps = 10 ):
246270 command = [
247- "ffmpeg" , "-framerate" , str (fps ),
248- "-pattern_type" , "glob" , "-i" , f"{ image_folder } /*.png" ,
249- "-c:v" , "libx264" , "-pix_fmt" , "yuv420p" ,
250- output_video
271+ "ffmpeg" ,
272+ "-framerate" ,
273+ str (fps ),
274+ "-pattern_type" ,
275+ "glob" ,
276+ "-i" ,
277+ f"{ image_folder } /*.png" ,
278+ "-c:v" ,
279+ "libx264" ,
280+ "-pix_fmt" ,
281+ "yuv420p" ,
282+ output_video ,
251283 ]
252284 subprocess .run (command , check = True )
253285 print (f"✅ Video saved as { output_video } " )
286+
287+
288+ import os
289+ import glob
290+ import imageio
291+ import imageio .v3 as iio # Newer version of imageio
292+ import imageio_ffmpeg # Ensure ffmpeg support
293+
294+
295+ def png_to_video_imageio (image_folder , output_video , fps = 10 ):
296+ images = sorted (glob .glob (os .path .join (image_folder , "*.png" )))
297+
298+ if not images :
299+ print ("❌ No images found in the directory:" , image_folder )
300+ return
301+
302+ print (f"🖼️ Found { len (images )} images. First image: { images [0 ]} " )
303+
304+ # Read first image to get size
305+ frame = iio .imread (images [0 ])
306+ height , width , _ = frame .shape
307+ print (f"📏 Image size: { width } x{ height } " )
308+
309+ writer = imageio .get_writer (
310+ output_video , fps = fps , codec = "libx264" , format = "FFMPEG" , pixelformat = "yuv420p"
311+ )
312+
313+ for img in images :
314+ frame = iio .imread (img )
315+ writer .append_data (frame )
316+
317+ writer .close ()
318+ print (f"✅ Video saved as { output_video } " )
319+
320+
321+ def _video_save (
322+ env ,
323+ seed ,
324+ step ,
325+ image_folder = "./output/image" ,
326+ deterministic = False ,
327+ wandb_path = "test/simulation_video" ,
328+ wandb = wandb ,
329+ ):
330+ output_video = f"seed_{ seed } _step_{ step } .mp4"
331+ obs , info = env .reset (seed = seed )
332+ done = False
333+ while not done :
334+ action , _states = model .predict (obs , deterministic = deterministic )
335+ obs , reward , terminated , truncated , info = env .step (action )
336+ env .render ()
337+ if terminated or truncated :
338+ png_to_video_imageio (image_folder , output_video , fps = 10 )
339+ wandb .log ({wandb_path : wandb .Video (output_video , fps = 10 , format = "mp4" )})
340+ obs , info = env .reset (seed = args .seed )
341+
342+
254343# ----------------------
255344# 🏗️ Environment Setup
256345# ----------------------
257- env = gym .make (args .env_id ,observation_type = args .observation_type )
346+ env = gym .make (args .env_id , observation_type = args .observation_type )
258347env = PhysiCellModelWrapper (env )
259348env = gym .wrappers .RescaleAction (env , min_action = - 1 , max_action = 1 )
260349env = gym .wrappers .GrayscaleObservation (env )
@@ -264,7 +353,7 @@ def png_to_video_ffmpeg(image_folder, output_video, fps=10):
264353# ----------------------
265354# 📂 Logging Setup
266355# ----------------------
267- log_dir = f"./tensorboard_logs/{ algo_name } "
356+ log_dir = f"./tensorboard_logs/{ algo_name } "
268357os .makedirs (log_dir , exist_ok = True )
269358
270359# ----------------------
@@ -273,28 +362,24 @@ def png_to_video_ffmpeg(image_folder, output_video, fps=10):
273362model = algorithm ("CnnPolicy" , env , verbose = 1 , tensorboard_log = log_dir , seed = args .seed )
274363new_logger = configure (log_dir , ["tensorboard" ])
275364model .set_logger (new_logger )
276- model .learn (total_timesteps = int (1e6 ), log_interval = 1 , progress_bar = False , callback = TensorboardCallback ())
277- path_saving_model = run_name + "/model"
278- model .save (path_saving_model )
279365# ✅ Finish WandB run
280- del model # remove to demonstrate saving and loading
281- wandb .finish () # ✅ Finish WandB run
282-
366+ # del model # remove to demonstrate saving and loading
283367# ----------------------
284368# 🎮 Run the Trained Agent
285369# ----------------------
286- model = algorithm .load (path_saving_model ) # load model
287- obs , info = env . reset ()
288- dictionnary = {}
289- for i in range ( 5 ):
290- step = 0
291- while True :
292- action , _states = model . predict ( obs , deterministic = True )
293- obs , reward , terminated , truncated , info = env . step ( action )
294- step += 1
295- if terminated or truncated :
296- png_to_video_ffmpeg ( "./output/image" , f"output_video_ { i } .mp4" , fps = 10 )
297- obs , info = env . reset ()
298- print ( "Finished" )
370+ # model = algorithm.load(path_saving_model) # load model
371+ for i in range ( 10 ):
372+ _video_save ( env = env , seed = args . seed , step = ( i ) * 25000 , wandb = wandb )
373+ model . learn (
374+ total_timesteps = int ( 25000 ),
375+ log_interval = 1 ,
376+ progress_bar = False ,
377+ callback = TensorboardCallback (),
378+ )
379+ # _video_save(env=env,seed=args.seed, step=(i+1)*25000,wandb=wandb)
380+
381+ path_saving_model = run_name + "/model"
382+ model . save ( path_saving_model )
299383
300-
384+ print ("Finished" )
385+ wandb .finish () # ✅ Finish WandB run
0 commit comments