1313from matplotlib import colormaps
1414import numpy as np
1515from tqdm import tqdm
16- from PIL import Image , ImageFont
16+ from PIL import ImageFont
1717import pandas as pd
1818import geopandas as gpd
1919import contextily as ctx
2020from functions import create_graph_from_adj
21+ import cv2
2122
2223# Constants
2324COLORMAP = colormaps ["RdYlGn_r" ]
2425
25- OUTPUT_FILE_NAME = "evolution.gif"
26-
2726# if on wsl
2827FONT_PATH = ""
2928if platform .system () == "Linux" :
3231 FONT_PATH = "/System/Library/Fonts/Supplemental/Arial.ttf"
3332
3433
35- def create_image (__df , __time , _graph , _pos , _edges , _n , _gdf , _day ):
34+ def create_image (__row : pd . Series , __graph , __pos , __edges , __n , __gdf , __day ):
3635 """
3736 Generates and saves an image of a graph with edges colored based on density.
3837
3938 Parameters:
40- __df (DataFrame ): A pandas DataFrame containing the data.
39+ __row (Series ): A pandas DataFrame containing the data.
4140 __time (int): The specific time (in seconds) for which the graph is to be generated.
42- _graph (Graph): A networkx Graph object.
43- _pos (dict): A dictionary containing the positions of the nodes.
44- _edges (list): A list containing the edges.
45- _n (int): The number of nodes in the graph.
46- _gdf (GeoDataFrame): A geopandas GeoDataFrame containing the coordinates of the nodes.
41+ __graph (Graph): A networkx Graph object.
42+ __pos (dict): A dictionary containing the positions of the nodes.
43+ __edges (list): A list containing the edges.
44+ __n (int): The number of nodes in the graph.
45+ __gdf (GeoDataFrame): A geopandas GeoDataFrame containing the coordinates of the nodes.
4746
4847 Returns:
4948 tuple: A tuple containing the time in seconds and the path to the saved image.
5049 """
51- for col in __df .columns :
50+ time = __row .name
51+ for col in __row .index :
5252 index = int (col )
53- density = __df . loc [ __time ][ col ] # / (225 / 2000)
54- src = index // _n
55- dst = index % _n
53+ density = __row [ col ]
54+ src = index // __n
55+ dst = index % __n
5656 # set color of edge based on density using a colormap from green to red
57- _graph [src ][dst ]["color" ] = COLORMAP (density )
57+ __graph [src ][dst ]["color" ] = COLORMAP (density )
5858 # draw graph with colors
59- colors = [_graph [u ][v ]["color" ] for u , v in _edges ]
59+ colors = [__graph [u ][v ]["color" ] for u , v in __edges ]
6060 # draw graph
6161 _ , ax = plt .subplots ()
62- if _gdf is not None :
63- limits = _gdf .total_bounds + np .array ([- 0.001 , - 0.001 , 0.001 , 0.001 ])
62+ if __gdf is not None :
63+ limits = __gdf .total_bounds + np .array ([- 0.001 , - 0.001 , 0.001 , 0.001 ])
6464 ax .set_xlim (limits [0 ], limits [2 ])
6565 ax .set_ylim (limits [1 ], limits [3 ])
6666 nx .draw_networkx_edges (
67- _graph ,
68- _pos ,
69- edgelist = _edges ,
67+ __graph ,
68+ __pos ,
69+ edgelist = __edges ,
7070 edge_color = colors ,
7171 ax = ax ,
7272 connectionstyle = "arc3,rad=0.05" ,
7373 arrowsize = 5 ,
7474 arrowstyle = "->" ,
7575 )
76- nx .draw_networkx_nodes (_graph , _pos , ax = ax , node_size = 69 )
77- nx .draw_networkx_labels (_graph , _pos , ax = ax , font_size = 5 )
78- if _gdf is not None :
76+ nx .draw_networkx_nodes (__graph , __pos , ax = ax , node_size = 69 )
77+ nx .draw_networkx_labels (__graph , __pos , ax = ax , font_size = 5 )
78+ if __gdf is not None :
7979 # _gdf.plot(ax=ax)
8080 ctx .add_basemap (
81- ax , crs = _gdf .crs .to_string (), source = ctx .providers .OpenStreetMap .Mapnik
81+ ax , crs = __gdf .crs .to_string (), source = ctx .providers .OpenStreetMap .Mapnik
8282 )
8383 plt .box (False )
84- h_time = f"{ (__time / 3600 ):.2f} "
85- plt .title (f"Time: { (__time // 3600 ):02d} :{ (__time % 3600 ) // 60 :02d} { _day } " )
84+ h_time = f"{ (time / 3600 ):.2f} "
85+ plt .title (f"Time: { (time // 3600 ):02d} :{ (time % 3600 ) // 60 :02d} { __day } " )
8686 plt .savefig (f"./temp_img/{ h_time } .png" , dpi = 300 , bbox_inches = "tight" )
87- return (__time , f"./temp_img/{ h_time } .png" )
87+ plt .close ()
88+ return (time , f"./temp_img/{ h_time } .png" )
8889
8990
9091if __name__ == "__main__" :
@@ -136,7 +137,7 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
136137 parser .add_argument (
137138 "--n-frames" ,
138139 type = int ,
139- default = 10 ,
140+ default = None ,
140141 required = False ,
141142 help = "Number of frames to generate." ,
142143 )
@@ -147,6 +148,20 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
147148 required = False ,
148149 help = "Day to plot." ,
149150 )
151+ parser .add_argument (
152+ "--fps" ,
153+ type = int ,
154+ default = 1 ,
155+ required = False ,
156+ help = "Frames per second for the output video." ,
157+ )
158+ parser .add_argument (
159+ "--output-file" ,
160+ type = str ,
161+ default = "evolution.mp4" ,
162+ required = False ,
163+ help = "Output file name for the video." ,
164+ )
150165 args = parser .parse_args ()
151166 # Load the graph
152167 # read the adjacency matrix discarding the first line
@@ -193,26 +208,30 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
193208 df = df [df .index % args .time_granularity == 0 ]
194209 if args .time_begin is not None :
195210 df = df [df .index > args .time_begin ]
196- # take N_FRAMES from the beginning
197- df = df .head (args .n_frames )
198- else :
211+ if args .n_frames is not None :
212+ # take N_FRAMES from the beginning
213+ df = df .head (args .n_frames )
214+ elif args .n_frames is not None :
199215 # take the last N_FRAMES
200216 df = df .tail (args .n_frames )
201217
202218 # check if the temp_img folder exists, if not create it
203- pathlib .Path ("./temp_img" ).mkdir (parents = True , exist_ok = True )
219+ # force delete the folder if it exists
220+ if pathlib .Path ("./temp_img" ).exists ():
221+ for file in pathlib .Path ("./temp_img" ).iterdir ():
222+ file .unlink ()
223+ else :
224+ pathlib .Path ("./temp_img" ).mkdir (parents = True , exist_ok = True )
204225
205226 with mp .Pool () as pool :
206- frames = []
207227 jobs = []
208228
209229 for time in df .index :
210230 jobs .append (
211231 pool .apply_async (
212232 create_image ,
213233 (
214- df ,
215- time ,
234+ df .loc [time ],
216235 G ,
217236 pos ,
218237 edges ,
@@ -225,19 +244,23 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
225244
226245 # use tqdm and take results:
227246 results = [job .get () for job in tqdm (jobs )]
228- results = sorted (results , key = lambda x : x [0 ])
229- frames = [Image .open (result [1 ]) for result in tqdm (results )]
230-
231- # if NFRAMES is 1, save a png image
232- if args .n_frames == 1 :
233- frames [0 ].save (OUTPUT_FILE_NAME .replace (".gif" , ".png" ), format = "PNG" )
234- else :
235- # Save into a GIF file that loops forever
236- frames [0 ].save (
237- OUTPUT_FILE_NAME ,
238- format = "GIF" ,
239- append_images = frames [1 :],
240- save_all = True ,
241- duration = 300 ,
242- loop = 0 ,
243- )
247+
248+ results = sorted (results , key = lambda x : x [0 ])
249+
250+ # Get the dimensions of the first frame
251+ first_frame_path = results [0 ][1 ]
252+ frame = cv2 .imread (first_frame_path )
253+ frame_height , frame_width , _ = frame .shape
254+
255+ # Create video writer
256+ fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
257+ video_writer = cv2 .VideoWriter (
258+ args .output_file , fourcc , args .fps , (frame_width , frame_height )
259+ )
260+
261+ for file_path in tqdm (results ):
262+ frame = cv2 .imread (file_path [1 ])
263+ video_writer .write (frame )
264+
265+ video_writer .release ()
266+ print (f"MP4 video saved to: { args .output_file } " )
0 commit comments