Skip to content

Commit 3d4d647

Browse files
committed
Allow predict --cls to receive a file path
1 parent 1abf49e commit 3d4d647

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ Options:
119119
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
120120
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
121121
--k=K number of top predictions to show [default: 5]
122-
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k
123-
arguments are not allowed
124-
--cls-file CLS_FILE path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed
122+
--cls=CLS classes to predict either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed.
123+
--rank and --k arguments are not allowed
125124
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
126125
--output=OUTFILE print output to file OUTFILE [default: stdout]
127126
```

src/bioclip/__main__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
22
from .predict import BIOCLIP_MODEL_STR
33
import open_clip as oc
4+
import os
45
import json
56
import sys
67
import prettytable as pt
@@ -83,9 +84,9 @@ def create_parser():
8384
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
8485
help='rank of the classification, default: species (when)')
8586
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
86-
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
87-
cls_group.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
88-
cls_group.add_argument('--cls-file', help='path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed')
87+
cls_help = "classes to predict either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed."
88+
predict_parser.add_argument('--cls', help=cls_help)
89+
8990
predict_parser.add_argument('--device', **device_arg)
9091
predict_parser.add_argument('--model', **model_arg)
9192
predict_parser.add_argument('--pretrained', **pretrained_arg)
@@ -147,8 +148,8 @@ def main():
147148
pretrained_str=args.pretrained)
148149
elif args.command == 'predict':
149150
cls_str = args.cls
150-
if args.cls_file:
151-
cls_str = create_classes_str(args.cls_file)
151+
if os.path.exists(args.cls):
152+
cls_str = create_classes_str(args.cls)
152153
predict(args.image_file,
153154
format=args.format,
154155
output=args.output,

tests/test_main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def test_parse_args(self):
5151
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
5252
self.assertEqual(args.k, 10)
5353

54-
args = parse_args(['predict', '--cls-file', 'somefile.txt', 'image.jpg'])
55-
self.assertEqual(args.cls_file, 'somefile.txt')
56-
self.assertEqual(args.cls, None)
54+
# example showing filename
55+
args = parse_args(['predict', 'image.jpg', '--cls', 'classes.txt', '--k', '10'])
56+
self.assertEqual(args.cls, 'classes.txt')
5757

5858
args = parse_args(['embed', 'image.jpg'])
5959
self.assertEqual(args.command, 'embed')

0 commit comments

Comments
 (0)