@@ -119,6 +119,11 @@ def get_txt_names():
119
119
return txt_names
120
120
121
121
122
+ def open_image (image_path ):
123
+ img = PIL .Image .open (image_path )
124
+ return img .convert ("RGB" )
125
+
126
+
122
127
preprocess_img = transforms .Compose (
123
128
[
124
129
transforms .ToTensor (),
@@ -181,7 +186,7 @@ def get_txt_features(self, classnames):
181
186
182
187
@torch .no_grad ()
183
188
def predict (self , image_path : str , cls_ary : List [str ]) -> dict [str , float ]:
184
- img = PIL . Image . open (image_path )
189
+ img = open_image (image_path )
185
190
classes = [cls .strip () for cls in cls_ary ]
186
191
txt_features = self .get_txt_features (classes )
187
192
@@ -248,7 +253,7 @@ def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MO
248
253
249
254
@torch .no_grad ()
250
255
def get_image_features (self , image_path : str ) -> torch .Tensor :
251
- img = PIL . Image . open (image_path )
256
+ img = open_image (image_path )
252
257
return self .encode_image (img )
253
258
254
259
def encode_image (self , img : PIL .Image .Image ) -> torch .Tensor :
@@ -295,7 +300,7 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
295
300
296
301
@torch .no_grad ()
297
302
def predict (self , image_path : str , rank : Rank , min_prob : float = 1e-9 , k : int = 5 ) -> List [dict [str , float ]]:
298
- img = PIL . Image . open (image_path )
303
+ img = open_image (image_path )
299
304
probs = self .predict_species (img )
300
305
if rank == Rank .SPECIES :
301
306
return self .format_species_probs (image_path , probs , k )
0 commit comments