@@ -166,13 +166,16 @@ def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
166
166
167
167
168
168
class CustomLabelsClassifier (object ):
169
- def __init__ (self , device : Union [str , torch .device ] = 'cpu' , model_str : str = MODEL_STR ):
169
+ def __init__ (self , cls_ary : List [ str ], device : Union [str , torch .device ] = 'cpu' , model_str : str = MODEL_STR ):
170
170
self .device = device
171
171
self .model = create_bioclip_model (device = device , model_str = model_str )
172
172
self .model_str = model_str
173
173
self .tokenizer = create_bioclip_tokenizer ()
174
+ self .classes = [cls .strip () for cls in cls_ary ]
175
+ self .txt_features = self ._get_txt_features (self .classes )
174
176
175
- def get_txt_features (self , classnames ):
177
+ @torch .no_grad ()
178
+ def _get_txt_features (self , classnames ):
176
179
all_features = []
177
180
for classname in classnames :
178
181
txts = [template (classname ) for template in OPENA_AI_IMAGENET_TEMPLATE ]
@@ -185,19 +188,17 @@ def get_txt_features(self, classnames):
185
188
return all_features
186
189
187
190
@torch .no_grad ()
188
- def predict (self , image_path : str , cls_ary : List [ str ] ) -> dict [str , float ]:
191
+ def predict (self , image_path : str ) -> dict [str , float ]:
189
192
img = open_image (image_path )
190
- classes = [cls .strip () for cls in cls_ary ]
191
- txt_features = self .get_txt_features (classes )
192
193
193
194
img = preprocess_img (img ).to (self .device )
194
195
img_features = self .model .encode_image (img .unsqueeze (0 ))
195
196
img_features = F .normalize (img_features , dim = - 1 )
196
197
197
- logits = (self .model .logit_scale .exp () * img_features @ txt_features ).squeeze ()
198
+ logits = (self .model .logit_scale .exp () * img_features @ self . txt_features ).squeeze ()
198
199
probs = F .softmax (logits , dim = 0 ).to ("cpu" ).tolist ()
199
200
pred_list = []
200
- for cls , prob in zip (classes , probs ):
201
+ for cls , prob in zip (self . classes , probs ):
201
202
pred_list .append ({
202
203
PRED_FILENAME_KEY : image_path ,
203
204
PRED_CLASSICATION_KEY : cls ,
@@ -207,8 +208,8 @@ def predict(self, image_path: str, cls_ary: List[str]) -> dict[str, float]:
207
208
208
209
209
210
def predict_classifications_from_list (img : Union [PIL .Image .Image , str ], cls_ary : List [str ], device : Union [str , torch .device ] = 'cpu' ) -> dict [str , float ]:
210
- classifier = CustomLabelsClassifier (device = device )
211
- return classifier .predict (img , cls_ary )
211
+ classifier = CustomLabelsClassifier (cls_ary = cls_ary , device = device )
212
+ return classifier .predict (img )
212
213
213
214
214
215
def get_tol_classification_labels (rank : Rank ) -> List [str ]:
@@ -293,7 +294,6 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
293
294
for name in topk_names :
294
295
item = { PRED_FILENAME_KEY : image_path }
295
296
item .update (name_to_class_dict [name ])
296
- #item.update(class_dict_lookup)
297
297
item [PRED_SCORE_KEY ] = output [name ].item ()
298
298
prediction_ary .append (item )
299
299
return prediction_ary
0 commit comments