Skip to content

Commit 3f65a47

Browse files
authored
Merge pull request #37 from Imageomics/30-classes-file
Add --cls-file to predict command
2 parents 452c5c6 + 5317df1 commit 3f65a47

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Options:
119119
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
120120
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
121121
--k=K number of top predictions to show [default: 5]
122-
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed
122+
--cls=CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed.
123123
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
124124
--output=OUTFILE print output to file OUTFILE [default: stdout]
125125
```

src/bioclip/__main__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
22
from .predict import BIOCLIP_MODEL_STR
33
import open_clip as oc
4+
import os
45
import json
56
import sys
67
import prettytable as pt
@@ -83,7 +84,9 @@ def create_parser():
8384
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
8485
help='rank of the classification, default: species (when)')
8586
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
86-
predict_parser.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
87+
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed."
88+
predict_parser.add_argument('--cls', help=cls_help)
89+
8790
predict_parser.add_argument('--device', **device_arg)
8891
predict_parser.add_argument('--model', **model_arg)
8992
predict_parser.add_argument('--pretrained', **pretrained_arg)
@@ -128,6 +131,13 @@ def parse_args(input_args=None):
128131
return args
129132

130133

134+
def create_classes_str(cls_file_path):
135+
"""Reads a file with one class per line and returns a comma separated string of classes"""
136+
with open(cls_file_path, 'r') as cls_file:
137+
cls_str = [item.strip() for item in cls_file.readlines()]
138+
return ",".join(cls_str)
139+
140+
131141
def main():
132142
args = parse_args()
133143
if args.command == 'embed':
@@ -137,10 +147,13 @@ def main():
137147
model_str=args.model,
138148
pretrained_str=args.pretrained)
139149
elif args.command == 'predict':
150+
cls_str = args.cls
151+
if args.cls and os.path.exists(args.cls):
152+
cls_str = create_classes_str(args.cls)
140153
predict(args.image_file,
141154
format=args.format,
142155
output=args.output,
143-
cls_str=args.cls,
156+
cls_str=cls_str,
144157
rank=args.rank,
145158
k=args.k,
146159
device=args.device,

tests/test_main.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import unittest
2-
from bioclip.__main__ import parse_args, Rank
2+
from unittest.mock import mock_open, patch
3+
import argparse
4+
from bioclip.__main__ import parse_args, Rank, create_classes_str, main
35

46

57
class TestParser(unittest.TestCase):
@@ -49,6 +51,10 @@ def test_parse_args(self):
4951
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
5052
self.assertEqual(args.k, 10)
5153

54+
# example showing filename
55+
args = parse_args(['predict', 'image.jpg', '--cls', 'classes.txt', '--k', '10'])
56+
self.assertEqual(args.cls, 'classes.txt')
57+
5258
args = parse_args(['embed', 'image.jpg'])
5359
self.assertEqual(args.command, 'embed')
5460
self.assertEqual(args.image_file, ['image.jpg'])
@@ -60,3 +66,43 @@ def test_parse_args(self):
6066
self.assertEqual(args.image_file, ['image.jpg', 'image2.png'])
6167
self.assertEqual(args.output, 'data.json')
6268
self.assertEqual(args.device, 'cuda')
69+
70+
def test_create_classes_str(self):
71+
data = "class1\nclass2\nclass3"
72+
with patch("builtins.open", mock_open(read_data=data)) as mock_file:
73+
self.assertEqual(create_classes_str('path/to/file'), 'class1,class2,class3')
74+
75+
@patch('bioclip.__main__.predict')
76+
@patch('bioclip.__main__.parse_args')
77+
def test_predict_no_class(self, mock_parse_args, mock_predict):
78+
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
79+
output='stdout', rank=Rank.SPECIES, k=5, cls=None, device='cpu',
80+
model=None, pretrained=None)
81+
main()
82+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, k=5,
83+
device='cpu', model_str=None, pretrained_str=None)
84+
85+
@patch('bioclip.__main__.predict')
86+
@patch('bioclip.__main__.parse_args')
87+
@patch('bioclip.__main__.os')
88+
def test_predict_class_list(self, mock_os, mock_parse_args, mock_predict):
89+
mock_os.path.exists.return_value = False
90+
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
91+
output='stdout', rank=Rank.SPECIES, k=5, cls='dog,fish,bird',
92+
device='cpu', model=None, pretrained=None)
93+
main()
94+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
95+
k=5, device='cpu', model_str=None, pretrained_str=None)
96+
97+
@patch('bioclip.__main__.predict')
98+
@patch('bioclip.__main__.parse_args')
99+
@patch('bioclip.__main__.os')
100+
def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict):
101+
mock_os.path.exists.return_value = True
102+
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
103+
output='stdout', rank=Rank.SPECIES, k=5, cls='somefile.txt',
104+
device='cpu', model=None, pretrained=None)
105+
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
106+
main()
107+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
108+
k=5, device='cpu', model_str=None, pretrained_str=None)

0 commit comments

Comments
 (0)