@@ -119,22 +119,6 @@ def get_txt_names():
119
119
return txt_names
120
120
121
121
122
- def open_image (image_path ):
123
- img = PIL .Image .open (image_path )
124
- return img .convert ("RGB" )
125
-
126
-
127
- preprocess_img = transforms .Compose (
128
- [
129
- transforms .ToTensor (),
130
- transforms .Resize ((224 , 224 ), antialias = True ),
131
- transforms .Normalize (
132
- mean = (0.48145466 , 0.4578275 , 0.40821073 ),
133
- std = (0.26862954 , 0.26130258 , 0.27577711 ),
134
- ),
135
- ]
136
- )
137
-
138
122
class Rank (Enum ):
139
123
KINGDOM = 0
140
124
PHYLUM = 1
@@ -165,11 +149,68 @@ def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
165
149
return get_tokenizer (tokenizer_str )
166
150
167
151
168
- class CustomLabelsClassifier (object ):
169
- def __init__ (self , cls_ary : List [str ], device : Union [str , torch .device ] = 'cpu' , model_str : str = MODEL_STR ):
152
+ preprocess_img = transforms .Compose (
153
+ [
154
+ transforms .ToTensor (),
155
+ transforms .Resize ((224 , 224 ), antialias = True ),
156
+ transforms .Normalize (
157
+ mean = (0.48145466 , 0.4578275 , 0.40821073 ),
158
+ std = (0.26862954 , 0.26130258 , 0.27577711 ),
159
+ ),
160
+ ]
161
+ )
162
+
163
+
164
+ class BaseClassifier (object ):
165
+ def __init__ (self , device : Union [str , torch .device ] = 'cpu' , model_str : str = MODEL_STR ):
170
166
self .device = device
171
167
self .model = create_bioclip_model (device = device , model_str = model_str )
172
168
self .model_str = model_str
169
+
170
+ @staticmethod
171
+ def open_image (image_path ):
172
+ img = PIL .Image .open (image_path )
173
+ return img .convert ("RGB" )
174
+
175
+ @torch .no_grad ()
176
+ def create_image_features (self , images : List [PIL .Image .Image ], normalize : bool = True ) -> torch .Tensor :
177
+ preprocessed_images = []
178
+ for img in images :
179
+ prep_img = preprocess_img (img ).to (self .device )
180
+ preprocessed_images .append (prep_img )
181
+ preprocessed_image_tensor = torch .stack (preprocessed_images )
182
+ img_features = self .model .encode_image (preprocessed_image_tensor )
183
+ if normalize :
184
+ return F .normalize (img_features , dim = - 1 )
185
+ else :
186
+ return img_features
187
+
188
+ @torch .no_grad ()
189
+ def create_image_features_for_path (self , image_path : str , normalize : bool ) -> torch .Tensor :
190
+ img = self .open_image (image_path )
191
+ result = self .create_image_features ([img ], normalize = normalize )
192
+ return result [0 ]
193
+
194
+ @torch .no_grad ()
195
+ def create_probabilities (self , img_features : torch .Tensor ,
196
+ txt_features : torch .Tensor ) -> dict [str , torch .Tensor ]:
197
+ logits = (self .model .logit_scale .exp () * img_features @ txt_features )
198
+ return F .softmax (logits , dim = 1 )
199
+
200
+ def create_probabilities_for_image_paths (self , image_paths : List [str ] | str ,
201
+ txt_features : torch .Tensor ) -> dict [str , torch .Tensor ]:
202
+ images = [self .open_image (image_path ) for image_path in image_paths ]
203
+ img_features = self .create_image_features (images )
204
+ probs = self .create_probabilities (img_features , txt_features )
205
+ result = {}
206
+ for i , key in enumerate (image_paths ):
207
+ result [key ] = probs [i ]
208
+ return result
209
+
210
+
211
+ class CustomLabelsClassifier (BaseClassifier ):
212
+ def __init__ (self , cls_ary : List [str ], device : Union [str , torch .device ] = 'cpu' , model_str : str = MODEL_STR ):
213
+ super ().__init__ (device = device , model_str = model_str )
173
214
self .tokenizer = create_bioclip_tokenizer ()
174
215
self .classes = [cls .strip () for cls in cls_ary ]
175
216
self .txt_features = self ._get_txt_features (self .classes )
@@ -188,28 +229,24 @@ def _get_txt_features(self, classnames):
188
229
return all_features
189
230
190
231
@torch .no_grad ()
191
- def predict (self , image_path : str ) -> dict [str , float ]:
192
- img = open_image (image_path )
193
-
194
- img = preprocess_img (img ).to (self .device )
195
- img_features = self .model .encode_image (img .unsqueeze (0 ))
196
- img_features = F .normalize (img_features , dim = - 1 )
197
-
198
- logits = (self .model .logit_scale .exp () * img_features @ self .txt_features ).squeeze ()
199
- probs = F .softmax (logits , dim = 0 ).to ("cpu" ).tolist ()
200
- pred_list = []
201
- for cls , prob in zip (self .classes , probs ):
202
- pred_list .append ({
203
- PRED_FILENAME_KEY : image_path ,
204
- PRED_CLASSICATION_KEY : cls ,
205
- PRED_SCORE_KEY : prob
206
- })
207
- return pred_list
232
+ def predict (self , image_paths : List [str ] | str ) -> dict [str , float ]:
233
+ if isinstance (image_paths , str ):
234
+ image_paths = [image_paths ]
235
+ probs = self .create_probabilities_for_image_paths (image_paths , self .txt_features )
236
+ result = []
237
+ for image_path in image_paths :
238
+ for cls_str , prob in zip (self .classes , probs [image_path ]):
239
+ result .append ({
240
+ PRED_FILENAME_KEY : image_path ,
241
+ PRED_CLASSICATION_KEY : cls_str ,
242
+ PRED_SCORE_KEY : prob .item ()
243
+ })
244
+ return result
208
245
209
246
210
247
def predict_classifications_from_list (img : Union [PIL .Image .Image , str ], cls_ary : List [str ], device : Union [str , torch .device ] = 'cpu' ) -> dict [str , float ]:
211
248
classifier = CustomLabelsClassifier (cls_ary = cls_ary , device = device )
212
- return classifier .predict (img )
249
+ return classifier .predict ([ img ] )
213
250
214
251
215
252
def get_tol_classification_labels (rank : Rank ) -> List [str ]:
@@ -244,31 +281,12 @@ def join_names(classification_dict: dict[str, str]) -> str:
244
281
return " " .join (classification_dict .values ())
245
282
246
283
247
- class TreeOfLifeClassifier (object ):
284
+ class TreeOfLifeClassifier (BaseClassifier ):
248
285
def __init__ (self , device : Union [str , torch .device ] = 'cpu' , model_str : str = MODEL_STR ):
249
- self .device = device
250
- self .model = create_bioclip_model (device = device , model_str = model_str )
251
- self .model_str = model_str
252
- self .txt_emb = get_txt_emb ().to (device )
286
+ super ().__init__ (device = device , model_str = model_str )
287
+ self .txt_features = get_txt_emb ().to (device )
253
288
self .txt_names = get_txt_names ()
254
289
255
- @torch .no_grad ()
256
- def get_image_features (self , image_path : str ) -> torch .Tensor :
257
- img = open_image (image_path )
258
- return self .encode_image (img )
259
-
260
- def encode_image (self , img : PIL .Image .Image ) -> torch .Tensor :
261
- img = preprocess_img (img ).to (self .device )
262
- img_features = self .model .encode_image (img .unsqueeze (0 ))
263
- return img_features
264
-
265
- def predict_species (self , img : PIL .Image .Image ) -> torch .Tensor :
266
- img_features = self .encode_image (img )
267
- img_features = F .normalize (img_features , dim = - 1 )
268
- logits = (self .model .logit_scale .exp () * img_features @ self .txt_emb ).squeeze ()
269
- probs = F .softmax (logits , dim = 0 )
270
- return probs
271
-
272
290
def format_species_probs (self , image_path : str , probs : torch .Tensor , k : int = 5 ) -> List [dict [str , float ]]:
273
291
topk = probs .topk (k )
274
292
result = []
@@ -299,12 +317,17 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
299
317
return prediction_ary
300
318
301
319
@torch .no_grad ()
302
- def predict (self , image_path : str , rank : Rank , min_prob : float = 1e-9 , k : int = 5 ) -> List [dict [str , float ]]:
303
- img = open_image (image_path )
304
- probs = self .predict_species (img )
305
- if rank == Rank .SPECIES :
306
- return self .format_species_probs (image_path , probs , k )
307
- return self .format_grouped_probs (image_path , probs , rank , min_prob , k )
320
+ def predict (self , image_paths : List [str ] | str , rank : Rank , min_prob : float = 1e-9 , k : int = 5 ) -> dict [str , dict [str , float ]]:
321
+ if isinstance (image_paths , str ):
322
+ image_paths = [image_paths ]
323
+ probs = self .create_probabilities_for_image_paths (image_paths , self .txt_features )
324
+ result = []
325
+ for image_path in image_paths :
326
+ if rank == Rank .SPECIES :
327
+ result .extend (self .format_species_probs (image_path , probs [image_path ], k ))
328
+ else :
329
+ result .extend (self .format_grouped_probs (image_path , probs [image_path ], rank , min_prob , k ))
330
+ return result
308
331
309
332
310
333
def predict_classification (img : str , rank : Rank , device : Union [str , torch .device ] = 'cpu' ,
@@ -315,4 +338,4 @@ def predict_classification(img: str, rank: Rank, device: Union[str, torch.device
315
338
species, then sums up species-level probabilities for the given rank.
316
339
"""
317
340
classifier = TreeOfLifeClassifier (device = device )
318
- return classifier .predict (img , rank , min_prob , k )
341
+ return classifier .predict ([ img ] , rank , min_prob , k )
0 commit comments