Skip to content

Commit 3a574b6

Browse files
authored
feat: restructure code (#30)
1 parent 7f8d8a4 commit 3a574b6

15 files changed

+133
-42
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ usage: movie-barcodes [-h] -i INPUT_VIDEO_PATH [-d [DESTINATION_PATH]] [-t {hori
6464
# Examples
6565
## Sequential Processing
6666
```python
67-
python -m src.main -i "path/to/video" --width 200 -w 1
67+
python -m movie_barcodes -i "path/to/video" --width 200 -w 1
6868
```
6969
## Parallel Processing
7070
```python
71-
python -m src.main -i "path/to/video" --width 200 -w 8
71+
python -m movie_barcodes -i "path/to/video" --width 200 -w 8
7272
```
7373
7474
# Development Setup
@@ -94,7 +94,7 @@ $ uv pip install pytest pytest-cov
9494
$ uv run pytest tests/
9595

9696
# Run package locally
97-
$ uv run python -m src.main -i "path_to_video.mp4"
97+
$ uv run python -m movie_barcodes -i "path_to_video.mp4"
9898
```
9999
100100
# Todo

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@ Homepage = "https://github.com/Wazzabeee/movie-barcodes"
4545
Repository = "https://github.com/Wazzabeee/movie-barcodes"
4646

4747
[project.scripts]
48-
movie-barcodes = "src.main:main"
48+
movie-barcodes = "movie_barcodes.cli:main"
49+
50+
[tool.setuptools]
51+
package-dir = {"" = "src"}
52+
53+
[tool.setuptools.packages.find]
54+
where = ["src"]
55+
4956

5057
[tool.setuptools_scm]
5158
version_scheme = "guess-next-dev"

src/movie_barcodes/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Public API for movie_barcodes.
2+
3+
This package provides CLI and library functions to generate movie color barcodes.
4+
"""
5+
6+
from . import barcode_generation as barcode
7+
from . import barcode_generation as barcode_generation
8+
from . import color_extraction
9+
from . import video_processing
10+
from .cli import main as main
11+
from . import utility
12+
13+
__all__ = [
14+
"barcode",
15+
"barcode_generation",
16+
"color_extraction",
17+
"video_processing",
18+
"utility",
19+
"main",
20+
]

src/movie_barcodes/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .cli import main
2+
3+
if __name__ == "__main__":
4+
main()
File renamed without changes.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def get_dominant_color_kmeans(frame: np.ndarray, k: int = 3) -> np.ndarray:
3030
Gets the dominant color of a frame using KMeans clustering.
3131
3232
:param np.ndarray frame: The frame as a NumPy array.
33-
:param int k: Number of clusters for KMeans algorithm. Defaults to 1.
33+
:param int k: Number of clusters for KMeans algorithm. Defaults to 3.
3434
:return: Dominant color as a NumPy array.
3535
"""
3636
# Reshape the frame to be a list of pixels
3737
pixels = frame.reshape(-1, 3)
3838

3939
# Apply KMeans clustering
40-
kmeans = KMeans(n_clusters=k, n_init=10)
40+
kmeans = KMeans(n_clusters=k, n_init=10, random_state=0)
4141
kmeans.fit(pixels)
4242

4343
# Get the RGB values of the cluster centers
Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,13 @@ def validate_args(args: argparse.Namespace, frame_count: int, MAX_PROCESSES: int
5757
if args.width <= 0:
5858
raise ValueError("Width must be greater than 0.")
5959
if args.width > frame_count:
60-
raise ValueError(
61-
f"Specified width ({args.width}) cannot be greater than the number of frames ({frame_count}) in the "
62-
f"video."
63-
)
60+
raise ValueError("Width must be less than or equal to the number of frames.")
6461

6562
if args.height is not None:
6663
if args.height <= 0:
6764
raise ValueError("Height must be greater than 0.")
6865
if args.height > frame_count:
69-
raise ValueError(
70-
f"Specified height ({args.height}) cannot be greater than the number of frames ({frame_count}) in the "
71-
f"video."
72-
)
66+
raise ValueError("Height must be less than or equal to the number of frames.")
7367

7468
if frame_count < MIN_FRAME_COUNT:
7569
raise ValueError(f"The video must have at least {MIN_FRAME_COUNT} frames.")
@@ -142,7 +136,8 @@ def save_barcode_image(barcode: np.ndarray, base_name: str, args: argparse.Names
142136
:param str method: The method used for color extraction.
143137
"""
144138
current_dir = path.dirname(path.abspath(__file__))
145-
project_root = path.dirname(current_dir) # Go up one directory to get to the project root
139+
# Go up two directories to reach the repository root (…/src/movie_barcodes -> …/src -> repo root)
140+
project_root = path.dirname(path.dirname(current_dir))
146141
# If destination_path isn't specified, construct one based on the video's name
147142
if not args.destination_path:
148143
barcode_dir = path.join(project_root, "barcodes")
@@ -159,9 +154,10 @@ def save_barcode_image(barcode: np.ndarray, base_name: str, args: argparse.Names
159154

160155
destination_path = path.join(barcode_dir, destination_name)
161156
else:
162-
# In case a destination_path is provided, consider appending the method
163-
# or managing as per your requirement
164-
destination_path = path.join(project_root, args.destination_path)
157+
# Use absolute path as-is; if relative, make it relative to project root
158+
destination_path = args.destination_path
159+
if not path.isabs(destination_path):
160+
destination_path = path.join(project_root, destination_path)
165161

166162
if barcode.shape[2] == 4: # If the image has an alpha channel (RGBA)
167163
image = Image.fromarray(barcode, "RGBA")

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import os
2+
import sys
3+
4+
5+
# Ensure the src/ directory is on sys.path so 'movie_barcodes' is importable in tests
6+
_THIS_DIR = os.path.dirname(__file__)
7+
_REPO_ROOT = os.path.abspath(os.path.join(_THIS_DIR, ".."))
8+
_SRC_DIR = os.path.join(_REPO_ROOT, "src")
9+
if _SRC_DIR not in sys.path:
10+
sys.path.insert(0, _SRC_DIR)

0 commit comments

Comments
 (0)