Skip to content

Commit c8d2051

Browse files
authored
Adds ability to provide other OpenCLIP models and checkpoints (#33)
For now this will only work with predicting on custom classes. For the full tree-of-life, the embeddings for all text labels would need to have been pre-computed and made available for download and caching. Also updates documentation for command line arguments.
1 parent 2aca6dd commit c8d2051

File tree

2 files changed

+88
-33
lines changed

2 files changed

+88
-33
lines changed

src/bioclip/__main__.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
2+
from .predict import BIOCLIP_MODEL_STR
3+
import open_clip as oc
24
import json
35
import sys
46
import prettytable as pt
@@ -29,20 +31,25 @@ def write_results_to_file(df, format, outfile):
2931
raise ValueError(f"Invalid format: {format}")
3032

3133

32-
def predict(image_file: list[str], format: str, output: str,
33-
cls_str: str, device: str, rank: Rank, k: int):
34+
def predict(image_file: list[str],
35+
format: str,
36+
output: str,
37+
cls_str: str,
38+
rank: Rank,
39+
k: int,
40+
**kwargs):
3441
if cls_str:
35-
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), device=device)
42+
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs)
3643
predictions = classifier.predict(image_paths=image_file, k=k)
3744
write_results(predictions, format, output)
3845
else:
39-
classifier = TreeOfLifeClassifier(device=device)
46+
classifier = TreeOfLifeClassifier(**kwargs)
4047
predictions = classifier.predict(image_paths=image_file, rank=rank, k=k)
4148
write_results(predictions, format, output)
4249

4350

44-
def embed(image_file: list[str], output: str, device: str):
45-
classifier = TreeOfLifeClassifier(device=device)
51+
def embed(image_file: list[str], output: str, **kwargs):
52+
classifier = TreeOfLifeClassifier(**kwargs)
4653
images_dict = {}
4754
data = {
4855
"model": classifier.model_str,
@@ -62,22 +69,42 @@ def create_parser():
6269
parser = argparse.ArgumentParser(prog='bioclip', description='BioCLIP command line interface')
6370
subparsers = parser.add_subparsers(title='commands', dest='command')
6471

72+
device_arg = {'default':'cpu', 'help': 'device to use (cpu or cuda or mps), default: cpu'}
73+
output_arg = {'default': 'stdout', 'help': 'print output to file, default: stdout'}
74+
model_arg = {'help': f'model identifier (see command list-models); default: {BIOCLIP_MODEL_STR}'}
75+
pretrained_arg = {'help': 'pretrained model checkpoint as tag or file, depends on model; '
76+
'needed only if more than one is available (see command list-models)'}
77+
6578
# Predict command
6679
predict_parser = subparsers.add_parser('predict', help='Use BioCLIP to generate predictions for image files.')
6780
predict_parser.add_argument('image_file', nargs='+', help='input image file(s)')
6881
predict_parser.add_argument('--format', choices=['table', 'csv'], default='csv', help='format of the output, default: csv')
69-
predict_parser.add_argument('--output', default='stdout', help='print output to file, default: stdout')
82+
predict_parser.add_argument('--output', **output_arg)
7083
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
7184
help='rank of the classification, default: species (when)')
7285
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
73-
predict_parser.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed')
74-
predict_parser.add_argument('--device', help='device to use (cpu or cuda or mps), default: cpu', default='cpu')
86+
predict_parser.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
87+
predict_parser.add_argument('--device', **device_arg)
88+
predict_parser.add_argument('--model', **model_arg)
89+
predict_parser.add_argument('--pretrained', **pretrained_arg)
7590

7691
# Embed command
7792
embed_parser = subparsers.add_parser('embed', help='Use BioCLIP to generate embeddings for image files.')
7893
embed_parser.add_argument('image_file', nargs='+', help='input image file(s)')
79-
embed_parser.add_argument('--output', default='stdout', help='print output to file, default: stdout')
80-
embed_parser.add_argument('--device', help='device to use (cpu or cuda or mps), default: cpu', default='cpu')
94+
embed_parser.add_argument('--output', **output_arg)
95+
embed_parser.add_argument('--device', **device_arg)
96+
embed_parser.add_argument('--model', **model_arg)
97+
embed_parser.add_argument('--pretrained', **pretrained_arg)
98+
99+
# List command
100+
list_parser = subparsers.add_parser('list-models',
101+
help='List available models and pretrained model checkpoints.',
102+
description=
103+
'Note that this will only list models known to open_clip; '
104+
'any model identifier loadable by open_clip, such as from hf-hub, file, etc '
105+
'should also be usable for --model in the embed and predict commands. '
106+
f'(The default model {BIOCLIP_MODEL_STR} is one example.)')
107+
list_parser.add_argument('--model', help='list available tags for pretrained model checkpoint(s) for specified model')
81108

82109
return parser
83110

@@ -91,6 +118,8 @@ def parse_args(input_args=None):
91118
raise ValueError("Cannot use --cls with --rank")
92119
else:
93120
# tree of life class list mode
121+
if args.model or args.pretrained:
122+
raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction")
94123
if not args.rank:
95124
args.rank = 'species'
96125
args.rank = Rank[args.rank.upper()]
@@ -102,9 +131,28 @@ def parse_args(input_args=None):
102131
def main():
103132
args = parse_args()
104133
if args.command == 'embed':
105-
embed(args.image_file, args.output, args.device)
134+
embed(args.image_file,
135+
args.output,
136+
device=args.device,
137+
model_str=args.model,
138+
pretrained_str=args.pretrained)
106139
elif args.command == 'predict':
107-
predict(args.image_file, args.format, args.output, args.cls, args.device, args.rank, args.k)
140+
predict(args.image_file,
141+
format=args.format,
142+
output=args.output,
143+
cls_str=args.cls,
144+
rank=args.rank,
145+
k=args.k,
146+
device=args.device,
147+
model_str=args.model,
148+
pretrained_str=args.pretrained)
149+
elif args.command == 'list-models':
150+
if args.model:
151+
for tag in oc.list_pretrained_tags_by_model(args.model):
152+
print(tag)
153+
else:
154+
for model_str in oc.list_models():
155+
print(f"\t{model_str}")
108156
else:
109157
raise ValueError("Invalid command")
110158

src/bioclip/predict.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import torch
33
from torchvision import transforms
4-
from open_clip import create_model, get_tokenizer
4+
import open_clip as oc
55
import torch.nn.functional as F
66
import numpy as np
77
import collections
@@ -14,7 +14,7 @@
1414

1515
HF_DATAFILE_REPO = "imageomics/bioclip-demo"
1616
HF_DATAFILE_REPO_TYPE = "space"
17-
MODEL_STR = "hf-hub:imageomics/bioclip"
17+
BIOCLIP_MODEL_STR = "hf-hub:imageomics/bioclip"
1818
PRED_FILENAME_KEY = "file_name"
1919
PRED_CLASSICATION_KEY = "classification"
2020
PRED_SCORE_KEY = "score"
@@ -139,14 +139,8 @@ def get_label(self):
139139
COMMON_NAME_LABEL = "common_name"
140140

141141

142-
def create_bioclip_model(model_str, device="cuda"):
143-
model = create_model(model_str, output_dict=True, require_pretrained=True)
144-
model = model.to(device)
145-
return torch.compile(model)
146-
147-
148-
def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
149-
return get_tokenizer(tokenizer_str)
142+
def create_bioclip_tokenizer(model_name="ViT-B-16"):
143+
return oc.get_tokenizer(model_name=model_name)
150144

151145

152146
preprocess_img = transforms.Compose(
@@ -162,10 +156,23 @@ def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
162156

163157

164158
class BaseClassifier(object):
165-
def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
159+
def __init__(self, model_str: str = BIOCLIP_MODEL_STR, pretrained_str: str | None = None, device: Union[str, torch.device] = 'cpu'):
166160
self.device = device
167-
self.model = create_bioclip_model(device=device, model_str=model_str)
168-
self.model_str = model_str
161+
self.load_pretrained_model(model_str=model_str, pretrained_str=pretrained_str)
162+
163+
def load_pretrained_model(self, model_str: str = BIOCLIP_MODEL_STR, pretrained_str: str | None = None):
164+
self.model_str = model_str or BIOCLIP_MODEL_STR
165+
pretrained_tags = oc.list_pretrained_tags_by_model(self.model_str)
166+
if pretrained_str is None and len(pretrained_tags) > 0:
167+
if len(pretrained_tags) > 1:
168+
raise ValueError(f"Multiple pretrained tags available {pretrained_tags}, must provide one")
169+
pretrained_str = pretrained_tags[0]
170+
model, preprocess = oc.create_model_from_pretrained(self.model_str,
171+
pretrained=pretrained_str,
172+
device=self.device,
173+
return_transform=True)
174+
self.model = torch.compile(model.to(self.device))
175+
self.preprocess = preprocess_img if self.model_str == BIOCLIP_MODEL_STR else preprocess
169176

170177
@staticmethod
171178
def open_image(image_path):
@@ -176,7 +183,7 @@ def open_image(image_path):
176183
def create_image_features(self, images: List[PIL.Image.Image], normalize : bool = True) -> torch.Tensor:
177184
preprocessed_images = []
178185
for img in images:
179-
prep_img = preprocess_img(img).to(self.device)
186+
prep_img = self.preprocess(img).to(self.device)
180187
preprocessed_images.append(prep_img)
181188
preprocessed_image_tensor = torch.stack(preprocessed_images)
182189
img_features = self.model.encode_image(preprocessed_image_tensor)
@@ -209,9 +216,9 @@ def create_probabilities_for_image_paths(self, image_paths: List[str] | str,
209216

210217

211218
class CustomLabelsClassifier(BaseClassifier):
212-
def __init__(self, cls_ary: List[str], device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
213-
super().__init__(device=device, model_str=model_str)
214-
self.tokenizer = create_bioclip_tokenizer()
219+
def __init__(self, cls_ary: List[str], **kwargs):
220+
super().__init__(**kwargs)
221+
self.tokenizer = create_bioclip_tokenizer(self.model_str)
215222
self.classes = [cls.strip() for cls in cls_ary]
216223
self.txt_features = self._get_txt_features(self.classes)
217224

@@ -286,9 +293,9 @@ def join_names(classification_dict: dict[str, str]) -> str:
286293

287294

288295
class TreeOfLifeClassifier(BaseClassifier):
289-
def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
290-
super().__init__(device=device, model_str=model_str)
291-
self.txt_features = get_txt_emb().to(device)
296+
def __init__(self, **kwargs):
297+
super().__init__(**kwargs)
298+
self.txt_features = get_txt_emb().to(self.device)
292299
self.txt_names = get_txt_names()
293300

294301
def format_species_probs(self, image_path: str, probs: torch.Tensor, k: int = 5) -> List[dict[str, float]]:

0 commit comments

Comments
 (0)