Skip to content

Commit 5317df1

Browse files
committed
Fix empty cls
1 parent 147eb57 commit 5317df1

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/bioclip/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def main():
148148
pretrained_str=args.pretrained)
149149
elif args.command == 'predict':
150150
cls_str = args.cls
151-
if os.path.exists(args.cls):
151+
if args.cls and os.path.exists(args.cls):
152152
cls_str = create_classes_str(args.cls)
153153
predict(args.image_file,
154154
format=args.format,

tests/test_main.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from unittest.mock import mock_open, patch
33
import argparse
4-
from bioclip.__main__ import parse_args, Rank, create_classes_str
4+
from bioclip.__main__ import parse_args, Rank, create_classes_str, main
55

66

77
class TestParser(unittest.TestCase):
@@ -71,3 +71,38 @@ def test_create_classes_str(self):
7171
data = "class1\nclass2\nclass3"
7272
with patch("builtins.open", mock_open(read_data=data)) as mock_file:
7373
self.assertEqual(create_classes_str('path/to/file'), 'class1,class2,class3')
74+
75+
@patch('bioclip.__main__.predict')
76+
@patch('bioclip.__main__.parse_args')
77+
def test_predict_no_class(self, mock_parse_args, mock_predict):
78+
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
79+
output='stdout', rank=Rank.SPECIES, k=5, cls=None, device='cpu',
80+
model=None, pretrained=None)
81+
main()
82+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, k=5,
83+
device='cpu', model_str=None, pretrained_str=None)
84+
85+
@patch('bioclip.__main__.predict')
86+
@patch('bioclip.__main__.parse_args')
87+
@patch('bioclip.__main__.os')
88+
def test_predict_class_list(self, mock_os, mock_parse_args, mock_predict):
89+
mock_os.path.exists.return_value = False
90+
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
91+
output='stdout', rank=Rank.SPECIES, k=5, cls='dog,fish,bird',
92+
device='cpu', model=None, pretrained=None)
93+
main()
94+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
95+
k=5, device='cpu', model_str=None, pretrained_str=None)
96+
97+
@patch('bioclip.__main__.predict')
98+
@patch('bioclip.__main__.parse_args')
99+
@patch('bioclip.__main__.os')
100+
def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict):
101+
mock_os.path.exists.return_value = True
102+
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
103+
output='stdout', rank=Rank.SPECIES, k=5, cls='somefile.txt',
104+
device='cpu', model=None, pretrained=None)
105+
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
106+
main()
107+
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
108+
k=5, device='cpu', model_str=None, pretrained_str=None)

0 commit comments

Comments
 (0)