Skip to content

Commit 1f67822

Browse files
authored
fix(video): fix video playback with force_key_frame (#726)
* force_key_frame * fix embedded video playback * disable fix_moov * black; flake8 * fix typo * yuv420p -> yuv444p (lossless)
1 parent 0adcc72 commit 1f67822

File tree

4 files changed

+120
-27
lines changed

4 files changed

+120
-27
lines changed

openadapt/config.defaults.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
"RECORD_FULL_VIDEO": false,
2222
"RECORD_IMAGES": false,
2323
"LOG_MEMORY": false,
24-
"VIDEO_PIXEL_FORMAT": "rgb24",
2524
"STOP_SEQUENCES": [
2625
[
2726
"o",

openadapt/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ class SegmentationAdapter(str, Enum):
147147
# useful for debugging but expensive computationally
148148
LOG_MEMORY: bool
149149
REPLAY_STRIP_ELEMENT_STATE: bool = True
150-
VIDEO_PIXEL_FORMAT: str = "rgb24"
150+
VIDEO_ENCODING: str = "libx264"
151+
VIDEO_PIXEL_FORMAT: str = "yuv444p"
151152
VIDEO_DIR_PATH: str = str(VIDEO_DIR_PATH)
152153
# sequences that when typed, will stop the recording of ActionEvents in record.py
153154
STOP_SEQUENCES: list[list[str]] = [

openadapt/record.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def video_pre_callback(db: crud.SaSession, recording: Recording) -> dict[str, An
411411
"video_stream": video_stream,
412412
"video_start_timestamp": video_start_timestamp,
413413
"last_pts": 0,
414+
"video_file_path": video_file_path,
414415
}
415416

416417

@@ -423,6 +424,11 @@ def video_post_callback(state: dict) -> None:
423424
video.finalize_video_writer(
424425
state["video_container"],
425426
state["video_stream"],
427+
state["video_start_timestamp"],
428+
state["last_frame"],
429+
state["last_frame_timestamp"],
430+
state["last_pts"],
431+
state["video_file_path"],
426432
)
427433

428434

@@ -435,7 +441,7 @@ def write_video_event(
435441
video_stream: av.stream.Stream,
436442
video_start_timestamp: float,
437443
last_pts: int = 0,
438-
num_copies: int = 2,
444+
**kwargs: dict,
439445
) -> dict[str, Any]:
440446
"""Write a screen event to the video file and update the performance queue.
441447
@@ -450,29 +456,33 @@ def write_video_event(
450456
video_start_timestamp (float): The base timestamp from which the video
451457
recording started.
452458
last_pts: The last presentation timestamp.
453-
num_copies: The number of times to write the first each frame.
454459
455460
Returns:
456461
dict containing state.
457462
"""
458-
if last_pts != 0:
459-
num_copies = 1
460-
# ensure that the first frame is available (otherwise occasionally it is not)
461-
for _ in range(num_copies):
462-
last_pts = video.write_video_frame(
463-
video_container,
464-
video_stream,
465-
event.data,
466-
event.timestamp,
467-
video_start_timestamp,
468-
last_pts,
469-
)
463+
screenshot_image = event.data
464+
screenshot_timestamp = event.timestamp
465+
force_key_frame = last_pts == 0
466+
last_pts = video.write_video_frame(
467+
video_container,
468+
video_stream,
469+
screenshot_image,
470+
screenshot_timestamp,
471+
video_start_timestamp,
472+
last_pts,
473+
force_key_frame,
474+
)
470475
perf_q.put((f"{event.type}(video)", event.timestamp, utils.get_timestamp()))
471476
return {
472-
"video_container": video_container,
473-
"video_stream": video_stream,
474-
"video_start_timestamp": video_start_timestamp,
475-
"last_pts": last_pts,
477+
**kwargs,
478+
**{
479+
"video_container": video_container,
480+
"video_stream": video_stream,
481+
"video_start_timestamp": video_start_timestamp,
482+
"last_frame": screenshot_image,
483+
"last_frame_timestamp": screenshot_timestamp,
484+
"last_pts": last_pts,
485+
},
476486
}
477487

478488

openadapt/video.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from fractions import Fraction
44
from pprint import pformat
55
import os
6+
import subprocess
7+
import tempfile
68
import threading
79

810
from loguru import logger
@@ -47,7 +49,7 @@ def initialize_video_writer(
4749
width: int,
4850
height: int,
4951
fps: int = 24,
50-
codec: str = "libx264rgb",
52+
codec: str = config.VIDEO_ENCODING,
5153
pix_fmt: str = config.VIDEO_PIXEL_FORMAT,
5254
crf: int = 0,
5355
preset: str = "veryslow",
@@ -60,8 +62,8 @@ def initialize_video_writer(
6062
height (int): Height of the video.
6163
fps (int, optional): Frames per second of the video. Defaults to 24.
6264
codec (str, optional): Codec used for encoding the video.
63-
Defaults to 'libx264rgb'.
64-
pix_fmt (str, optional): Pixel format of the video. Defaults to 'rgb24'.
65+
Defaults to 'libx264'.
66+
pix_fmt (str, optional): Pixel format of the video. Defaults to 'yuv420p'.
6567
crf (int, optional): Constant Rate Factor for encoding quality.
6668
Defaults to 0 for lossless.
6769
preset (str, optional): Encoding speed/quality trade-off.
@@ -91,6 +93,7 @@ def write_video_frame(
9193
timestamp: float,
9294
video_start_timestamp: float,
9395
last_pts: int,
96+
force_key_frame: bool = False,
9497
) -> int:
9598
"""Encodes and writes a video frame to the output container from a given screenshot.
9699
@@ -108,6 +111,7 @@ def write_video_frame(
108111
video_start_timestamp (float): The base timestamp from which the video
109112
recording started.
110113
last_pts (int): The PTS of the last written frame.
114+
force_key_frame (bool): Whether to force this frame to be a key frame.
111115
112116
Returns:
113117
int: The updated last_pts value, to be used for writing the next frame.
@@ -118,23 +122,28 @@ def write_video_frame(
118122
- The function logs the current timestamp, base timestamp, and
119123
calculated PTS values for debugging purposes.
120124
"""
121-
logger.debug(f"{timestamp=} {video_start_timestamp=}")
122-
123125
# Convert the PIL Image to an AVFrame
124126
av_frame = av.VideoFrame.from_image(screenshot)
125127

128+
# Optionally force a key frame
129+
# TODO: force key frames on active window change?
130+
if force_key_frame:
131+
av_frame.pict_type = "I"
132+
126133
# Calculate the time difference in seconds
127134
time_diff = timestamp - video_start_timestamp
128135

129136
# Calculate PTS, taking into account the fractional average rate
130137
pts = int(time_diff * float(Fraction(video_stream.average_rate)))
131138

132-
logger.debug(f"{time_diff=} {pts=} {video_stream.average_rate=}")
139+
logger.debug(
140+
f"{timestamp=} {video_start_timestamp=} {time_diff=} {pts=} {force_key_frame=}"
141+
)
133142

134143
# Ensure monotonically increasing PTS
135144
if pts <= last_pts:
136145
pts = last_pts + 1
137-
logger.debug("incremented {pts=}")
146+
logger.debug(f"incremented {pts=}")
138147
av_frame.pts = pts
139148
last_pts = pts # Update the last_pts
140149

@@ -149,16 +158,45 @@ def write_video_frame(
149158
def finalize_video_writer(
150159
video_container: av.container.OutputContainer,
151160
video_stream: av.stream.Stream,
161+
video_start_timestamp: float,
162+
last_frame: Image.Image,
163+
last_frame_timestamp: float,
164+
last_pts: int,
165+
video_file_path: str,
166+
fix_moov: bool = False,
152167
) -> None:
153168
"""Finalizes the video writer, ensuring all buffered frames are encoded and written.
154169
155170
Args:
156171
video_container (av.container.OutputContainer): The AV container to finalize.
157172
video_stream (av.stream.Stream): The AV stream to finalize.
173+
video_start_timestamp (float): The base timestamp from which the video
174+
recording started.
175+
last_frame (Image.Image): The last frame that was written (to be written again).
176+
last_frame_timestamp (float): The timestamp of the last frame that was written.
177+
last_pts (int): The last presentation timestamp.
178+
video_file_path (str): The path to the video file.
179+
fix_moov (bool): Whether to move the moov atom to the beginning of the file.
180+
Setting this to True will fix a bug when displaying the video in Github
181+
comments causing the video to appear to start a few seconds after 0:00.
182+
However, this causes extract_frames to fail.
158183
"""
159184
# Closing the container in the main thread leads to a GIL deadlock.
160185
# https://github.com/PyAV-Org/PyAV/issues/1053
161186

187+
# Write a final key frame
188+
last_pts = write_video_frame(
189+
video_container,
190+
video_stream,
191+
last_frame,
192+
last_frame_timestamp,
193+
video_start_timestamp,
194+
last_pts,
195+
force_key_frame=True,
196+
)
197+
198+
# Closing in the same thread sometimes hangs, so do it in a different thread:
199+
162200
# Define a function to close the container
163201
def close_container() -> None:
164202
logger.info("closing video container...")
@@ -177,9 +215,54 @@ def close_container() -> None:
177215

178216
# Wait for the thread to finish execution
179217
close_thread.join()
218+
219+
# Move moov atom to beginning of file
220+
if fix_moov:
221+
# TODO: fix this
222+
logger.warning(f"{fix_moov=} will cause extract_frames() to fail!!!")
223+
move_moov_atom(video_file_path)
224+
180225
logger.info("done")
181226

182227

228+
def move_moov_atom(input_file: str, output_file: str = None) -> None:
229+
"""Moves the moov atom to the beginning of the video file using ffmpeg.
230+
231+
If no output file is specified, modifies the input file in place.
232+
233+
Args:
234+
input_file (str): The path to the input MP4 file.
235+
output_file (str, optional): The path to the output MP4 file where the moov
236+
atom is at the beginning. If None, modifies the input file in place.
237+
"""
238+
if output_file is None:
239+
# Create a temporary file
240+
temp_file = tempfile.NamedTemporaryFile(
241+
delete=False,
242+
suffix=".mp4",
243+
dir=os.path.dirname(input_file),
244+
).name
245+
output_file = temp_file
246+
247+
command = [
248+
"ffmpeg",
249+
"-y", # Automatically overwrite files without asking
250+
"-i",
251+
input_file,
252+
"-codec",
253+
"copy", # Avoid re-encoding; just copy streams
254+
"-movflags",
255+
"faststart", # Move the moov atom to the start
256+
output_file,
257+
]
258+
logger.info(f"{command=}")
259+
subprocess.run(command, check=True)
260+
261+
if temp_file:
262+
# Replace the original file with the modified one
263+
os.replace(temp_file, input_file)
264+
265+
183266
def extract_frames(
184267
video_filename: str,
185268
timestamps: list[str],

0 commit comments

Comments
 (0)