Skip to content

Commit 041a60e

Browse files
authored
fix: improve frames split for workers - minor fixes (#32)
* fix: improve frame splitting into workers * fix: minor fixes to parameters
1 parent 1514025 commit 041a60e

File tree

3 files changed

+41
-41
lines changed

3 files changed

+41
-41
lines changed

src/movie_barcodes/utility.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def validate_args(args: argparse.Namespace, frame_count: int, MAX_PROCESSES: int
4040

4141
# Check if the destination path is writable
4242
if args.destination_path is not None:
43-
destination_dir = path.dirname(args.destination_path)
43+
destination_dir = path.dirname(args.destination_path) or "."
4444
if not access(destination_dir, W_OK):
4545
raise PermissionError(f"The specified destination path '{args.destination_path}' is not writable.")
4646

@@ -62,15 +62,10 @@ def validate_args(args: argparse.Namespace, frame_count: int, MAX_PROCESSES: int
6262
if args.height is not None:
6363
if args.height <= 0:
6464
raise ValueError("Height must be greater than 0.")
65-
if args.height > frame_count:
66-
raise ValueError("Height must be less than or equal to the number of frames.")
6765

6866
if frame_count < MIN_FRAME_COUNT:
6967
raise ValueError(f"The video must have at least {MIN_FRAME_COUNT} frames.")
7068

71-
if args.all_methods and args.method is not None:
72-
raise ValueError("The --all_methods flag cannot be used with the --method argument.")
73-
7469

7570
def get_dominant_color_function(method: str) -> Callable:
7671
"""

src/movie_barcodes/video_processing.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -51,33 +51,45 @@ def parallel_extract_colors(
5151
if target_frames is None:
5252
target_frames = frame_count
5353

54-
frames_per_worker = frame_count // workers
55-
target_frames_per_worker = target_frames // workers
56-
57-
with Pool(workers) as pool:
58-
args = [
54+
# Cap workers to available work to avoid empty tasks
55+
active_workers = max(1, min(workers, frame_count, target_frames))
56+
57+
# Evenly distribute frame ranges across active workers
58+
base_frames = frame_count // active_workers
59+
remainder_frames = frame_count % active_workers
60+
frame_ranges = []
61+
start = 0
62+
for i in range(active_workers):
63+
length = base_frames + (1 if i < remainder_frames else 0)
64+
end = start + length - 1
65+
frame_ranges.append((start, end))
66+
start = end + 1
67+
68+
# Evenly distribute target samples ensuring total equals target_frames
69+
base_samples = target_frames // active_workers
70+
remainder_samples = target_frames % active_workers
71+
samples_per_worker = [base_samples + (1 if i < remainder_samples else 0) for i in range(active_workers)]
72+
73+
# Build only tasks that have at least one sample (avoid passing 0 -> falsy)
74+
task_args = []
75+
for i in range(active_workers):
76+
if samples_per_worker[i] <= 0:
77+
continue
78+
start_frame, end_frame = frame_ranges[i]
79+
if end_frame < start_frame:
80+
continue
81+
task_args.append(
5982
(
6083
video_path,
61-
i * frames_per_worker,
62-
(i + 1) * frames_per_worker - 1,
63-
color_extractor,
64-
target_frames_per_worker,
65-
False, # disable worker progress bars
66-
)
67-
for i in range(workers)
68-
]
69-
70-
if frame_count % workers != 0 or target_frames % workers != 0:
71-
args[-1] = (
72-
video_path,
73-
args[-1][1],
74-
frame_count - 1,
84+
start_frame,
85+
end_frame,
7586
color_extractor,
76-
target_frames - (workers - 1) * target_frames_per_worker,
77-
False,
87+
samples_per_worker[i],
7888
)
89+
)
7990

80-
results = pool.starmap(extract_colors, args)
91+
with Pool(active_workers) as pool:
92+
results = pool.starmap(extract_colors, task_args)
8193

8294
# Concatenate results from all workers
8395
final_colors = [color for colors in results for color in colors]
@@ -91,7 +103,6 @@ def extract_colors(
91103
end_frame: int,
92104
color_extractor: Callable,
93105
target_frames: Optional[int] = None,
94-
show_progress: bool = True,
95106
) -> List:
96107
"""
97108
Extracts dominant colors from frames in a video file.
@@ -101,7 +112,6 @@ def extract_colors(
101112
:param int end_frame: The index of the last frame to process.
102113
:param Callable color_extractor: A function to extract the dominant color from a frame.
103114
:param Optional[int] target_frames: The total number of frames to sample.
104-
:param bool show_progress: Whether to display a tqdm progress bar during extraction.
105115
:return: List of dominant colors from the sampled frames.
106116
"""
107117
video = cv2.VideoCapture(video_path)
@@ -116,11 +126,7 @@ def extract_colors(
116126

117127
colors = []
118128

119-
iterator = range(target_frames or total_frames)
120-
if show_progress:
121-
iterator = tqdm(iterator, desc="Processing frames")
122-
123-
for _ in iterator:
129+
for _ in tqdm(range(target_frames or total_frames), desc="Processing frames"):
124130
ret, frame = video.read() # Read the first or next frame
125131
if ret:
126132
dominant_color = color_extractor(frame)

tests/test_utility.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def test_invalid_height(self) -> None:
114114
with self.assertRaises(ValueError):
115115
utility.validate_args(self.args, self.frame_count, self.MAX_PROCESSES, self.MIN_FRAME_COUNT)
116116

117-
self.args.height = self.frame_count + 1 # Testing for > frame_count
118-
with self.assertRaises(ValueError):
119-
utility.validate_args(self.args, self.frame_count, self.MAX_PROCESSES, self.MIN_FRAME_COUNT)
117+
# Heights greater than frame_count are now allowed (no exception expected)
118+
self.args.height = self.frame_count + 1
119+
utility.validate_args(self.args, self.frame_count, self.MAX_PROCESSES, self.MIN_FRAME_COUNT)
120120

121121
def test_frame_count_too_low(self) -> None:
122122
"""
@@ -144,13 +144,12 @@ def test_frame_count_too_low_specific_branch(self, mock_exists: MagicMock) -> No
144144

145145
def test_all_methods_and_method_error(self) -> None:
146146
"""
147-
Test that validate_args raises a ValueError when the --all_methods flag is used with the --method argument.
147+
When --all_methods is provided, it should no longer raise even if --method is set.
148148
:return: None
149149
"""
150150
self.args.all_methods = True
151151
self.args.method = "avg" # Explicitly setting method to simulate conflict
152-
with self.assertRaises(ValueError):
153-
utility.validate_args(self.args, self.frame_count, self.MAX_PROCESSES, self.MIN_FRAME_COUNT)
152+
utility.validate_args(self.args, self.frame_count, self.MAX_PROCESSES, self.MIN_FRAME_COUNT)
154153

155154
def test_no_error_raised(self) -> None:
156155
"""

0 commit comments

Comments
 (0)