Skip to content

Commit d6ebb1d

Browse files
authored
Merge pull request #13 from Imageomics/12-rgba-fix
Fix for RGBA png files
2 parents 88eee31 + 2179742 commit d6ebb1d

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

src/bioclip/predict.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ def get_txt_names():
119119
return txt_names
120120

121121

122+
def open_image(image_path):
123+
img = PIL.Image.open(image_path)
124+
return img.convert("RGB")
125+
126+
122127
preprocess_img = transforms.Compose(
123128
[
124129
transforms.ToTensor(),
@@ -181,7 +186,7 @@ def get_txt_features(self, classnames):
181186

182187
@torch.no_grad()
183188
def predict(self, image_path: str, cls_ary: List[str]) -> dict[str, float]:
184-
img = PIL.Image.open(image_path)
189+
img = open_image(image_path)
185190
classes = [cls.strip() for cls in cls_ary]
186191
txt_features = self.get_txt_features(classes)
187192

@@ -248,7 +253,7 @@ def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MO
248253

249254
@torch.no_grad()
250255
def get_image_features(self, image_path: str) -> torch.Tensor:
251-
img = PIL.Image.open(image_path)
256+
img = open_image(image_path)
252257
return self.encode_image(img)
253258

254259
def encode_image(self, img: PIL.Image.Image) -> torch.Tensor:
@@ -295,7 +300,7 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
295300

296301
@torch.no_grad()
297302
def predict(self, image_path: str, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]:
298-
img = PIL.Image.open(image_path)
303+
img = open_image(image_path)
299304
probs = self.predict_species(img)
300305
if rank == Rank.SPECIES:
301306
return self.format_species_probs(image_path, probs, k)

tests/images/mycat.png

112 KB
Loading

tests/test_predict.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
DIRNAME = os.path.dirname(os.path.realpath(__file__))
99
EXAMPLE_CAT_IMAGE = os.path.join(DIRNAME, "images", "mycat.jpg")
10+
EXAMPLE_CAT_IMAGE2 = os.path.join(DIRNAME, "images", "mycat.png")
1011

1112
class TestPredict(unittest.TestCase):
1213
def test_tree_of_life_classifier_species(self):
@@ -51,6 +52,12 @@ def test_custom_labels_classifier(self):
5152
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
5253
])
5354

55+
def test_predict_with_rgba_image(self):
56+
# Ensure that the classifier can handle RGBA images
57+
classifier = TreeOfLifeClassifier()
58+
prediction_ary = classifier.predict(image_path=EXAMPLE_CAT_IMAGE2, rank=Rank.SPECIES)
59+
self.assertEqual(len(prediction_ary), 5)
60+
5461

5562
class TestEmbed(unittest.TestCase):
5663
def test_get_image_features(self):

0 commit comments

Comments
 (0)