Skip to content

Commit 1abf49e

Browse files
committed
Add --cls-file to predict command
Fixes #30
1 parent 452c5c6 commit 1abf49e

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ bear 1.0
104104

105105
## Command Line Usage
106106
```
107-
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
107+
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS | --cls-file CLS_FILE] [--device DEVICE] image_file [image_file ...]
108108
bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...]
109109
110110
Commands:
@@ -119,7 +119,9 @@ Options:
119119
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
120120
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
121121
--k=K number of top predictions to show [default: 5]
122-
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed
122+
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k
123+
arguments are not allowed
124+
--cls-file CLS_FILE path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed
123125
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
124126
--output=OUTFILE print output to file OUTFILE [default: stdout]
125127
```

src/bioclip/__main__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def create_parser():
8383
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
8484
help='rank of the classification, default: species (when)')
8585
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
86-
predict_parser.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
86+
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
87+
cls_group.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
88+
cls_group.add_argument('--cls-file', help='path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed')
8789
predict_parser.add_argument('--device', **device_arg)
8890
predict_parser.add_argument('--model', **model_arg)
8991
predict_parser.add_argument('--pretrained', **pretrained_arg)
@@ -128,6 +130,13 @@ def parse_args(input_args=None):
128130
return args
129131

130132

133+
def create_classes_str(cls_file_path):
134+
"""Reads a file with one class per line and returns a comma separated string of classes"""
135+
with open(cls_file_path, 'r') as cls_file:
136+
cls_str = [item.strip() for item in cls_file.readlines()]
137+
return ",".join(cls_str)
138+
139+
131140
def main():
132141
args = parse_args()
133142
if args.command == 'embed':
@@ -137,10 +146,13 @@ def main():
137146
model_str=args.model,
138147
pretrained_str=args.pretrained)
139148
elif args.command == 'predict':
149+
cls_str = args.cls
150+
if args.cls_file:
151+
cls_str = create_classes_str(args.cls_file)
140152
predict(args.image_file,
141153
format=args.format,
142154
output=args.output,
143-
cls_str=args.cls,
155+
cls_str=cls_str,
144156
rank=args.rank,
145157
k=args.k,
146158
device=args.device,

tests/test_main.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import unittest
2-
from bioclip.__main__ import parse_args, Rank
2+
from unittest.mock import mock_open, patch
3+
import argparse
4+
from bioclip.__main__ import parse_args, Rank, create_classes_str
35

46

57
class TestParser(unittest.TestCase):
@@ -49,6 +51,10 @@ def test_parse_args(self):
4951
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
5052
self.assertEqual(args.k, 10)
5153

54+
args = parse_args(['predict', '--cls-file', 'somefile.txt', 'image.jpg'])
55+
self.assertEqual(args.cls_file, 'somefile.txt')
56+
self.assertEqual(args.cls, None)
57+
5258
args = parse_args(['embed', 'image.jpg'])
5359
self.assertEqual(args.command, 'embed')
5460
self.assertEqual(args.image_file, ['image.jpg'])
@@ -60,3 +66,8 @@ def test_parse_args(self):
6066
self.assertEqual(args.image_file, ['image.jpg', 'image2.png'])
6167
self.assertEqual(args.output, 'data.json')
6268
self.assertEqual(args.device, 'cuda')
69+
70+
def test_create_classes_str(self):
71+
data = "class1\nclass2\nclass3"
72+
with patch("builtins.open", mock_open(read_data=data)) as mock_file:
73+
self.assertEqual(create_classes_str('path/to/file'), 'class1,class2,class3')

0 commit comments

Comments
 (0)