Skip to content

Commit 2aca6dd

Browse files
authored
Allows specifying top k results for custom class list (#31)
This change also changes the predictions to be returned sorted in descending order for each image, like in the case of no custom classes. Arguably this ordering is a more useful output whether asking for the top k predictions or all of them.
1 parent 57f6522 commit 2aca6dd

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

src/bioclip/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def predict(image_file: list[str], format: str, output: str,
3333
cls_str: str, device: str, rank: Rank, k: int):
3434
if cls_str:
3535
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), device=device)
36-
predictions = classifier.predict(image_paths=image_file)
36+
predictions = classifier.predict(image_paths=image_file, k=k)
3737
write_results(predictions, format, output)
3838
else:
3939
classifier = TreeOfLifeClassifier(device=device)
@@ -87,8 +87,8 @@ def parse_args(input_args=None):
8787
if args.command == 'predict':
8888
if args.cls:
8989
# custom class list mode
90-
if args.rank or args.k:
91-
raise ValueError("Cannot use --cls with --rank or --k")
90+
if args.rank:
91+
raise ValueError("Cannot use --cls with --rank")
9292
else:
9393
# tree of life class list mode
9494
if not args.rank:

src/bioclip/predict.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,20 @@ def _get_txt_features(self, classnames):
229229
return all_features
230230

231231
@torch.no_grad()
232-
def predict(self, image_paths: List[str] | str) -> dict[str, float]:
232+
def predict(self, image_paths: List[str] | str, k: int = None) -> dict[str, float]:
233233
if isinstance(image_paths, str):
234234
image_paths = [image_paths]
235235
probs = self.create_probabilities_for_image_paths(image_paths, self.txt_features)
236236
result = []
237237
for image_path in image_paths:
238-
for cls_str, prob in zip(self.classes, probs[image_path]):
238+
img_probs = probs[image_path]
239+
if not k or k > len(self.classes):
240+
k = len(self.classes)
241+
topk = img_probs.topk(k)
242+
for i, prob in zip(topk.indices, topk.values):
239243
result.append({
240244
PRED_FILENAME_KEY: image_path,
241-
PRED_CLASSICATION_KEY: cls_str,
245+
PRED_CLASSICATION_KEY: self.classes[i],
242246
PRED_SCORE_KEY: prob.item()
243247
})
244248
return result

tests/test_main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ def test_parse_args(self):
4444
# test error when using --cls with --rank
4545
with self.assertRaises(ValueError):
4646
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus'])
47-
# test error when using --cls with --k
48-
with self.assertRaises(ValueError):
49-
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
47+
48+
# not an error when using --cls with --k
49+
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
50+
self.assertEqual(args.k, 10)
5051

5152
args = parse_args(['embed', 'image.jpg'])
5253
self.assertEqual(args.command, 'embed')

0 commit comments

Comments
 (0)