diff --git a/src/cleanvision/dataset/base_dataset.py b/src/cleanvision/dataset/base_dataset.py index 9c4cccf5..66a39e57 100644 --- a/src/cleanvision/dataset/base_dataset.py +++ b/src/cleanvision/dataset/base_dataset.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sized -from typing import List, Union +from typing import List, Union, Optional from PIL import Image @@ -20,7 +20,7 @@ def __len__(self) -> int: """Returns the number of examples in the dataset""" raise NotImplementedError - def __getitem__(self, item: Union[int, str]) -> Image.Image: + def __getitem__(self, item: Union[int, str]) -> Optional[Image.Image]: """Returns the image at a given index""" raise NotImplementedError diff --git a/src/cleanvision/dataset/hf_dataset.py b/src/cleanvision/dataset/hf_dataset.py index c02ab964..1ae6fc10 100644 --- a/src/cleanvision/dataset/hf_dataset.py +++ b/src/cleanvision/dataset/hf_dataset.py @@ -29,7 +29,7 @@ def __len__(self) -> int: def __getitem__(self, item: Union[int, str]) -> Optional[Image.Image]: try: - image = self._data[item][self._image_key] + image: Image.Image = self._data[item][self._image_key] return image except Exception as e: print(f"Could not load image at index: {item}\n", e) diff --git a/src/cleanvision/dataset/torch_dataset.py b/src/cleanvision/dataset/torch_dataset.py index 8e8b7973..4e62ba3e 100644 --- a/src/cleanvision/dataset/torch_dataset.py +++ b/src/cleanvision/dataset/torch_dataset.py @@ -26,7 +26,8 @@ def __len__(self) -> int: return len(self._data) def __getitem__(self, item: Union[int, str]) -> Image.Image: - return self._data[item][self._image_idx] + img: Image.Image = self._data[item][self._image_idx] + return img def get_name(self, index: Union[int, str]) -> str: return f"idx: {index}" diff --git a/src/cleanvision/imagelab.py b/src/cleanvision/imagelab.py index 583710bb..188b3c00 100644 --- a/src/cleanvision/imagelab.py +++ b/src/cleanvision/imagelab.py @@ -7,7 +7,7 @@ from __future__ import annotations import random -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar, cast import numpy as np import pandas as pd @@ -502,6 +502,7 @@ def _visualize( scores = sorted_df.head(num_images)[get_score_colname(issue_type)] indices = scores.index.tolist() images = [self._dataset[i] for i in indices] + images = cast(list[Image.Image], images) # construct title info title_info = {"scores": [f"score : {x:.4f}" for x in scores]} @@ -526,6 +527,7 @@ def _visualize( image_sets = [] for indices in image_sets_indices: image_sets.append([self._dataset[index] for index in indices]) + image_sets = cast(list[list[Image.Image]], image_sets) title_info_sets = [] for s in image_sets_indices: @@ -620,7 +622,7 @@ def visualize( elif image_files is not None: if len(image_files) == 0: raise ValueError("image_files list is empty.") - images = [Image.open(path) for path in image_files] + images: List[Image.Image] = [Image.open(path) for path in image_files] title_info = {"path": [path.split("/")[-1] for path in image_files]} VizManager.individual_images( images, @@ -629,7 +631,7 @@ def visualize( cell_size=cell_size, ) elif indices: - images = [self._dataset[i] for i in indices] + images = [cast(Image.Image, self._dataset[i]) for i in indices] title_info = {"name": [self._dataset.get_name(i) for i in indices]} VizManager.individual_images( images, @@ -644,7 +646,7 @@ def visualize( image_indices = random.sample( self._dataset.index, min(num_images, len(self._dataset)) ) - images = [self._dataset[i] for i in image_indices] + images = [cast(Image.Image, self._dataset[i]) for i in image_indices] title_info = { "name": [self._dataset.get_name(i) for i in image_indices] } diff --git a/src/cleanvision/issue_managers/duplicate_issue_manager.py b/src/cleanvision/issue_managers/duplicate_issue_manager.py index 81ea3811..6b5f3f72 100644 --- a/src/cleanvision/issue_managers/duplicate_issue_manager.py +++ b/src/cleanvision/issue_managers/duplicate_issue_manager.py @@ -17,10 +17,20 @@ def get_hash(image: Image.Image, params: Dict[str, Any]) -> str: hash_type, hash_size = params["hash_type"], params.get("hash_size", None) + supported_types = ["md5", "whash", "phash", "ahash", "dhash", "chash"] + if hash_type not in supported_types: + raise ValueError( + f"Hash type `{hash_type}` is not supported. Must be one of: {supported_types}" + ) + if hash_type == "md5": pixels = np.asarray(image) return hashlib.md5(pixels.tobytes()).hexdigest() - elif hash_type == "whash": + + if not isinstance(hash_size, int): + raise ValueError("hash_size must be declared as a int in params") + + if hash_type == "whash": return str(imagehash.whash(image, hash_size=hash_size)) elif hash_type == "phash": return str(imagehash.phash(image, hash_size=hash_size)) @@ -31,7 +41,7 @@ def get_hash(image: Image.Image, params: Dict[str, Any]) -> str: elif hash_type == "chash": return str(imagehash.colorhash(image, binbits=hash_size)) else: - raise ValueError("Hash type not supported") + raise ValueError("hash_type not supported") def compute_hash( diff --git a/src/cleanvision/utils/viz_manager.py b/src/cleanvision/utils/viz_manager.py index f0838596..519c1e73 100644 --- a/src/cleanvision/utils/viz_manager.py +++ b/src/cleanvision/utils/viz_manager.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Dict +from typing import List, Tuple, Dict, Sequence import math import matplotlib.axes @@ -9,7 +9,7 @@ class VizManager: @staticmethod def individual_images( - images: List[Image.Image], + images: Sequence[Image.Image], title_info: Dict[str, List[str]], ncols: int, cell_size: Tuple[int, int], @@ -86,7 +86,7 @@ def construct_titles(title_info: Dict[str, List[str]], cell_width: int) -> List[ def plot_image_grid( - images: List[Image.Image], + images: Sequence[Image.Image], title_info: Dict[str, List[str]], ncols: int, cell_size: Tuple[int, int],