Skip to content

Commit dc6234c

Browse files
authored
type-checker issues (#267)
1 parent 7fd414f commit dc6234c

File tree

6 files changed

+26
-13
lines changed

6 files changed

+26
-13
lines changed

src/cleanvision/dataset/base_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Sized
4-
from typing import List, Union
4+
from typing import List, Union, Optional
55

66
from PIL import Image
77

@@ -20,7 +20,7 @@ def __len__(self) -> int:
2020
"""Returns the number of examples in the dataset"""
2121
raise NotImplementedError
2222

23-
def __getitem__(self, item: Union[int, str]) -> Image.Image:
23+
def __getitem__(self, item: Union[int, str]) -> Optional[Image.Image]:
2424
"""Returns the image at a given index"""
2525
raise NotImplementedError
2626

src/cleanvision/dataset/hf_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __len__(self) -> int:
2929

3030
def __getitem__(self, item: Union[int, str]) -> Optional[Image.Image]:
3131
try:
32-
image = self._data[item][self._image_key]
32+
image: Image.Image = self._data[item][self._image_key]
3333
return image
3434
except Exception as e:
3535
print(f"Could not load image at index: {item}\n", e)

src/cleanvision/dataset/torch_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def __len__(self) -> int:
2626
return len(self._data)
2727

2828
def __getitem__(self, item: Union[int, str]) -> Image.Image:
29-
return self._data[item][self._image_idx]
29+
img: Image.Image = self._data[item][self._image_idx]
30+
return img
3031

3132
def get_name(self, index: Union[int, str]) -> str:
3233
return f"idx: {index}"

src/cleanvision/imagelab.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
import random
10-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar
10+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar, cast
1111

1212
import numpy as np
1313
import pandas as pd
@@ -502,6 +502,7 @@ def _visualize(
502502
scores = sorted_df.head(num_images)[get_score_colname(issue_type)]
503503
indices = scores.index.tolist()
504504
images = [self._dataset[i] for i in indices]
505+
images = cast(list[Image.Image], images)
505506

506507
# construct title info
507508
title_info = {"scores": [f"score : {x:.4f}" for x in scores]}
@@ -526,6 +527,7 @@ def _visualize(
526527
image_sets = []
527528
for indices in image_sets_indices:
528529
image_sets.append([self._dataset[index] for index in indices])
530+
image_sets = cast(list[list[Image.Image]], image_sets)
529531

530532
title_info_sets = []
531533
for s in image_sets_indices:
@@ -620,7 +622,7 @@ def visualize(
620622
elif image_files is not None:
621623
if len(image_files) == 0:
622624
raise ValueError("image_files list is empty.")
623-
images = [Image.open(path) for path in image_files]
625+
images: List[Image.Image] = [Image.open(path) for path in image_files]
624626
title_info = {"path": [path.split("/")[-1] for path in image_files]}
625627
VizManager.individual_images(
626628
images,
@@ -629,7 +631,7 @@ def visualize(
629631
cell_size=cell_size,
630632
)
631633
elif indices:
632-
images = [self._dataset[i] for i in indices]
634+
images = [cast(Image.Image, self._dataset[i]) for i in indices]
633635
title_info = {"name": [self._dataset.get_name(i) for i in indices]}
634636
VizManager.individual_images(
635637
images,
@@ -644,7 +646,7 @@ def visualize(
644646
image_indices = random.sample(
645647
self._dataset.index, min(num_images, len(self._dataset))
646648
)
647-
images = [self._dataset[i] for i in image_indices]
649+
images = [cast(Image.Image, self._dataset[i]) for i in image_indices]
648650
title_info = {
649651
"name": [self._dataset.get_name(i) for i in image_indices]
650652
}

src/cleanvision/issue_managers/duplicate_issue_manager.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,20 @@
1717

1818
def get_hash(image: Image.Image, params: Dict[str, Any]) -> str:
1919
hash_type, hash_size = params["hash_type"], params.get("hash_size", None)
20+
supported_types = ["md5", "whash", "phash", "ahash", "dhash", "chash"]
21+
if hash_type not in supported_types:
22+
raise ValueError(
23+
f"Hash type `{hash_type}` is not supported. Must be one of: {supported_types}"
24+
)
25+
2026
if hash_type == "md5":
2127
pixels = np.asarray(image)
2228
return hashlib.md5(pixels.tobytes()).hexdigest()
23-
elif hash_type == "whash":
29+
30+
if not isinstance(hash_size, int):
31+
raise ValueError("hash_size must be declared as a int in params")
32+
33+
if hash_type == "whash":
2434
return str(imagehash.whash(image, hash_size=hash_size))
2535
elif hash_type == "phash":
2636
return str(imagehash.phash(image, hash_size=hash_size))
@@ -31,7 +41,7 @@ def get_hash(image: Image.Image, params: Dict[str, Any]) -> str:
3141
elif hash_type == "chash":
3242
return str(imagehash.colorhash(image, binbits=hash_size))
3343
else:
34-
raise ValueError("Hash type not supported")
44+
raise ValueError("hash_type not supported")
3545

3646

3747
def compute_hash(

src/cleanvision/utils/viz_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Dict
1+
from typing import List, Tuple, Dict, Sequence
22

33
import math
44
import matplotlib.axes
@@ -9,7 +9,7 @@
99
class VizManager:
1010
@staticmethod
1111
def individual_images(
12-
images: List[Image.Image],
12+
images: Sequence[Image.Image],
1313
title_info: Dict[str, List[str]],
1414
ncols: int,
1515
cell_size: Tuple[int, int],
@@ -86,7 +86,7 @@ def construct_titles(title_info: Dict[str, List[str]], cell_width: int) -> List[
8686

8787

8888
def plot_image_grid(
89-
images: List[Image.Image],
89+
images: Sequence[Image.Image],
9090
title_info: Dict[str, List[str]],
9191
ncols: int,
9292
cell_size: Tuple[int, int],

0 commit comments

Comments
 (0)