Skip to content

Commit 8520cd5

Browse files
johnbradleythompsonmjegrace479
committed
Allow using PIL with predict functions
Fixes #54 Co-authored-by: Matt Thompson <[email protected]> Co-authored-by: Elizabeth Campolongo <[email protected]>
1 parent d8035ef commit 8520cd5

File tree

4 files changed

+105
-46
lines changed

4 files changed

+105
-46
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ big 0.99992835521698
120120
small 7.165559509303421e-05
121121
```
122122

123+
### PIL Images
124+
The predict() functions used in all the examples above allow passing a list of paths or a list of [PIL Images](https://pillow.readthedocs.io/en/stable/reference/Image.html).
125+
When a list of PIL images is passed the index of the image will be filled in for `file_name`. This is because PIL images may not have an associated file name.
126+
127+
123128
## Command Line Usage
124129
```
125130
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT]

src/bioclip/__main__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ def predict(image_file: list[str],
5151
**kwargs):
5252
if cls_str:
5353
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs)
54-
predictions = classifier.predict(image_paths=image_file, k=k)
54+
predictions = classifier.predict(images=image_file, k=k)
5555
write_results(predictions, format, output)
5656
elif bins_path:
5757
cls_to_bin = parse_bins_csv(bins_path)
5858
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
59-
predictions = classifier.predict(image_paths=image_file, k=k)
59+
predictions = classifier.predict(images=image_file, k=k)
6060
write_results(predictions, format, output)
6161
else:
6262
classifier = TreeOfLifeClassifier(**kwargs)
63-
predictions = classifier.predict(image_paths=image_file, rank=rank, k=k)
63+
predictions = classifier.predict(images=image_file, rank=rank, k=k)
6464
write_results(predictions, format, output)
6565

6666

@@ -72,7 +72,7 @@ def embed(image_file: list[str], output: str, **kwargs):
7272
"embeddings": images_dict
7373
}
7474
for image_path in image_file:
75-
features = classifier.create_image_features_for_path(image_path=image_path, normalize=False)
75+
features = classifier.create_image_features_for_image(image=image_path, normalize=False)
7676
images_dict[image_path] = features.tolist()
7777
if output == 'stdout':
7878
print(json.dumps(data, indent=4))

src/bioclip/predict.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,20 @@ def load_pretrained_model(self, model_str: str = BIOCLIP_MODEL_STR, pretrained_s
184184
self.preprocess = preprocess_img if self.model_str == BIOCLIP_MODEL_STR else preprocess
185185

186186
@staticmethod
187-
def open_image(image_path):
188-
img = PIL.Image.open(image_path)
187+
def ensure_rgb_image(image: str | PIL.Image.Image) -> PIL.Image.Image:
188+
if isinstance(image, PIL.Image.Image):
189+
img = image
190+
else:
191+
img = PIL.Image.open(image)
189192
return img.convert("RGB")
190193

194+
@staticmethod
195+
def make_key(image: str | PIL.Image.Image, idx: int) -> str:
196+
if isinstance(image, PIL.Image.Image):
197+
return f"{idx}"
198+
else:
199+
return image
200+
191201
@torch.no_grad()
192202
def create_image_features(self, images: List[PIL.Image.Image], normalize : bool = True) -> torch.Tensor:
193203
preprocessed_images = []
@@ -202,8 +212,8 @@ def create_image_features(self, images: List[PIL.Image.Image], normalize : bool
202212
return img_features
203213

204214
@torch.no_grad()
205-
def create_image_features_for_path(self, image_path: str, normalize: bool) -> torch.Tensor:
206-
img = self.open_image(image_path)
215+
def create_image_features_for_image(self, image: str | PIL.Image.Image, normalize: bool) -> torch.Tensor:
216+
img = self.ensure_rgb_image(image)
207217
result = self.create_image_features([img], normalize=normalize)
208218
return result[0]
209219

@@ -213,13 +223,14 @@ def create_probabilities(self, img_features: torch.Tensor,
213223
logits = (self.model.logit_scale.exp() * img_features @ txt_features)
214224
return F.softmax(logits, dim=1)
215225

216-
def create_probabilities_for_image_paths(self, image_paths: List[str] | str,
217-
txt_features: torch.Tensor) -> dict[str, torch.Tensor]:
218-
images = [self.open_image(image_path) for image_path in image_paths]
226+
def create_probabilities_for_images(self, images: List[str] | List[PIL.Image.Image],
227+
txt_features: torch.Tensor) -> dict[str, torch.Tensor]:
228+
keys = [self.make_key(image, i) for i,image in enumerate(images)]
229+
images = [self.ensure_rgb_image(image) for image in images]
219230
img_features = self.create_image_features(images)
220231
probs = self.create_probabilities(img_features, txt_features)
221232
result = {}
222-
for i, key in enumerate(image_paths):
233+
for i, key in enumerate(keys):
223234
result[key] = probs[i]
224235
return result
225236

@@ -245,24 +256,25 @@ def _get_txt_features(self, classnames):
245256
return all_features
246257

247258
@torch.no_grad()
248-
def predict(self, image_paths: List[str] | str, k: int = None) -> dict[str, float]:
249-
if isinstance(image_paths, str):
250-
image_paths = [image_paths]
251-
probs = self.create_probabilities_for_image_paths(image_paths, self.txt_features)
259+
def predict(self, images: List[str] | str | List[PIL.Image.Image], k: int = None) -> dict[str, float]:
260+
if isinstance(images, str):
261+
images = [images]
262+
probs = self.create_probabilities_for_images(images, self.txt_features)
252263
result = []
253-
for image_path in image_paths:
254-
img_probs = probs[image_path]
264+
for i, image in enumerate(images):
265+
key = self.make_key(image, i)
266+
img_probs = probs[key]
255267
if not k or k > len(self.classes):
256268
k = len(self.classes)
257-
result.extend(self.group_probs(image_path, img_probs, k))
269+
result.extend(self.group_probs(key, img_probs, k))
258270
return result
259271

260-
def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
272+
def group_probs(self, image_key: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
261273
result = []
262274
topk = img_probs.topk(k)
263275
for i, prob in zip(topk.indices, topk.values):
264276
result.append({
265-
PRED_FILENAME_KEY: image_path,
277+
PRED_FILENAME_KEY: image_key,
266278
PRED_CLASSICATION_KEY: self.classes[i],
267279
PRED_SCORE_KEY: prob.item()
268280
})
@@ -276,7 +288,7 @@ def __init__(self, cls_to_bin: dict, **kwargs):
276288
if any([pd.isna(x) or not x for x in cls_to_bin.values()]):
277289
raise ValueError("Empty, null, or nan are not allowed for bin values.")
278290

279-
def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
291+
def group_probs(self, image_key: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
280292
result = []
281293
output = collections.defaultdict(float)
282294
for i in range(len(self.classes)):
@@ -285,7 +297,7 @@ def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -
285297
topk_names = heapq.nlargest(k, output, key=output.get)
286298
for name in topk_names:
287299
result.append({
288-
PRED_FILENAME_KEY: image_path,
300+
PRED_FILENAME_KEY: image_key,
289301
PRED_CLASSICATION_KEY: name,
290302
PRED_SCORE_KEY: output[name].item()
291303
})
@@ -335,17 +347,17 @@ def __init__(self, **kwargs):
335347
self.txt_features = get_txt_emb().to(self.device)
336348
self.txt_names = get_txt_names()
337349

338-
def format_species_probs(self, image_path: str, probs: torch.Tensor, k: int = 5) -> List[dict[str, float]]:
350+
def format_species_probs(self, image_key: str, probs: torch.Tensor, k: int = 5) -> List[dict[str, float]]:
339351
topk = probs.topk(k)
340352
result = []
341353
for i, prob in zip(topk.indices, topk.values):
342-
item = { PRED_FILENAME_KEY: image_path }
354+
item = { PRED_FILENAME_KEY: image_key }
343355
item.update(create_classification_dict(self.txt_names[i], Rank.SPECIES))
344356
item[PRED_SCORE_KEY] = prob.item()
345357
result.append(item)
346358
return result
347359

348-
def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]:
360+
def format_grouped_probs(self, image_key: str, probs: torch.Tensor, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]:
349361
output = collections.defaultdict(float)
350362
class_dict_lookup = {}
351363
name_to_class_dict = {}
@@ -358,27 +370,28 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
358370
topk_names = heapq.nlargest(k, output, key=output.get)
359371
prediction_ary = []
360372
for name in topk_names:
361-
item = { PRED_FILENAME_KEY: image_path }
373+
item = { PRED_FILENAME_KEY: image_key }
362374
item.update(name_to_class_dict[name])
363375
item[PRED_SCORE_KEY] = output[name].item()
364376
prediction_ary.append(item)
365377
return prediction_ary
366378

367379
@torch.no_grad()
368-
def predict(self, image_paths: List[str] | str, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> dict[str, dict[str, float]]:
369-
if isinstance(image_paths, str):
370-
image_paths = [image_paths]
371-
probs = self.create_probabilities_for_image_paths(image_paths, self.txt_features)
380+
def predict(self, images: List[str] | str | List[PIL.Image.Image], rank: Rank, min_prob: float = 1e-9, k: int = 5) -> dict[str, dict[str, float]]:
381+
if isinstance(images, str):
382+
images = [images]
383+
probs = self.create_probabilities_for_images(images, self.txt_features)
372384
result = []
373-
for image_path in image_paths:
385+
for i, image in enumerate(images):
386+
key = self.make_key(image, i)
374387
if rank == Rank.SPECIES:
375-
result.extend(self.format_species_probs(image_path, probs[image_path], k))
388+
result.extend(self.format_species_probs(key, probs[key], k))
376389
else:
377-
result.extend(self.format_grouped_probs(image_path, probs[image_path], rank, min_prob, k))
390+
result.extend(self.format_grouped_probs(key, probs[key], rank, min_prob, k))
378391
return result
379392

380393

381-
def predict_classification(img: str, rank: Rank, device: Union[str, torch.device] = 'cpu',
394+
def predict_classification(img: Union[PIL.Image.Image, str], rank: Rank, device: Union[str, torch.device] = 'cpu',
382395
min_prob: float = 1e-9, k: int = 5) -> dict[str, float]:
383396
"""
384397
Predicts from the entire tree of life.

tests/test_predict.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import torch
77
import pandas as pd
8+
import PIL.Image
89

910

1011
DIRNAME = os.path.dirname(os.path.realpath(__file__))
@@ -14,7 +15,7 @@
1415
class TestPredict(unittest.TestCase):
1516
def test_tree_of_life_classifier_species_single(self):
1617
classifier = TreeOfLifeClassifier()
17-
prediction_ary = classifier.predict(image_paths=EXAMPLE_CAT_IMAGE, rank=Rank.SPECIES)
18+
prediction_ary = classifier.predict(images=EXAMPLE_CAT_IMAGE, rank=Rank.SPECIES)
1819
self.assertEqual(len(prediction_ary), 5)
1920
prediction_dict = {
2021
'file_name': EXAMPLE_CAT_IMAGE,
@@ -33,18 +34,26 @@ def test_tree_of_life_classifier_species_single(self):
3334

3435
def test_tree_of_life_classifier_species_ary_one(self):
3536
classifier = TreeOfLifeClassifier()
36-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE], rank=Rank.SPECIES)
37+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE], rank=Rank.SPECIES)
3738
self.assertEqual(len(prediction_ary), 5)
3839

3940
def test_tree_of_life_classifier_species_ary_multiple(self):
4041
classifier = TreeOfLifeClassifier()
41-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE, EXAMPLE_CAT_IMAGE2],
42+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE, EXAMPLE_CAT_IMAGE2],
43+
rank=Rank.SPECIES)
44+
self.assertEqual(len(prediction_ary), 10)
45+
46+
def test_tree_of_life_classifier_species_ary_multiple_pil(self):
47+
classifier = TreeOfLifeClassifier()
48+
img1 = PIL.Image.open(EXAMPLE_CAT_IMAGE)
49+
img2 = PIL.Image.open(EXAMPLE_CAT_IMAGE2)
50+
prediction_ary = classifier.predict(images=[img1, img2],
4251
rank=Rank.SPECIES)
4352
self.assertEqual(len(prediction_ary), 10)
4453

4554
def test_tree_of_life_classifier_family(self):
4655
classifier = TreeOfLifeClassifier()
47-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE], rank=Rank.FAMILY, k=2)
56+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE], rank=Rank.FAMILY, k=2)
4857
self.assertEqual(len(prediction_ary), 2)
4958
prediction_dict = {
5059
'file_name': EXAMPLE_CAT_IMAGE,
@@ -59,34 +68,46 @@ def test_tree_of_life_classifier_family(self):
5968

6069
def test_custom_labels_classifier(self):
6170
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
62-
prediction_ary = classifier.predict(image_paths=EXAMPLE_CAT_IMAGE)
71+
prediction_ary = classifier.predict(images=EXAMPLE_CAT_IMAGE)
6372
self.assertEqual(prediction_ary, [
6473
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
6574
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
6675
])
6776

6877
def test_custom_labels_classifier_ary_one(self):
6978
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
70-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE])
79+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE])
7180
self.assertEqual(prediction_ary, [
7281
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
7382
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
7483
])
7584

7685
def test_custom_labels_classifier_ary_multiple(self):
7786
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
78-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE, EXAMPLE_CAT_IMAGE2])
87+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE, EXAMPLE_CAT_IMAGE2])
7988
self.assertEqual(prediction_ary, [
8089
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
8190
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
8291
{'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'cat', 'score': unittest.mock.ANY},
8392
{'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'dog', 'score': unittest.mock.ANY},
8493
])
8594

95+
def test_custom_labels_classifier_ary_multiple_pil(self):
96+
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
97+
img1 = PIL.Image.open(EXAMPLE_CAT_IMAGE)
98+
img2 = PIL.Image.open(EXAMPLE_CAT_IMAGE2)
99+
prediction_ary = classifier.predict(images=[img1, img2])
100+
self.assertEqual(prediction_ary, [
101+
{'file_name': '0', 'classification': 'cat', 'score': unittest.mock.ANY},
102+
{'file_name': '0', 'classification': 'dog', 'score': unittest.mock.ANY},
103+
{'file_name': '1', 'classification': 'cat', 'score': unittest.mock.ANY},
104+
{'file_name': '1', 'classification': 'dog', 'score': unittest.mock.ANY},
105+
])
106+
86107
def test_predict_with_rgba_image(self):
87108
# Ensure that the classifier can handle RGBA images
88109
classifier = TreeOfLifeClassifier()
89-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2], rank=Rank.SPECIES)
110+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE2], rank=Rank.SPECIES)
90111
self.assertEqual(len(prediction_ary), 5)
91112

92113
def test_predict_with_bins(self):
@@ -95,7 +116,7 @@ def test_predict_with_bins(self):
95116
'mouse': 'two',
96117
'fish': 'two',
97118
})
98-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
119+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE2])
99120
self.assertEqual(len(prediction_ary), 2)
100121
self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2)
101122
names = set([pred['classification'] for pred in prediction_ary])
@@ -106,7 +127,7 @@ def test_predict_with_bins(self):
106127
'mouse': 'two',
107128
'fish': 'three',
108129
})
109-
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
130+
prediction_ary = classifier.predict(images=[EXAMPLE_CAT_IMAGE2])
110131
self.assertEqual(len(prediction_ary), 3)
111132
self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2)
112133
names = set([pred['classification'] for pred in prediction_ary])
@@ -138,9 +159,29 @@ def test_predict_with_bins_bad_values(self):
138159
self.assertEqual(str(raised_exceptions.exception),
139160
"Empty, null, or nan are not allowed for bin values.")
140161

162+
def test_predict_with_bins_pil(self):
163+
classifier = CustomLabelsBinningClassifier(cls_to_bin={
164+
'cat': 'one',
165+
'mouse': 'two',
166+
'fish': 'two',
167+
})
168+
img1 = PIL.Image.open(EXAMPLE_CAT_IMAGE)
169+
prediction_ary = classifier.predict(images=[img1])
170+
self.assertEqual(len(prediction_ary), 2)
171+
self.assertEqual(prediction_ary[0]['file_name'], '0')
172+
names = set([pred['classification'] for pred in prediction_ary])
173+
self.assertEqual(names, set(['one', 'two']))
174+
175+
141176
class TestEmbed(unittest.TestCase):
142177
def test_get_image_features(self):
143178
classifier = TreeOfLifeClassifier(device='cpu')
144179
self.assertEqual(classifier.model_str, 'hf-hub:imageomics/bioclip')
145-
features = classifier.create_image_features_for_path(EXAMPLE_CAT_IMAGE, normalize=False)
180+
features = classifier.create_image_features_for_image(EXAMPLE_CAT_IMAGE, normalize=False)
181+
self.assertEqual(features.shape, torch.Size([512]))
182+
183+
def test_get_image_features_pil(self):
184+
classifier = TreeOfLifeClassifier(device='cpu')
185+
self.assertEqual(classifier.model_str, 'hf-hub:imageomics/bioclip')
186+
features = classifier.create_image_features_for_image(PIL.Image.open(EXAMPLE_CAT_IMAGE), normalize=False)
146187
self.assertEqual(features.shape, torch.Size([512]))

0 commit comments

Comments
 (0)