Skip to content

Commit a8d3a58

Browse files
authored
Merge pull request #16 from Imageomics/20-txt-cache
Avoid recomputing text embeddings for custom labels
2 parents d6ebb1d + 0912555 commit a8d3a58

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ df = pd.DataFrame(predictions)
8989
```python
9090
from bioclip import CustomLabelsClassifier
9191

92-
classifier = CustomLabelsClassifier()
93-
predictions = classifier.predict("Ursus-arctos.jpeg", ["duck","fish","bear"])
92+
classifier = CustomLabelsClassifier(["duck","fish","bear"])
93+
predictions = classifier.predict("Ursus-arctos.jpeg")
9494
for prediction in predictions:
9595
print(prediction["classification"], prediction["score"])
9696
```

src/bioclip/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def write_results_to_file(df, format, outfile):
3131
def predict(image_file: list[str], format: str, output: str,
3232
cls_str: str, device: str, rank: Rank, k: int):
3333
if cls_str:
34-
classifier = CustomLabelsClassifier(device=device)
34+
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), device=device)
3535
data = []
3636
for image_path in image_file:
37-
data.extend(classifier.predict(image_path=image_path, cls_ary=cls_str.split(',')))
37+
data.extend(classifier.predict(image_path=image_path))
3838
write_results(data, format, output)
3939
else:
4040
classifier = TreeOfLifeClassifier(device=device)

src/bioclip/predict.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,16 @@ def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
166166

167167

168168
class CustomLabelsClassifier(object):
169-
def __init__(self, device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
169+
def __init__(self, cls_ary: List[str], device: Union[str, torch.device] = 'cpu', model_str: str = MODEL_STR):
170170
self.device = device
171171
self.model = create_bioclip_model(device=device, model_str=model_str)
172172
self.model_str = model_str
173173
self.tokenizer = create_bioclip_tokenizer()
174+
self.classes = [cls.strip() for cls in cls_ary]
175+
self.txt_features = self._get_txt_features(self.classes)
174176

175-
def get_txt_features(self, classnames):
177+
@torch.no_grad()
178+
def _get_txt_features(self, classnames):
176179
all_features = []
177180
for classname in classnames:
178181
txts = [template(classname) for template in OPENA_AI_IMAGENET_TEMPLATE]
@@ -185,19 +188,17 @@ def get_txt_features(self, classnames):
185188
return all_features
186189

187190
@torch.no_grad()
188-
def predict(self, image_path: str, cls_ary: List[str]) -> dict[str, float]:
191+
def predict(self, image_path: str) -> dict[str, float]:
189192
img = open_image(image_path)
190-
classes = [cls.strip() for cls in cls_ary]
191-
txt_features = self.get_txt_features(classes)
192193

193194
img = preprocess_img(img).to(self.device)
194195
img_features = self.model.encode_image(img.unsqueeze(0))
195196
img_features = F.normalize(img_features, dim=-1)
196197

197-
logits = (self.model.logit_scale.exp() * img_features @ txt_features).squeeze()
198+
logits = (self.model.logit_scale.exp() * img_features @ self.txt_features).squeeze()
198199
probs = F.softmax(logits, dim=0).to("cpu").tolist()
199200
pred_list = []
200-
for cls, prob in zip(classes, probs):
201+
for cls, prob in zip(self.classes, probs):
201202
pred_list.append({
202203
PRED_FILENAME_KEY: image_path,
203204
PRED_CLASSICATION_KEY: cls,
@@ -207,8 +208,8 @@ def predict(self, image_path: str, cls_ary: List[str]) -> dict[str, float]:
207208

208209

209210
def predict_classifications_from_list(img: Union[PIL.Image.Image, str], cls_ary: List[str], device: Union[str, torch.device] = 'cpu') -> dict[str, float]:
210-
classifier = CustomLabelsClassifier(device=device)
211-
return classifier.predict(img, cls_ary)
211+
classifier = CustomLabelsClassifier(cls_ary=cls_ary, device=device)
212+
return classifier.predict(img)
212213

213214

214215
def get_tol_classification_labels(rank: Rank) -> List[str]:
@@ -293,7 +294,6 @@ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank,
293294
for name in topk_names:
294295
item = { PRED_FILENAME_KEY: image_path }
295296
item.update(name_to_class_dict[name])
296-
#item.update(class_dict_lookup)
297297
item[PRED_SCORE_KEY] = output[name].item()
298298
prediction_ary.append(item)
299299
return prediction_ary

tests/test_predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def test_tree_of_life_classifier_family(self):
4545
self.assertEqual(prediction_ary[0], prediction_dict)
4646

4747
def test_custom_labels_classifier(self):
48-
classifier = CustomLabelsClassifier()
49-
results = classifier.predict(image_path=EXAMPLE_CAT_IMAGE, cls_ary=['cat', 'dog'])
48+
classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog'])
49+
results = classifier.predict(image_path=EXAMPLE_CAT_IMAGE)
5050
self.assertEqual(results, [
5151
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'cat', 'score': unittest.mock.ANY},
5252
{'file_name': EXAMPLE_CAT_IMAGE, 'classification': 'dog', 'score': unittest.mock.ANY},

0 commit comments

Comments
 (0)