Skip to content

Commit 50600a5

Browse files
committed
Allow batch processing of images
Instead of processing images one at a time allow loading multiple and processing them all at once. When used with the GPU this may improve processing time. Fixes #11
1 parent a8d3a58 commit 50600a5

File tree

4 files changed

+136
-84
lines changed

4 files changed

+136
-84
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ predictions = classifier.predict("Ursus-arctos.jpeg", Rank.SPECIES)
8585
df = pd.DataFrame(predictions)
8686
```
8787

88+
The first argument of the `predict()` method supports both a single path or a list of paths.
89+
8890
### Predict from a list of classes
8991
```python
9092
from bioclip import CustomLabelsClassifier

src/bioclip/__main__.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,17 @@ def write_results_to_file(df, format, outfile):
2828
else:
2929
raise ValueError(f"Invalid format: {format}")
3030

31+
3132
def predict(image_file: list[str], format: str, output: str,
3233
cls_str: str, device: str, rank: Rank, k: int):
3334
if cls_str:
3435
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), device=device)
35-
data = []
36-
for image_path in image_file:
37-
data.extend(classifier.predict(image_path=image_path))
38-
write_results(data, format, output)
36+
predictions = classifier.predict(image_paths=image_file)
37+
write_results(predictions, format, output)
3938
else:
4039
classifier = TreeOfLifeClassifier(device=device)
41-
data = []
42-
for image_path in image_file:
43-
data.extend(classifier.predict(image_path=image_path, rank=rank, k=k))
44-
write_results(data, format, output)
40+
predictions = classifier.predict(image_paths=image_file, rank=rank, k=k)
41+
write_results(predictions, format, output)
4542

4643

4744
def embed(image_file: list[str], output: str, device: str):
@@ -52,14 +49,14 @@ def embed(image_file: list[str], output: str, device: str):
5249
"embeddings": images_dict
5350
}
5451
for image_path in image_file:
55-
features = classifier.get_image_features(image_path)[0]
52+
features = classifier.create_image_features_for_path(image_path=image_path, normalize=False)
5653
images_dict[image_path] = features.tolist()
5754
if output == 'stdout':
5855
print(json.dumps(data, indent=4))
5956
else:
6057
with open(output, 'w') as outfile:
61-
json.dump(data, outfile, indent=4)
62-
58+
json.dump(data, outfile, indent=4)
59+
6360

6461
def create_parser():
6562
parser = argparse.ArgumentParser(prog='bioclip', description='BioCLIP command line interface')

src/bioclip/predict.py

Lines changed: 88 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,6 @@ def get_txt_names():
119119
return txt_names
120120

121121

122-
def open_image(image_path):
123-
img = PIL.Image.open(image_path)
124-
return img.convert("RGB")
125-
126-
127-
preprocess_img = transforms.Compose(
128-
[
129-
transforms.ToTensor(),
130-
transforms.Resize((224, 224), antialias=True),
131-
transforms.Normalize(
132-
mean=(0.48145466, 0.4578275, 0.40821073),
133-
std=(0.26862954, 0.26130258, 0.27577711),
134-
),
135-
]
136-
)
137-
138122
class Rank(Enum):
139123
KINGDOM = 0
140124
PHYLUM = 1
@@ -165,11 +149,68 @@ def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
165149
return get_tokenizer(tokenizer_str)
166150

167151

168-
class CustomLabelsClassifier(object):
169-
def __init__(self, cls_ary: List[str], device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
152+
preprocess_img = transforms.Compose(
153+
[
154+
transforms.ToTensor(),
155+
transforms.Resize((224, 224), antialias=True),
156+
transforms.Normalize(
157+
mean=(0.48145466, 0.4578275, 0.40821073),
158+
std=(0.26862954, 0.26130258, 0.27577711),
159+
),
160+
]
161+
)
162+
163+
164+
class BaseClassifier(object):
165+
def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
170166
self.device = device
171167
self.model = create_bioclip_model(device=device, model_str=model_str)
172168
self.model_str = model_str
169+
170+
@staticmethod
171+
def open_image(image_path):
172+
img = PIL.Image.open(image_path)
173+
return img.convert("RGB")
174+
175+
@torch.no_grad()
176+
def create_image_features(self, images: List[PIL.Image.Image], normalize : bool = True) -> torch.Tensor:
177+
preprocessed_images = []
178+
for img in images:
179+
prep_img = preprocess_img(img).to(self.device)
180+
preprocessed_images.append(prep_img)
181+
preprocessed_image_tensor = torch.stack(preprocessed_images)
182+
img_features = self.model.encode_image(preprocessed_image_tensor)
183+
if normalize:
184+
return F.normalize(img_features, dim=-1)
185+
else:
186+
return img_features
187+
188+
@torch.no_grad()
189+
def create_image_features_for_path(self, image_path: str, normalize: bool) -> torch.Tensor:
190+
img = self.open_image(image_path)
191+
result = self.create_image_features([img], normalize=normalize)
192+
return result[0]
193+
194+
@torch.no_grad()
195+
def create_probabilities(self, img_features: torch.Tensor,
196+
txt_features: torch.Tensor) -> dict[str, torch.Tensor]:
197+
logits = (self.model.logit_scale.exp() * img_features @ txt_features)
198+
return F.softmax(logits, dim=1)
199+
200+
def create_probabilities_for_image_paths(self, image_paths: List[str] | str,
201+
txt_features: torch.Tensor) -> dict[str, torch.Tensor]:
202+
images = [self.open_image(image_path) for image_path in image_paths]
203+
img_features = self.create_image_features(images)
204+
probs = self.create_probabilities(img_features, txt_features)
205+
result = {}
206+
for i, key in enumerate(image_paths):
207+
result[key] = probs[i]
208+
return result
209+
210+
211+
class CustomLabelsClassifier(BaseClassifier):
212+
def __init__(self, cls_ary: List[str], device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
213+
super().__init__(device=device, model_str=model_str)
173214
self.tokenizer = create_bioclip_tokenizer()
174215
self.classes = [cls.strip() for cls in cls_ary]
175216
self.txt_features = self._get_txt_features(self.classes)
@@ -188,28 +229,24 @@ def _get_txt_features(self, classnames):
188229
return all_features
189230

190231
@torch.no_grad()
191-
def predict(self, image_path: str) -> dict[str, float]:
192-
img = open_image(image_path)
193-
194-
img = preprocess_img(img).to(self.device)
195-
img_features = self.model.encode_image(img.unsqueeze(0))
196-
img_features = F.normalize(img_features, dim=-1)
197-
198-
logits = (self.model.logit_scale.exp() * img_features @ self.txt_features).squeeze()
199-
probs = F.softmax(logits, dim=0).to("cpu").tolist()
200-
pred_list = []
201-
for cls, prob in zip(self.classes, probs):
202-
pred_list.append({
203-
PRED_FILENAME_KEY: image_path,
204-
PRED_CLASSICATION_KEY: cls,
205-
PRED_SCORE_KEY: prob
206-
})
207-
return pred_list
232+
def predict(self, image_paths: List[str] | str) -> dict[str, float]:
233+
if isinstance(image_paths, str):
234+
image_paths = [image_paths]
235+
probs = self.create_probabilities_for_image_paths(image_paths, self.txt_features)
236+
result = []
237+
for image_path in image_paths:
238+
for cls_str, prob in zip(self.classes, probs[image_path]):
239+
result.append({
240+
PRED_FILENAME_KEY: image_path,
241+
PRED_CLASSICATION_KEY: cls_str,
242+
PRED_SCORE_KEY: prob.item()
243+
})
244+
return result
208245

209246

210247
def predict_classifications_from_list(img: Union[PIL.Image.Image, str], cls_ary: List[str], device: Union[str, torch.device] = 'cpu') -> dict[str, float]:
211248
classifier = CustomLabelsClassifier(cls_ary=cls_ary, device=device)
212-
return classifier.predict(img)
249+
return classifier.predict([img])
213250

214251

215252
def get_tol_classification_labels(rank: Rank) -> List[str]:
@@ -244,31 +281,12 @@ def join_names(classification_dict: dict[str, str]) -> str:
244281
return " ".join(classification_dict.values())
245282

246283

247-
class TreeOfLifeClassifier(object):
284+
class TreeOfLifeClassifier(BaseClassifier):
248285
def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
249-
self.device = device
250-
self.model = create_bioclip_model(device=device, model_str=model_str)
251-
self.model_str = model_str
252-
self.txt_emb = get_txt_emb().to(device)
286+
super().__init__(device=device, model_str=model_str)
287+
self.txt_features = get_txt_emb().to(device)
253288
self.txt_names = get_txt_names()
254289

255-
@torch.no_grad()
256-
def get_image_features(self, image_path: str) -> torch.Tensor:
257-
img = open_image(image_path)
258-
return self.encode_image(img)
259-
260-
def encode_image(self, img: PIL.Image.Image) -> torch.Tensor:
261-
img = preprocess_img(img).to(self.device)
262-
img_features = self.model.encode_image(img.unsqueeze(0))
263-
return img_features
264-
265-
def predict_species(self, img: PIL.Image.Image) -> torch.Tensor:
266-
img_features = self.encode_image(img)
267-
img_features = F.normalize(img_features, dim=-1)
268-
logits = (self.model.logit_scale.exp() * img_features @ self.txt_emb).squeeze()
269-
probs = F.softmax(logits, dim=0)
270-
return probs
271-
272290
def format_species_probs(self, image_path: str, probs: torch.Tensor, k: int = 5) -> List[dict[str, float]]:
273291
topk = probs.topk(k)
274292
result = []
@@ -299,12 +317,17 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
299317
return prediction_ary
300318

301319
@torch.no_grad()
302-
def predict(self, image_path: str, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]:
303-
img = open_image(image_path)
304-
probs = self.predict_species(img)
305-
if rank == Rank.SPECIES:
306-
return self.format_species_probs(image_path, probs, k)
307-
return self.format_grouped_probs(image_path, probs, rank, min_prob, k)
320+
def predict(self, image_paths: List[str] | str, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> dict[str, dict[str, float]]:
321+
if isinstance(image_paths, str):
322+
image_paths = [image_paths]
323+
probs = self.create_probabilities_for_image_paths(image_paths, self.txt_features)
324+
result = []
325+
for image_path in image_paths:
326+
if rank == Rank.SPECIES:
327+
result.extend(self.format_species_probs(image_path, probs[image_path], k))
328+
else:
329+
result.extend(self.format_grouped_probs(image_path, probs[image_path], rank, min_prob, k))
330+
return result
308331

309332

310333
def predict_classification(img: str, rank: Rank, device: Union[str, torch.device] = 'cpu',
@@ -315,4 +338,4 @@ def predict_classification(img: str, rank: Rank, device: Union[str, torch.device
315338
species, then sums up species-level probabilities for the given rank.
316339
"""
317340
classifier = TreeOfLifeClassifier(device=device)
318-
return classifier.predict(img, rank, min_prob, k)
341+
return classifier.predict([img], rank, min_prob, k)

tests/test_predict.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
EXAMPLE_CAT_IMAGE2 = os.path.join(DIRNAME, "images", "mycat.png")
1111

1212
class TestPredict(unittest.TestCase):
13-
def test_tree_of_life_classifier_species(self):
13+
def test_tree_of_life_classifier_species_single(self):
1414
classifier = TreeOfLifeClassifier()
15-
prediction_ary = classifier.predict(image_path=EXAMPLE_CAT_IMAGE, rank=Rank.SPECIES)
15+
prediction_ary = classifier.predict(image_paths=EXAMPLE_CAT_IMAGE, rank=Rank.SPECIES)
1616
self.assertEqual(len(prediction_ary), 5)
1717
prediction_dict = {
1818
'file_name': EXAMPLE_CAT_IMAGE,
@@ -29,9 +29,20 @@ def test_tree_of_life_classifier_species(self):
2929
}
3030
self.assertEqual(prediction_ary[0], prediction_dict)
3131

32+
def test_tree_of_life_classifier_species_ary_one(self):
33+
classifier = TreeOfLifeClassifier()
34+
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE], rank=Rank.SPECIES)
35+
self.assertEqual(len(prediction_ary), 5)
36+
37+
def test_tree_of_life_classifier_species_ary_multiple(self):
38+
classifier = TreeOfLifeClassifier()
39+
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE, EXAMPLE_CAT_IMAGE2],
40+
rank=Rank.SPECIES)
41+
self.assertEqual(len(prediction_ary), 10)
42+
3243
def test_tree_of_life_classifier_family(self):
3344
classifier = TreeOfLifeClassifier()
34-
prediction_ary = classifier.predict(image_path=EXAMPLE_CAT_IMAGE, rank=Rank.FAMILY, k=2)
45+
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE], rank=Rank.FAMILY, k=2)
3546
self.assertEqual(len(prediction_ary), 2)
3647
prediction_dict = {
3748
'file_name': EXAMPLE_CAT_IMAGE,
@@ -46,22 +57,41 @@ def test_tree_of_life_classifier_family(self):
4657

4758
def test_custom_labels_classifier(self):
4859
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
49-
results = classifier.predict(image_path=EXAMPLE_CAT_IMAGE)
50-
self.assertEqual(results, [
60+
prediction_ary = classifier.predict(image_paths=EXAMPLE_CAT_IMAGE)
61+
self.assertEqual(prediction_ary, [
62+
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
63+
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
64+
])
65+
66+
def test_custom_labels_classifier_ary_one(self):
67+
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
68+
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE])
69+
self.assertEqual(prediction_ary, [
70+
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
71+
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
72+
])
73+
74+
def test_custom_labels_classifier_ary_multiple(self):
75+
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
76+
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE, EXAMPLE_CAT_IMAGE2])
77+
self.assertEqual(prediction_ary, [
5178
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
5279
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
80+
{'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'cat', 'score': unittest.mock.ANY},
81+
{'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'dog', 'score': unittest.mock.ANY},
5382
])
5483

84+
5585
def test_predict_with_rgba_image(self):
5686
# Ensure that the classifier can handle RGBA images
5787
classifier = TreeOfLifeClassifier()
58-
prediction_ary = classifier.predict(image_path=EXAMPLE_CAT_IMAGE2, rank=Rank.SPECIES)
88+
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2], rank=Rank.SPECIES)
5989
self.assertEqual(len(prediction_ary), 5)
6090

6191

6292
class TestEmbed(unittest.TestCase):
6393
def test_get_image_features(self):
6494
classifier = TreeOfLifeClassifier(device='cpu')
6595
self.assertEqual(classifier.model_str, 'hf-hub:imageomics/bioclip')
66-
features = classifier.get_image_features(EXAMPLE_CAT_IMAGE)
67-
self.assertEqual(features.shape, torch.Size([1, 512]))
96+
features = classifier.create_image_features_for_path(EXAMPLE_CAT_IMAGE, normalize=False)
97+
self.assertEqual(features.shape, torch.Size([512]))

0 commit comments

Comments
 (0)