Skip to content

Commit 88eee31

Browse files
authored
Merge pull request #10 from Imageomics/embed
Add embed command
2 parents 71d8a63 + 12f555f commit 88eee31

File tree

5 files changed

+195
-53
lines changed

5 files changed

+195
-53
lines changed

README.md

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,24 @@ bear 1.0
103103

104104
## Command Line Usage
105105
```
106-
bioclip predict [options] [IMAGE_FILE...]
106+
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
107+
bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...]
108+
109+
Commands:
110+
predict Use BioCLIP to generate predictions for image files.
111+
embed Use BioCLIP to generate embeddings for image files.
107112
108113
Arguments:
109-
IMAGE_FILE input image file
114+
IMAGE_FILE input image file
110115
111116
Options:
112117
-h --help
113-
--format=FORMAT format of the output (table or csv) [default: csv]
114-
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species)
115-
[default: species]
116-
--k=K number of top predictions to show [default: 5]
117-
--cls=CLS comma separated list of classes to predict, when specified the --rank and
118-
--k arguments are ignored [default: all]
119-
--device=DEVICE device to use for prediction (cpu or cuda or mps) [default: cpu]
120-
--output=OUTFILE print output to file OUTFILE [default: stdout]
118+
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
119+
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
120+
--k=K number of top predictions to show [default: 5]
121+
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed
122+
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
123+
--output=OUTFILE print output to file OUTFILE [default: stdout]
121124
```
122125

123126
### Predict classification
@@ -191,6 +194,28 @@ Ursus-arctos.jpeg,bird,3.051998476166773e-08
191194
Ursus-arctos.jpeg,bear,0.9999998807907104
192195
```
193196

197+
### Create embeddings
198+
199+
#### Create embedding for an image
200+
201+
```console
202+
bioclip embed Ursus-arctos.jpeg
203+
```
204+
Output:
205+
```
206+
{
207+
"model": "hf-hub:imageomics/bioclip",
208+
"embeddings": {
209+
"Ursus-arctos.jpeg": [
210+
-0.23633578419685364,
211+
-0.28467196226119995,
212+
-0.4394485652446747,
213+
...
214+
]
215+
}
216+
}
217+
```
218+
194219
### View command line help
195220
```console
196221
bioclip --help

src/bioclip/__main__.py

Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
1-
"""Usage: bioclip predict [options] [IMAGE_FILE...]
2-
3-
Use BioCLIP to generate predictions for an IMAGE_FILE.
4-
5-
Arguments:
6-
IMAGE_FILE input image file
7-
8-
Options:
9-
-h --help
10-
--format=FORMAT format of the output (table or csv) [default: csv]
11-
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
12-
--k=K number of top predictions to show [default: 5]
13-
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k arguments are ignored [default: all]
14-
--device=DEVICE device to use for prediction (cpu or cuda or mps) [default: cpu]
15-
--output=OUTFILE print output to file OUTFILE [default: stdout]
16-
17-
"""
18-
from docopt import docopt
191
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
202
import json
213
import sys
224
import prettytable as pt
23-
import csv
245
import pandas as pd
6+
import argparse
257

268

279
def write_results(data, format, output):
@@ -46,33 +28,89 @@ def write_results_to_file(df, format, outfile):
4628
else:
4729
raise ValueError(f"Invalid format: {format}")
4830

49-
50-
def main():
51-
# execute only if run as the entry point into the program
52-
x = docopt(__doc__) # parse arguments based on docstring above
53-
format = x['--format']
54-
output = x['--output']
55-
image_file = x['IMAGE_FILE']
56-
device = 'cpu'
57-
if x['--device']:
58-
device = x['--device']
59-
cls = x['--cls']
60-
if not format in ['table', 'csv']:
61-
raise ValueError(f"Invalid format: {format}")
62-
rank = Rank[x['--rank'].upper()]
63-
if cls == 'all':
64-
classifier = TreeOfLifeClassifier(device=device)
31+
def predict(image_file: list[str], format: str, output: str,
32+
cls_str: str, device: str, rank: Rank, k: int):
33+
if cls_str:
34+
classifier = CustomLabelsClassifier(device=device)
6535
data = []
6636
for image_path in image_file:
67-
data.extend(classifier.predict(image_path=image_path, rank=rank, k=int(x['--k'])))
37+
data.extend(classifier.predict(image_path=image_path, cls_ary=cls_str.split(',')))
6838
write_results(data, format, output)
6939
else:
70-
classifier = CustomLabelsClassifier(device=device)
40+
classifier = TreeOfLifeClassifier(device=device)
7141
data = []
7242
for image_path in image_file:
73-
data.extend(classifier.predict(image_path=image_path, cls_ary=cls.split(',')))
43+
data.extend(classifier.predict(image_path=image_path, rank=rank, k=k))
7444
write_results(data, format, output)
7545

7646

47+
def embed(image_file: list[str], output: str, device: str):
48+
classifier = TreeOfLifeClassifier(device=device)
49+
images_dict = {}
50+
data = {
51+
"model": classifier.model_str,
52+
"embeddings": images_dict
53+
}
54+
for image_path in image_file:
55+
features = classifier.get_image_features(image_path)[0]
56+
images_dict[image_path] = features.tolist()
57+
if output == 'stdout':
58+
print(json.dumps(data, indent=4))
59+
else:
60+
with open(output, 'w') as outfile:
61+
json.dump(data, outfile, indent=4)
62+
63+
64+
def create_parser():
65+
parser = argparse.ArgumentParser(prog='bioclip', description='BioCLIP command line interface')
66+
subparsers = parser.add_subparsers(title='commands', dest='command')
67+
68+
# Predict command
69+
predict_parser = subparsers.add_parser('predict', help='Use BioCLIP to generate predictions for image files.')
70+
predict_parser.add_argument('image_file', nargs='+', help='input image file(s)')
71+
predict_parser.add_argument('--format', choices=['table', 'csv'], default='csv', help='format of the output, default: csv')
72+
predict_parser.add_argument('--output', default='stdout', help='print output to file, default: stdout')
73+
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
74+
help='rank of the classification, default: species (when)')
75+
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
76+
predict_parser.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed')
77+
predict_parser.add_argument('--device', help='device to use (cpu or cuda or mps), default: cpu', default='cpu')
78+
79+
# Embed command
80+
embed_parser = subparsers.add_parser('embed', help='Use BioCLIP to generate embeddings for image files.')
81+
embed_parser.add_argument('image_file', nargs='+', help='input image file(s)')
82+
embed_parser.add_argument('--output', default='stdout', help='print output to file, default: stdout')
83+
embed_parser.add_argument('--device', help='device to use (cpu or cuda or mps), default: cpu', default='cpu')
84+
85+
return parser
86+
87+
88+
def parse_args(input_args=None):
89+
args = create_parser().parse_args(input_args)
90+
if args.command == 'predict':
91+
if args.cls:
92+
# custom class list mode
93+
if args.rank or args.k:
94+
raise ValueError("Cannot use --cls with --rank or --k")
95+
else:
96+
# tree of life class list mode
97+
if not args.rank:
98+
args.rank = 'species'
99+
args.rank = Rank[args.rank.upper()]
100+
if not args.k:
101+
args.k = 5
102+
return args
103+
104+
105+
def main():
106+
args = parse_args()
107+
if args.command == 'embed':
108+
embed(args.image_file, args.output, args.device)
109+
elif args.command == 'predict':
110+
predict(args.image_file, args.format, args.output, args.cls, args.device, args.rank, args.k)
111+
else:
112+
raise ValueError("Invalid command")
113+
114+
77115
if __name__ == '__main__':
78116
main()

src/bioclip/predict.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
HF_DATAFILE_REPO = "imageomics/bioclip-demo"
1616
HF_DATAFILE_REPO_TYPE = "space"
17+
MODEL_STR = "hf-hub:imageomics/bioclip"
1718
PRED_FILENAME_KEY = "file_name"
1819
PRED_CLASSICATION_KEY = "classification"
1920
PRED_SCORE_KEY = "score"
@@ -149,7 +150,7 @@ def get_label(self):
149150
COMMON_NAME_LABEL = "common_name"
150151

151152

152-
def create_bioclip_model(model_str="hf-hub:imageomics/bioclip", device="cuda"):
153+
def create_bioclip_model(model_str, device="cuda"):
153154
model = create_model(model_str, output_dict=True, require_pretrained=True)
154155
model = model.to(device)
155156
return torch.compile(model)
@@ -160,9 +161,10 @@ def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
160161

161162

162163
class CustomLabelsClassifier(object):
163-
def __init__(self, device: Union[str, torch.device] = 'cpu'):
164+
def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
164165
self.device = device
165-
self.model = create_bioclip_model(device=device)
166+
self.model = create_bioclip_model(device=device, model_str=model_str)
167+
self.model_str = model_str
166168
self.tokenizer = create_bioclip_tokenizer()
167169

168170
def get_txt_features(self, classnames):
@@ -237,12 +239,18 @@ def join_names(classification_dict: dict[str, str]) -> str:
237239

238240

239241
class TreeOfLifeClassifier(object):
240-
def __init__(self, device: Union[str, torch.device] = 'cpu'):
242+
def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
241243
self.device = device
242-
self.model = create_bioclip_model(device=device)
244+
self.model = create_bioclip_model(device=device, model_str=model_str)
245+
self.model_str = model_str
243246
self.txt_emb = get_txt_emb().to(device)
244247
self.txt_names = get_txt_names()
245248

249+
@torch.no_grad()
250+
def get_image_features(self, image_path: str) -> torch.Tensor:
251+
img = PIL.Image.open(image_path)
252+
return self.encode_image(img)
253+
246254
def encode_image(self, img: PIL.Image.Image) -> torch.Tensor:
247255
img = preprocess_img(img).to(self.device)
248256
img_features = self.model.encode_image(img.unsqueeze(0))

tests/test_main.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import unittest
2+
from bioclip.__main__ import parse_args, Rank
3+
4+
5+
class TestParser(unittest.TestCase):
6+
def test_parse_args(self):
7+
8+
args = parse_args(['predict', 'image.jpg'])
9+
self.assertEqual(args.command, 'predict')
10+
self.assertEqual(args.image_file, ['image.jpg'])
11+
self.assertEqual(args.format, 'csv')
12+
self.assertEqual(args.output, 'stdout')
13+
self.assertEqual(args.rank, Rank.SPECIES)
14+
self.assertEqual(args.k, 5)
15+
self.assertEqual(args.cls, None)
16+
self.assertEqual(args.device, 'cpu')
17+
18+
args = parse_args(['predict', 'image.jpg', 'image2.png'])
19+
self.assertEqual(args.command, 'predict')
20+
self.assertEqual(args.image_file, ['image.jpg', 'image2.png'])
21+
22+
# test tree of life version of predict
23+
args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--rank', 'genus', '--k', '10', '--device', 'cuda'])
24+
self.assertEqual(args.command, 'predict')
25+
self.assertEqual(args.image_file, ['image.jpg'])
26+
self.assertEqual(args.format, 'table')
27+
self.assertEqual(args.output, 'output.csv')
28+
self.assertEqual(args.rank, Rank.GENUS)
29+
self.assertEqual(args.k, 10)
30+
self.assertEqual(args.cls, None)
31+
self.assertEqual(args.device, 'cuda')
32+
33+
# test custom class list version of predict
34+
args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--cls', 'class1,class2', '--device', 'cuda'])
35+
self.assertEqual(args.command, 'predict')
36+
self.assertEqual(args.image_file, ['image.jpg'])
37+
self.assertEqual(args.format, 'table')
38+
self.assertEqual(args.output, 'output.csv')
39+
self.assertEqual(args.rank, None) # default ignored for the --cls variation
40+
self.assertEqual(args.k, None)
41+
self.assertEqual(args.cls, 'class1,class2')
42+
self.assertEqual(args.device, 'cuda')
43+
44+
# test error when using --cls with --rank
45+
with self.assertRaises(ValueError):
46+
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus'])
47+
# test error when using --cls with --k
48+
with self.assertRaises(ValueError):
49+
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
50+
51+
args = parse_args(['embed', 'image.jpg'])
52+
self.assertEqual(args.command, 'embed')
53+
self.assertEqual(args.image_file, ['image.jpg'])
54+
self.assertEqual(args.output, 'stdout')
55+
self.assertEqual(args.device, 'cpu')
56+
57+
args = parse_args(['embed', '--output', 'data.json', '--device', 'cuda', 'image.jpg', 'image2.png'])
58+
self.assertEqual(args.command, 'embed')
59+
self.assertEqual(args.image_file, ['image.jpg', 'image2.png'])
60+
self.assertEqual(args.output, 'data.json')
61+
self.assertEqual(args.device, 'cuda')

tests/test_predict.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from bioclip.predict import TreeOfLifeClassifier, Rank
33
from bioclip.predict import CustomLabelsClassifier
44
import os
5+
import torch
6+
57

68
DIRNAME = os.path.dirname(os.path.realpath(__file__))
79
EXAMPLE_CAT_IMAGE = os.path.join(DIRNAME, "images", "mycat.jpg")
@@ -48,3 +50,11 @@ def test_custom_labels_classifier(self):
4850
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
4951
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},
5052
])
53+
54+
55+
class TestEmbed(unittest.TestCase):
56+
def test_get_image_features(self):
57+
classifier = TreeOfLifeClassifier(device='cpu')
58+
self.assertEqual(classifier.model_str, 'hf-hub:imageomics/bioclip')
59+
features = classifier.get_image_features(EXAMPLE_CAT_IMAGE)
60+
self.assertEqual(features.shape, torch.Size([1, 512]))

0 commit comments

Comments
 (0)