Skip to content

Commit 2dbbe71

Browse files
committed
Replace gifter script with videomaker
1 parent 194e18f commit 2dbbe71

File tree

1 file changed

+75
-52
lines changed

1 file changed

+75
-52
lines changed

utils/gifter.py renamed to utils/videomaker.py

Lines changed: 75 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,16 @@
1313
from matplotlib import colormaps
1414
import numpy as np
1515
from tqdm import tqdm
16-
from PIL import Image, ImageFont
16+
from PIL import ImageFont
1717
import pandas as pd
1818
import geopandas as gpd
1919
import contextily as ctx
2020
from functions import create_graph_from_adj
21+
import cv2
2122

2223
# Constants
2324
COLORMAP = colormaps["RdYlGn_r"]
2425

25-
OUTPUT_FILE_NAME = "evolution.gif"
26-
2726
# if on wsl
2827
FONT_PATH = ""
2928
if platform.system() == "Linux":
@@ -32,59 +31,61 @@
3231
FONT_PATH = "/System/Library/Fonts/Supplemental/Arial.ttf"
3332

3433

35-
def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
34+
def create_image(__row: pd.Series, __graph, __pos, __edges, __n, __gdf, __day):
3635
"""
3736
Generates and saves an image of a graph with edges colored based on density.
3837
3938
Parameters:
40-
__df (DataFrame): A pandas DataFrame containing the data.
39+
__row (Series): A pandas DataFrame containing the data.
4140
__time (int): The specific time (in seconds) for which the graph is to be generated.
42-
_graph (Graph): A networkx Graph object.
43-
_pos (dict): A dictionary containing the positions of the nodes.
44-
_edges (list): A list containing the edges.
45-
_n (int): The number of nodes in the graph.
46-
_gdf (GeoDataFrame): A geopandas GeoDataFrame containing the coordinates of the nodes.
41+
__graph (Graph): A networkx Graph object.
42+
__pos (dict): A dictionary containing the positions of the nodes.
43+
__edges (list): A list containing the edges.
44+
__n (int): The number of nodes in the graph.
45+
__gdf (GeoDataFrame): A geopandas GeoDataFrame containing the coordinates of the nodes.
4746
4847
Returns:
4948
tuple: A tuple containing the time in seconds and the path to the saved image.
5049
"""
51-
for col in __df.columns:
50+
time = __row.name
51+
for col in __row.index:
5252
index = int(col)
53-
density = __df.loc[__time][col] # / (225 / 2000)
54-
src = index // _n
55-
dst = index % _n
53+
density = __row[col]
54+
src = index // __n
55+
dst = index % __n
5656
# set color of edge based on density using a colormap from green to red
57-
_graph[src][dst]["color"] = COLORMAP(density)
57+
__graph[src][dst]["color"] = COLORMAP(density)
5858
# draw graph with colors
59-
colors = [_graph[u][v]["color"] for u, v in _edges]
59+
colors = [__graph[u][v]["color"] for u, v in __edges]
6060
# draw graph
6161
_, ax = plt.subplots()
62-
if _gdf is not None:
63-
limits = _gdf.total_bounds + np.array([-0.001, -0.001, 0.001, 0.001])
62+
if __gdf is not None:
63+
limits = __gdf.total_bounds + np.array([-0.001, -0.001, 0.001, 0.001])
6464
ax.set_xlim(limits[0], limits[2])
6565
ax.set_ylim(limits[1], limits[3])
6666
nx.draw_networkx_edges(
67-
_graph,
68-
_pos,
69-
edgelist=_edges,
67+
__graph,
68+
__pos,
69+
edgelist=__edges,
7070
edge_color=colors,
7171
ax=ax,
7272
connectionstyle="arc3,rad=0.05",
7373
arrowsize=5,
7474
arrowstyle="->",
7575
)
76-
nx.draw_networkx_nodes(_graph, _pos, ax=ax, node_size=69)
77-
nx.draw_networkx_labels(_graph, _pos, ax=ax, font_size=5)
78-
if _gdf is not None:
76+
nx.draw_networkx_nodes(__graph, __pos, ax=ax, node_size=69)
77+
nx.draw_networkx_labels(__graph, __pos, ax=ax, font_size=5)
78+
if __gdf is not None:
7979
# _gdf.plot(ax=ax)
8080
ctx.add_basemap(
81-
ax, crs=_gdf.crs.to_string(), source=ctx.providers.OpenStreetMap.Mapnik
81+
ax, crs=__gdf.crs.to_string(), source=ctx.providers.OpenStreetMap.Mapnik
8282
)
8383
plt.box(False)
84-
h_time = f"{(__time / 3600):.2f}"
85-
plt.title(f"Time: {(__time // 3600):02d}:{(__time % 3600) // 60:02d} {_day}")
84+
h_time = f"{(time / 3600):.2f}"
85+
plt.title(f"Time: {(time // 3600):02d}:{(time % 3600) // 60:02d} {__day}")
8686
plt.savefig(f"./temp_img/{h_time}.png", dpi=300, bbox_inches="tight")
87-
return (__time, f"./temp_img/{h_time}.png")
87+
plt.close()
88+
return (time, f"./temp_img/{h_time}.png")
8889

8990

9091
if __name__ == "__main__":
@@ -136,7 +137,7 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
136137
parser.add_argument(
137138
"--n-frames",
138139
type=int,
139-
default=10,
140+
default=None,
140141
required=False,
141142
help="Number of frames to generate.",
142143
)
@@ -147,6 +148,20 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
147148
required=False,
148149
help="Day to plot.",
149150
)
151+
parser.add_argument(
152+
"--fps",
153+
type=int,
154+
default=1,
155+
required=False,
156+
help="Frames per second for the output video.",
157+
)
158+
parser.add_argument(
159+
"--output-file",
160+
type=str,
161+
default="evolution.mp4",
162+
required=False,
163+
help="Output file name for the video.",
164+
)
150165
args = parser.parse_args()
151166
# Load the graph
152167
# read the adjacency matrix discarding the first line
@@ -193,26 +208,30 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
193208
df = df[df.index % args.time_granularity == 0]
194209
if args.time_begin is not None:
195210
df = df[df.index > args.time_begin]
196-
# take N_FRAMES from the beginning
197-
df = df.head(args.n_frames)
198-
else:
211+
if args.n_frames is not None:
212+
# take N_FRAMES from the beginning
213+
df = df.head(args.n_frames)
214+
elif args.n_frames is not None:
199215
# take the last N_FRAMES
200216
df = df.tail(args.n_frames)
201217

202218
# check if the temp_img folder exists, if not create it
203-
pathlib.Path("./temp_img").mkdir(parents=True, exist_ok=True)
219+
# force delete the folder if it exists
220+
if pathlib.Path("./temp_img").exists():
221+
for file in pathlib.Path("./temp_img").iterdir():
222+
file.unlink()
223+
else:
224+
pathlib.Path("./temp_img").mkdir(parents=True, exist_ok=True)
204225

205226
with mp.Pool() as pool:
206-
frames = []
207227
jobs = []
208228

209229
for time in df.index:
210230
jobs.append(
211231
pool.apply_async(
212232
create_image,
213233
(
214-
df,
215-
time,
234+
df.loc[time],
216235
G,
217236
pos,
218237
edges,
@@ -225,19 +244,23 @@ def create_image(__df, __time, _graph, _pos, _edges, _n, _gdf, _day):
225244

226245
# use tqdm and take results:
227246
results = [job.get() for job in tqdm(jobs)]
228-
results = sorted(results, key=lambda x: x[0])
229-
frames = [Image.open(result[1]) for result in tqdm(results)]
230-
231-
# if NFRAMES is 1, save a png image
232-
if args.n_frames == 1:
233-
frames[0].save(OUTPUT_FILE_NAME.replace(".gif", ".png"), format="PNG")
234-
else:
235-
# Save into a GIF file that loops forever
236-
frames[0].save(
237-
OUTPUT_FILE_NAME,
238-
format="GIF",
239-
append_images=frames[1:],
240-
save_all=True,
241-
duration=300,
242-
loop=0,
243-
)
247+
248+
results = sorted(results, key=lambda x: x[0])
249+
250+
# Get the dimensions of the first frame
251+
first_frame_path = results[0][1]
252+
frame = cv2.imread(first_frame_path)
253+
frame_height, frame_width, _ = frame.shape
254+
255+
# Create video writer
256+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
257+
video_writer = cv2.VideoWriter(
258+
args.output_file, fourcc, args.fps, (frame_width, frame_height)
259+
)
260+
261+
for file_path in tqdm(results):
262+
frame = cv2.imread(file_path[1])
263+
video_writer.write(frame)
264+
265+
video_writer.release()
266+
print(f"MP4 video saved to: {args.output_file}")

0 commit comments

Comments
 (0)