1
1
from bioclip import TreeOfLifeClassifier , Rank , CustomLabelsClassifier
2
+ from .predict import BIOCLIP_MODEL_STR
3
+ import open_clip as oc
2
4
import json
3
5
import sys
4
6
import prettytable as pt
@@ -29,20 +31,25 @@ def write_results_to_file(df, format, outfile):
29
31
raise ValueError (f"Invalid format: { format } " )
30
32
31
33
32
- def predict (image_file : list [str ], format : str , output : str ,
33
- cls_str : str , device : str , rank : Rank , k : int ):
34
+ def predict (image_file : list [str ],
35
+ format : str ,
36
+ output : str ,
37
+ cls_str : str ,
38
+ rank : Rank ,
39
+ k : int ,
40
+ ** kwargs ):
34
41
if cls_str :
35
- classifier = CustomLabelsClassifier (cls_ary = cls_str .split (',' ), device = device )
42
+ classifier = CustomLabelsClassifier (cls_ary = cls_str .split (',' ), ** kwargs )
36
43
predictions = classifier .predict (image_paths = image_file , k = k )
37
44
write_results (predictions , format , output )
38
45
else :
39
- classifier = TreeOfLifeClassifier (device = device )
46
+ classifier = TreeOfLifeClassifier (** kwargs )
40
47
predictions = classifier .predict (image_paths = image_file , rank = rank , k = k )
41
48
write_results (predictions , format , output )
42
49
43
50
44
- def embed (image_file : list [str ], output : str , device : str ):
45
- classifier = TreeOfLifeClassifier (device = device )
51
+ def embed (image_file : list [str ], output : str , ** kwargs ):
52
+ classifier = TreeOfLifeClassifier (** kwargs )
46
53
images_dict = {}
47
54
data = {
48
55
"model" : classifier .model_str ,
@@ -62,22 +69,42 @@ def create_parser():
62
69
parser = argparse .ArgumentParser (prog = 'bioclip' , description = 'BioCLIP command line interface' )
63
70
subparsers = parser .add_subparsers (title = 'commands' , dest = 'command' )
64
71
72
+ device_arg = {'default' :'cpu' , 'help' : 'device to use (cpu or cuda or mps), default: cpu' }
73
+ output_arg = {'default' : 'stdout' , 'help' : 'print output to file, default: stdout' }
74
+ model_arg = {'help' : f'model identifier (see command list-models); default: { BIOCLIP_MODEL_STR } ' }
75
+ pretrained_arg = {'help' : 'pretrained model checkpoint as tag or file, depends on model; '
76
+ 'needed only if more than one is available (see command list-models)' }
77
+
65
78
# Predict command
66
79
predict_parser = subparsers .add_parser ('predict' , help = 'Use BioCLIP to generate predictions for image files.' )
67
80
predict_parser .add_argument ('image_file' , nargs = '+' , help = 'input image file(s)' )
68
81
predict_parser .add_argument ('--format' , choices = ['table' , 'csv' ], default = 'csv' , help = 'format of the output, default: csv' )
69
- predict_parser .add_argument ('--output' , default = 'stdout' , help = 'print output to file, default: stdout' )
82
+ predict_parser .add_argument ('--output' , ** output_arg )
70
83
predict_parser .add_argument ('--rank' , choices = ['kingdom' , 'phylum' , 'class' , 'order' , 'family' , 'genus' , 'species' ],
71
84
help = 'rank of the classification, default: species (when)' )
72
85
predict_parser .add_argument ('--k' , type = int , help = 'number of top predictions to show, default: 5' )
73
- predict_parser .add_argument ('--cls' , help = 'comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed' )
74
- predict_parser .add_argument ('--device' , help = 'device to use (cpu or cuda or mps), default: cpu' , default = 'cpu' )
86
+ predict_parser .add_argument ('--cls' , help = 'comma separated list of classes to predict, when specified the --rank argument is not allowed' )
87
+ predict_parser .add_argument ('--device' , ** device_arg )
88
+ predict_parser .add_argument ('--model' , ** model_arg )
89
+ predict_parser .add_argument ('--pretrained' , ** pretrained_arg )
75
90
76
91
# Embed command
77
92
embed_parser = subparsers .add_parser ('embed' , help = 'Use BioCLIP to generate embeddings for image files.' )
78
93
embed_parser .add_argument ('image_file' , nargs = '+' , help = 'input image file(s)' )
79
- embed_parser .add_argument ('--output' , default = 'stdout' , help = 'print output to file, default: stdout' )
80
- embed_parser .add_argument ('--device' , help = 'device to use (cpu or cuda or mps), default: cpu' , default = 'cpu' )
94
+ embed_parser .add_argument ('--output' , ** output_arg )
95
+ embed_parser .add_argument ('--device' , ** device_arg )
96
+ embed_parser .add_argument ('--model' , ** model_arg )
97
+ embed_parser .add_argument ('--pretrained' , ** pretrained_arg )
98
+
99
+ # List command
100
+ list_parser = subparsers .add_parser ('list-models' ,
101
+ help = 'List available models and pretrained model checkpoints.' ,
102
+ description =
103
+ 'Note that this will only list models known to open_clip; '
104
+ 'any model identifier loadable by open_clip, such as from hf-hub, file, etc '
105
+ 'should also be usable for --model in the embed and predict commands. '
106
+ f'(The default model { BIOCLIP_MODEL_STR } is one example.)' )
107
+ list_parser .add_argument ('--model' , help = 'list available tags for pretrained model checkpoint(s) for specified model' )
81
108
82
109
return parser
83
110
@@ -91,6 +118,8 @@ def parse_args(input_args=None):
91
118
raise ValueError ("Cannot use --cls with --rank" )
92
119
else :
93
120
# tree of life class list mode
121
+ if args .model or args .pretrained :
122
+ raise ValueError ("Custom model or checkpoints currently not supported for Tree-of-Life prediction" )
94
123
if not args .rank :
95
124
args .rank = 'species'
96
125
args .rank = Rank [args .rank .upper ()]
@@ -102,9 +131,28 @@ def parse_args(input_args=None):
102
131
def main ():
103
132
args = parse_args ()
104
133
if args .command == 'embed' :
105
- embed (args .image_file , args .output , args .device )
134
+ embed (args .image_file ,
135
+ args .output ,
136
+ device = args .device ,
137
+ model_str = args .model ,
138
+ pretrained_str = args .pretrained )
106
139
elif args .command == 'predict' :
107
- predict (args .image_file , args .format , args .output , args .cls , args .device , args .rank , args .k )
140
+ predict (args .image_file ,
141
+ format = args .format ,
142
+ output = args .output ,
143
+ cls_str = args .cls ,
144
+ rank = args .rank ,
145
+ k = args .k ,
146
+ device = args .device ,
147
+ model_str = args .model ,
148
+ pretrained_str = args .pretrained )
149
+ elif args .command == 'list-models' :
150
+ if args .model :
151
+ for tag in oc .list_pretrained_tags_by_model (args .model ):
152
+ print (tag )
153
+ else :
154
+ for model_str in oc .list_models ():
155
+ print (f"\t { model_str } " )
108
156
else :
109
157
raise ValueError ("Invalid command" )
110
158
0 commit comments