Skip to content

Commit ae33abf

Browse files
johnbradleyhlappegrace479
committed
Add binning to custom label prediction
Allow users to group custom label predictions into bins. Fixes #29 Co-authored-by: Hilmar Lapp <[email protected]> Co-authored-by: Elizabeth Campolongo <[email protected]>
1 parent 7f2041b commit ae33abf

File tree

6 files changed

+221
-33
lines changed

6 files changed

+221
-33
lines changed

README.md

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,29 @@ fish 2.932403668845507e-12
102102
bear 1.0
103103
```
104104

105+
### Predict from a list of classes with binning
106+
```python
107+
from bioclip import CustomLabelsBinningClassifier
108+
classifier = CustomLabelsBinningClassifier(cls_to_bin={
109+
'dog': 'small',
110+
'fish': 'small',
111+
'bear': 'big',
112+
})
113+
predictions = classifier.predict("Ursus-arctos.jpeg")
114+
for prediction in predictions:
115+
print(prediction["classification"], prediction["score"])
116+
```
117+
Output:
118+
```console
119+
big 0.99992835521698
120+
small 7.165559509303421e-05
121+
```
122+
105123
## Command Line Usage
106124
```
107-
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
125+
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT]
126+
[--rank {kingdom,phylum,class,order,family,genus,species} | --cls CLS | --bins BINS]
127+
[--k K] [--device DEVICE] image_file [image_file ...]
108128
bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...]
109129
110130
Commands:
@@ -117,9 +137,13 @@ Arguments:
117137
Options:
118138
-h --help
119139
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
120-
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
121-
--k=K number of top predictions to show [default: 5]
122-
--cls=CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed.
140+
--rank {kingdom,phylum,class,order,family,genus,species}
141+
rank of the classification, default: species (when)
142+
--cls CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the
143+
--rank and --bins arguments are not allowed.
144+
--bins BINS path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls
145+
argument is not allowed.
146+
--k K number of top predictions to show, default: 5
123147
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
124148
--output=OUTFILE print output to file OUTFILE [default: stdout]
125149
```
@@ -195,6 +219,29 @@ Ursus-arctos.jpeg,bird,3.051998476166773e-08
195219
Ursus-arctos.jpeg,bear,0.9999998807907104
196220
```
197221

222+
### Predict from a binning CSV
223+
Create predictions for 3 classes (cat, bird, and bear) with 2 bins (one, two) for image `Ursus-arctos.jpeg`:
224+
225+
Create a CSV file named `bins.csv` with the following contents:
226+
```
227+
cls,bin
228+
cat,one
229+
bird,one
230+
bear,two
231+
```
232+
The names of the columns do not matter. The first column values will be used as the classes. The second column values will be used for bin names.
233+
234+
Run predict command:
235+
```console
236+
bioclip predict --bins bins.csv Ursus-arctos.jpeg
237+
```
238+
239+
Output:
240+
```
241+
Ursus-arctos.jpeg,two,0.9999998807907104
242+
Ursus-arctos.jpeg,one,7.633736487377973e-08
243+
```
244+
198245
### Create embeddings
199246

200247
#### Create embedding for an image

src/bioclip/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-FileCopyrightText: 2024-present John Bradley <[email protected]>
22
#
33
# SPDX-License-Identifier: MIT
4-
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
4+
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier
55

6-
__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier"]
6+
__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier", "CustomLabelsBinningClassifier"]

src/bioclip/__main__.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
1+
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier
22
from .predict import BIOCLIP_MODEL_STR
33
import open_clip as oc
44
import os
@@ -32,17 +32,32 @@ def write_results_to_file(df, format, outfile):
3232
raise ValueError(f"Invalid format: {format}")
3333

3434

35+
def parse_bins_csv(bins_path):
36+
if not os.path.exists(bins_path):
37+
raise FileNotFoundError(f"File not found: {bins_path}")
38+
bin_df = pd.read_csv(bins_path, index_col=0)
39+
if len(bin_df.columns) == 0:
40+
raise ValueError("CSV file must have at least two columns.")
41+
return bin_df[bin_df.columns[0]].to_dict()
42+
43+
3544
def predict(image_file: list[str],
3645
format: str,
3746
output: str,
3847
cls_str: str,
3948
rank: Rank,
49+
bins_path: str,
4050
k: int,
4151
**kwargs):
4252
if cls_str:
4353
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs)
4454
predictions = classifier.predict(image_paths=image_file, k=k)
4555
write_results(predictions, format, output)
56+
elif bins_path:
57+
cls_to_bin = parse_bins_csv(bins_path)
58+
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
59+
predictions = classifier.predict(image_paths=image_file, k=k)
60+
write_results(predictions, format, output)
4661
else:
4762
classifier = TreeOfLifeClassifier(**kwargs)
4863
predictions = classifier.predict(image_paths=image_file, rank=rank, k=k)
@@ -81,11 +96,13 @@ def create_parser():
8196
predict_parser.add_argument('image_file', nargs='+', help='input image file(s)')
8297
predict_parser.add_argument('--format', choices=['table', 'csv'], default='csv', help='format of the output, default: csv')
8398
predict_parser.add_argument('--output', **output_arg)
84-
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
99+
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
100+
cls_group.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
85101
help='rank of the classification, default: species (when)')
102+
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank and --bins arguments are not allowed."
103+
cls_group.add_argument('--cls', help=cls_help)
104+
cls_group.add_argument('--bins', help='path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls argument is not allowed.')
86105
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
87-
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed."
88-
predict_parser.add_argument('--cls', help=cls_help)
89106

90107
predict_parser.add_argument('--device', **device_arg)
91108
predict_parser.add_argument('--model', **model_arg)
@@ -115,11 +132,7 @@ def create_parser():
115132
def parse_args(input_args=None):
116133
args = create_parser().parse_args(input_args)
117134
if args.command == 'predict':
118-
if args.cls:
119-
# custom class list mode
120-
if args.rank:
121-
raise ValueError("Cannot use --cls with --rank")
122-
else:
135+
if not args.cls and not args.bins:
123136
# tree of life class list mode
124137
if args.model or args.pretrained:
125138
raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction")
@@ -155,6 +168,7 @@ def main():
155168
output=args.output,
156169
cls_str=cls_str,
157170
rank=args.rank,
171+
bins_path=args.bins,
158172
k=args.k,
159173
device=args.device,
160174
model_str=args.model,
@@ -167,7 +181,7 @@ def main():
167181
for model_str in oc.list_models():
168182
print(f"\t{model_str}")
169183
else:
170-
raise ValueError("Invalid command")
184+
create_parser().print_help()
171185

172186

173187
if __name__ == '__main__':

src/bioclip/predict.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import open_clip as oc
66
import torch.nn.functional as F
77
import numpy as np
8+
import pandas as pd
89
import collections
910
import heapq
1011
import PIL.Image
@@ -253,13 +254,41 @@ def predict(self, image_paths: List[str] | str, k: int = None) -> dict[str, floa
253254
img_probs = probs[image_path]
254255
if not k or k > len(self.classes):
255256
k = len(self.classes)
256-
topk = img_probs.topk(k)
257-
for i, prob in zip(topk.indices, topk.values):
258-
result.append({
259-
PRED_FILENAME_KEY: image_path,
260-
PRED_CLASSICATION_KEY: self.classes[i],
261-
PRED_SCORE_KEY: prob.item()
262-
})
257+
result.extend(self.group_probs(image_path, img_probs, k))
258+
return result
259+
260+
def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
261+
result = []
262+
topk = img_probs.topk(k)
263+
for i, prob in zip(topk.indices, topk.values):
264+
result.append({
265+
PRED_FILENAME_KEY: image_path,
266+
PRED_CLASSICATION_KEY: self.classes[i],
267+
PRED_SCORE_KEY: prob.item()
268+
})
269+
return result
270+
271+
272+
class CustomLabelsBinningClassifier(CustomLabelsClassifier):
273+
def __init__(self, cls_to_bin: dict, **kwargs):
274+
super().__init__(cls_ary=cls_to_bin.keys(), **kwargs)
275+
self.cls_to_bin = cls_to_bin
276+
if any([pd.isna(x) or not x for x in cls_to_bin.values()]):
277+
raise ValueError("Empty, null, or nan are not allowed for bin values.")
278+
279+
def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
280+
result = []
281+
output = collections.defaultdict(float)
282+
for i in range(len(self.classes)):
283+
name = self.cls_to_bin[self.classes[i]]
284+
output[name] += img_probs[i]
285+
topk_names = heapq.nlargest(k, output, key=output.get)
286+
for name in topk_names:
287+
result.append({
288+
PRED_FILENAME_KEY: image_path,
289+
PRED_CLASSICATION_KEY: name,
290+
PRED_SCORE_KEY: output[name].item()
291+
})
263292
return result
264293

265294

tests/test_main.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
22
from unittest.mock import mock_open, patch
33
import argparse
4-
from bioclip.__main__ import parse_args, Rank, create_classes_str, main
4+
import pandas as pd
5+
from bioclip.__main__ import parse_args, Rank, create_classes_str, main, parse_bins_csv
56

67

78
class TestParser(unittest.TestCase):
@@ -15,6 +16,7 @@ def test_parse_args(self):
1516
self.assertEqual(args.rank, Rank.SPECIES)
1617
self.assertEqual(args.k, 5)
1718
self.assertEqual(args.cls, None)
19+
self.assertEqual(args.bins, None)
1820
self.assertEqual(args.device, 'cpu')
1921

2022
args = parse_args(['predict', 'image.jpg', 'image2.png'])
@@ -41,12 +43,29 @@ def test_parse_args(self):
4143
self.assertEqual(args.rank, None) # default ignored for the --cls variation
4244
self.assertEqual(args.k, None)
4345
self.assertEqual(args.cls, 'class1,class2')
46+
self.assertEqual(args.bins, None)
47+
self.assertEqual(args.device, 'cuda')
48+
49+
# test binning version of predict
50+
args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--bins', 'bins.csv', '--device', 'cuda'])
51+
self.assertEqual(args.command, 'predict')
52+
self.assertEqual(args.image_file, ['image.jpg'])
53+
self.assertEqual(args.format, 'table')
54+
self.assertEqual(args.output, 'output.csv')
55+
self.assertEqual(args.rank, None) # default ignored for the --cls variation
56+
self.assertEqual(args.k, None)
57+
self.assertEqual(args.cls, None)
58+
self.assertEqual(args.bins, 'bins.csv')
4459
self.assertEqual(args.device, 'cuda')
4560

4661
# test error when using --cls with --rank
47-
with self.assertRaises(ValueError):
62+
with self.assertRaises(SystemExit):
4863
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus'])
4964

65+
# test error when using --cls with --bins
66+
with self.assertRaises(SystemExit):
67+
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--bins', 'somefile.csv'])
68+
5069
# not an error when using --cls with --k
5170
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
5271
self.assertEqual(args.k, 10)
@@ -77,10 +96,10 @@ def test_create_classes_str(self):
7796
def test_predict_no_class(self, mock_parse_args, mock_predict):
7897
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
7998
output='stdout', rank=Rank.SPECIES, k=5, cls=None, device='cpu',
80-
model=None, pretrained=None)
99+
model=None, pretrained=None, bins=None)
81100
main()
82-
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, k=5,
83-
device='cpu', model_str=None, pretrained_str=None)
101+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES,
102+
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)
84103

85104
@patch('bioclip.__main__.predict')
86105
@patch('bioclip.__main__.parse_args')
@@ -89,10 +108,10 @@ def test_predict_class_list(self, mock_os, mock_parse_args, mock_predict):
89108
mock_os.path.exists.return_value = False
90109
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
91110
output='stdout', rank=Rank.SPECIES, k=5, cls='dog,fish,bird',
92-
device='cpu', model=None, pretrained=None)
111+
device='cpu', model=None, pretrained=None, bins=None)
93112
main()
94113
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
95-
k=5, device='cpu', model_str=None, pretrained_str=None)
114+
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)
96115

97116
@patch('bioclip.__main__.predict')
98117
@patch('bioclip.__main__.parse_args')
@@ -101,8 +120,38 @@ def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict):
101120
mock_os.path.exists.return_value = True
102121
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
103122
output='stdout', rank=Rank.SPECIES, k=5, cls='somefile.txt',
104-
device='cpu', model=None, pretrained=None)
123+
device='cpu', model=None, pretrained=None, bins=None)
105124
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
106125
main()
107126
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
108-
k=5, device='cpu', model_str=None, pretrained_str=None)
127+
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)
128+
129+
@patch('bioclip.__main__.predict')
130+
@patch('bioclip.__main__.parse_args')
131+
@patch('bioclip.__main__.os')
132+
def test_predict_bins(self, mock_os, mock_parse_args, mock_predict):
133+
mock_os.path.exists.return_value = True
134+
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
135+
output='stdout', rank=None, k=5, cls=None,
136+
device='cpu', model=None, pretrained=None,
137+
bins='some.csv')
138+
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
139+
main()
140+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=None,
141+
bins_path='some.csv', k=5, device='cpu', model_str=None, pretrained_str=None)
142+
@patch('bioclip.__main__.os.path')
143+
def test_parse_bins_csv_file_missing(self, mock_path):
144+
mock_path.exists.return_value = False
145+
with self.assertRaises(FileNotFoundError) as raised_exception:
146+
parse_bins_csv("somefile.csv")
147+
self.assertEqual(str(raised_exception.exception), 'File not found: somefile.csv')
148+
149+
@patch('bioclip.__main__.pd')
150+
@patch('bioclip.__main__.os.path')
151+
def test_parse_bins_csv(self, mock_path, mock_pd):
152+
mock_path.exists.return_value = True
153+
data = {'bin': ['a', 'b']}
154+
mock_pd.read_csv.return_value = pd.DataFrame(data=data, index=['dog', 'cat'])
155+
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
156+
cls_to_bin = parse_bins_csv("somefile.csv")
157+
self.assertEqual(cls_to_bin, {'cat': 'b', 'dog': 'a'})

0 commit comments

Comments
 (0)