@@ -184,10 +184,20 @@ def load_pretrained_model(self, model_str: str = BIOCLIP_MODEL_STR, pretrained_s
184
184
self .preprocess = preprocess_img if self .model_str == BIOCLIP_MODEL_STR else preprocess
185
185
186
186
@staticmethod
187
- def open_image (image_path ):
188
- img = PIL .Image .open (image_path )
187
+ def ensure_rgb_image (image : str | PIL .Image .Image ) -> PIL .Image .Image :
188
+ if isinstance (image , PIL .Image .Image ):
189
+ img = image
190
+ else :
191
+ img = PIL .Image .open (image )
189
192
return img .convert ("RGB" )
190
193
194
+ @staticmethod
195
+ def make_key (image : str | PIL .Image .Image , idx : int ) -> str :
196
+ if isinstance (image , PIL .Image .Image ):
197
+ return f"{ idx } "
198
+ else :
199
+ return image
200
+
191
201
@torch .no_grad ()
192
202
def create_image_features (self , images : List [PIL .Image .Image ], normalize : bool = True ) -> torch .Tensor :
193
203
preprocessed_images = []
@@ -202,8 +212,8 @@ def create_image_features(self, images: List[PIL.Image.Image], normalize : bool
202
212
return img_features
203
213
204
214
@torch .no_grad ()
205
- def create_image_features_for_path (self , image_path : str , normalize : bool ) -> torch .Tensor :
206
- img = self .open_image ( image_path )
215
+ def create_image_features_for_image (self , image : str | PIL . Image . Image , normalize : bool ) -> torch .Tensor :
216
+ img = self .ensure_rgb_image ( image )
207
217
result = self .create_image_features ([img ], normalize = normalize )
208
218
return result [0 ]
209
219
@@ -213,13 +223,14 @@ def create_probabilities(self, img_features: torch.Tensor,
213
223
logits = (self .model .logit_scale .exp () * img_features @ txt_features )
214
224
return F .softmax (logits , dim = 1 )
215
225
216
- def create_probabilities_for_image_paths (self , image_paths : List [str ] | str ,
217
- txt_features : torch .Tensor ) -> dict [str , torch .Tensor ]:
218
- images = [self .open_image (image_path ) for image_path in image_paths ]
226
+ def create_probabilities_for_images (self , images : List [str ] | List [PIL .Image .Image ],
227
+ txt_features : torch .Tensor ) -> dict [str , torch .Tensor ]:
228
+ keys = [self .make_key (image , i ) for i ,image in enumerate (images )]
229
+ images = [self .ensure_rgb_image (image ) for image in images ]
219
230
img_features = self .create_image_features (images )
220
231
probs = self .create_probabilities (img_features , txt_features )
221
232
result = {}
222
- for i , key in enumerate (image_paths ):
233
+ for i , key in enumerate (keys ):
223
234
result [key ] = probs [i ]
224
235
return result
225
236
@@ -245,24 +256,25 @@ def _get_txt_features(self, classnames):
245
256
return all_features
246
257
247
258
@torch .no_grad ()
248
- def predict (self , image_paths : List [str ] | str , k : int = None ) -> dict [str , float ]:
249
- if isinstance (image_paths , str ):
250
- image_paths = [image_paths ]
251
- probs = self .create_probabilities_for_image_paths ( image_paths , self .txt_features )
259
+ def predict (self , images : List [str ] | str | List [ PIL . Image . Image ] , k : int = None ) -> dict [str , float ]:
260
+ if isinstance (images , str ):
261
+ images = [images ]
262
+ probs = self .create_probabilities_for_images ( images , self .txt_features )
252
263
result = []
253
- for image_path in image_paths :
254
- img_probs = probs [image_path ]
264
+ for i , image in enumerate (images ):
265
+ key = self .make_key (image , i )
266
+ img_probs = probs [key ]
255
267
if not k or k > len (self .classes ):
256
268
k = len (self .classes )
257
- result .extend (self .group_probs (image_path , img_probs , k ))
269
+ result .extend (self .group_probs (key , img_probs , k ))
258
270
return result
259
271
260
- def group_probs (self , image_path : str , img_probs : torch .Tensor , k : int = None ) -> List [dict [str , float ]]:
272
+ def group_probs (self , image_key : str , img_probs : torch .Tensor , k : int = None ) -> List [dict [str , float ]]:
261
273
result = []
262
274
topk = img_probs .topk (k )
263
275
for i , prob in zip (topk .indices , topk .values ):
264
276
result .append ({
265
- PRED_FILENAME_KEY : image_path ,
277
+ PRED_FILENAME_KEY : image_key ,
266
278
PRED_CLASSICATION_KEY : self .classes [i ],
267
279
PRED_SCORE_KEY : prob .item ()
268
280
})
@@ -276,7 +288,7 @@ def __init__(self, cls_to_bin: dict, **kwargs):
276
288
if any ([pd .isna (x ) or not x for x in cls_to_bin .values ()]):
277
289
raise ValueError ("Empty, null, or nan are not allowed for bin values." )
278
290
279
- def group_probs (self , image_path : str , img_probs : torch .Tensor , k : int = None ) -> List [dict [str , float ]]:
291
+ def group_probs (self , image_key : str , img_probs : torch .Tensor , k : int = None ) -> List [dict [str , float ]]:
280
292
result = []
281
293
output = collections .defaultdict (float )
282
294
for i in range (len (self .classes )):
@@ -285,7 +297,7 @@ def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -
285
297
topk_names = heapq .nlargest (k , output , key = output .get )
286
298
for name in topk_names :
287
299
result .append ({
288
- PRED_FILENAME_KEY : image_path ,
300
+ PRED_FILENAME_KEY : image_key ,
289
301
PRED_CLASSICATION_KEY : name ,
290
302
PRED_SCORE_KEY : output [name ].item ()
291
303
})
@@ -335,17 +347,17 @@ def __init__(self, **kwargs):
335
347
self .txt_features = get_txt_emb ().to (self .device )
336
348
self .txt_names = get_txt_names ()
337
349
338
- def format_species_probs (self , image_path : str , probs : torch .Tensor , k : int = 5 ) -> List [dict [str , float ]]:
350
+ def format_species_probs (self , image_key : str , probs : torch .Tensor , k : int = 5 ) -> List [dict [str , float ]]:
339
351
topk = probs .topk (k )
340
352
result = []
341
353
for i , prob in zip (topk .indices , topk .values ):
342
- item = { PRED_FILENAME_KEY : image_path }
354
+ item = { PRED_FILENAME_KEY : image_key }
343
355
item .update (create_classification_dict (self .txt_names [i ], Rank .SPECIES ))
344
356
item [PRED_SCORE_KEY ] = prob .item ()
345
357
result .append (item )
346
358
return result
347
359
348
- def format_grouped_probs (self , image_path : str , probs : torch .Tensor , rank : Rank , min_prob : float = 1e-9 , k : int = 5 ) -> List [dict [str , float ]]:
360
+ def format_grouped_probs (self , image_key : str , probs : torch .Tensor , rank : Rank , min_prob : float = 1e-9 , k : int = 5 ) -> List [dict [str , float ]]:
349
361
output = collections .defaultdict (float )
350
362
class_dict_lookup = {}
351
363
name_to_class_dict = {}
@@ -358,27 +370,28 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
358
370
topk_names = heapq .nlargest (k , output , key = output .get )
359
371
prediction_ary = []
360
372
for name in topk_names :
361
- item = { PRED_FILENAME_KEY : image_path }
373
+ item = { PRED_FILENAME_KEY : image_key }
362
374
item .update (name_to_class_dict [name ])
363
375
item [PRED_SCORE_KEY ] = output [name ].item ()
364
376
prediction_ary .append (item )
365
377
return prediction_ary
366
378
367
379
@torch .no_grad ()
368
- def predict (self , image_paths : List [str ] | str , rank : Rank , min_prob : float = 1e-9 , k : int = 5 ) -> dict [str , dict [str , float ]]:
369
- if isinstance (image_paths , str ):
370
- image_paths = [image_paths ]
371
- probs = self .create_probabilities_for_image_paths ( image_paths , self .txt_features )
380
+ def predict (self , images : List [str ] | str | List [ PIL . Image . Image ] , rank : Rank , min_prob : float = 1e-9 , k : int = 5 ) -> dict [str , dict [str , float ]]:
381
+ if isinstance (images , str ):
382
+ images = [images ]
383
+ probs = self .create_probabilities_for_images ( images , self .txt_features )
372
384
result = []
373
- for image_path in image_paths :
385
+ for i , image in enumerate (images ):
386
+ key = self .make_key (image , i )
374
387
if rank == Rank .SPECIES :
375
- result .extend (self .format_species_probs (image_path , probs [image_path ], k ))
388
+ result .extend (self .format_species_probs (key , probs [key ], k ))
376
389
else :
377
- result .extend (self .format_grouped_probs (image_path , probs [image_path ], rank , min_prob , k ))
390
+ result .extend (self .format_grouped_probs (key , probs [key ], rank , min_prob , k ))
378
391
return result
379
392
380
393
381
- def predict_classification (img : str , rank : Rank , device : Union [str , torch .device ] = 'cpu' ,
394
+ def predict_classification (img : Union [ PIL . Image . Image , str ] , rank : Rank , device : Union [str , torch .device ] = 'cpu' ,
382
395
min_prob : float = 1e-9 , k : int = 5 ) -> dict [str , float ]:
383
396
"""
384
397
Predicts from the entire tree of life.
0 commit comments