Skip to content

Commit 53c8c15

Browse files
Add bioimage search utils and example notebook (#553)
Add bioimage search utils and example notebook
1 parent bab5ada commit 53c8c15

File tree

6 files changed

+494
-175
lines changed

6 files changed

+494
-175
lines changed

apis/python/examples/object_api/bioimg_similarity_search.ipynb

Lines changed: 340 additions & 0 deletions
Large diffs are not rendered by default.

apis/python/src/tiledb/vector_search/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .huggingface_auto_image_embedding import HuggingfaceAutoImageEmbedding
12
from .image_resnetv2_embedding import ImageResNetV2Embedding
23
from .langchain_embedding import LangChainEmbedding
34
from .object_embedding import ObjectEmbedding
@@ -11,6 +12,7 @@
1112
"ObjectEmbedding",
1213
"SomaGenePTwEmbedding",
1314
"ImageResNetV2Embedding",
15+
"HuggingfaceAutoImageEmbedding",
1416
"RandomEmbedding",
1517
"SentenceTransformersEmbedding",
1618
"LangChainEmbedding",
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Dict, Optional, OrderedDict
2+
3+
import numpy as np
4+
5+
6+
class HuggingfaceAutoImageEmbedding:
7+
def __init__(
8+
self,
9+
model_name_or_path: str,
10+
dimensions: int,
11+
device: Optional[str] = None,
12+
cache_folder: Optional[str] = None,
13+
batch_size: int = 64,
14+
):
15+
self.model_name_or_path = model_name_or_path
16+
self.dim_num = dimensions
17+
self.device = device
18+
self.cache_folder = cache_folder
19+
self.batch_size = batch_size
20+
self.processor = None
21+
self.model = None
22+
23+
def init_kwargs(self) -> Dict:
24+
return {
25+
"model_name_or_path": self.model_name_or_path,
26+
"dimensions": self.dim_num,
27+
"device": self.device,
28+
"cache_folder": self.cache_folder,
29+
"batch_size": self.batch_size,
30+
}
31+
32+
def dimensions(self) -> int:
33+
return self.dim_num
34+
35+
def vector_type(self) -> np.dtype:
36+
return np.float32
37+
38+
def load(self) -> None:
39+
from transformers import AutoImageProcessor
40+
from transformers import AutoModel
41+
42+
self.processor = AutoImageProcessor.from_pretrained(self.model_name_or_path)
43+
self.model = AutoModel.from_pretrained(self.model_name_or_path)
44+
45+
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray:
46+
from PIL import Image
47+
48+
write_id = 0
49+
count = 0
50+
image_batch = []
51+
size = len(objects["image"])
52+
embeddings = np.zeros((size, self.dim_num), dtype=np.float32)
53+
for image_id in range(len(objects["image"])):
54+
image_batch.append(
55+
Image.fromarray(
56+
np.reshape(objects["image"][image_id], objects["shape"][image_id])
57+
)
58+
)
59+
count += 1
60+
if count >= self.batch_size:
61+
print(image_id)
62+
inputs = self.processor(images=image_batch, return_tensors="pt")
63+
batch_embeddings = (
64+
self.model(**inputs).last_hidden_state[:, 0].cpu().detach().numpy()
65+
)
66+
embeddings[write_id : write_id + count] = batch_embeddings
67+
count = 0
68+
image_batch = []
69+
70+
if count > 0:
71+
inputs = self.processor(images=image_batch, return_tensors="pt")
72+
batch_embeddings = (
73+
self.model(**inputs).last_hidden_state[:, 0].cpu().detach().numpy()
74+
)
75+
embeddings[write_id : write_id + count] = batch_embeddings
76+
return embeddings

apis/python/src/tiledb/vector_search/object_readers/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .bioimage_reader import BioImagePartition
2-
from .bioimage_reader import BioImageReader
1+
from .bioimage_reader import BioImageDirectoryReader
32
from .directory_reader import DirectoryImageReader
43
from .directory_reader import DirectoryPartition
54
from .directory_reader import DirectoryReader
@@ -18,8 +17,7 @@
1817
"SomaAnnDataReader",
1918
"TileDB1DArrayPartition",
2019
"TileDB1DArrayReader",
21-
"BioImagePartition",
22-
"BioImageReader",
20+
"BioImageDirectoryReader",
2321
"DirectoryReader",
2422
"DirectoryTextReader",
2523
"DirectoryImageReader",
Lines changed: 62 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,52 @@
1-
from typing import Any, Dict, List, Mapping, Optional, OrderedDict, Tuple
1+
from typing import Dict, List, Optional, OrderedDict, Sequence, Tuple
22

33
import numpy as np
44

55
import tiledb
6-
7-
# from tiledb.vector_search.object_readers import ObjectPartition, ObjectReader
6+
from tiledb.vector_search.object_readers.directory_reader import DirectoryImageReader
7+
from tiledb.vector_search.object_readers.directory_reader import DirectoryPartition
88

99
MAX_IMAGE_CROPS_PER_IMAGE = 10000
1010

1111

12-
# class BioImagePartition(ObjectPartition):
13-
class BioImagePartition:
12+
class BioImageDirectoryReader(DirectoryImageReader):
1413
def __init__(
1514
self,
16-
partition_id: int,
17-
image_uris: List[str],
18-
image_id_start: int,
19-
):
20-
self.partition_id = partition_id
21-
self.image_uris = image_uris
22-
self.image_id_start = image_id_start
23-
24-
def init_kwargs(self) -> Dict:
25-
return {
26-
"partition_id": self.partition_id,
27-
"image_uris": self.image_uris,
28-
"image_id_start": self.image_id_start,
29-
}
30-
31-
def id(self) -> int:
32-
return self.partition_id
33-
34-
35-
# class BioImageReader(ObjectReader):
36-
class BioImageReader:
37-
def __init__(
38-
self,
39-
uri: str,
15+
search_uri: str,
16+
include: str = "*",
17+
exclude: Sequence[str] = ["[.]*", "*/[.]*"],
18+
suffixes: Optional[Sequence[str]] = None,
19+
max_files: Optional[int] = None,
4020
level: int = -1,
4121
object_crop_shape: Tuple[int, int] = None,
42-
config: Optional[Mapping[str, Any]] = None,
4322
timestamp=None,
4423
):
45-
self.uri = uri
24+
super().__init__(
25+
search_uri=search_uri,
26+
include=include,
27+
exclude=exclude,
28+
suffixes=suffixes,
29+
max_files=max_files,
30+
)
4631
self.level = level
4732
self.object_crop_shape = object_crop_shape
48-
self.config = config
4933
self.timestamp = timestamp
5034
self.images = None
5135

5236
def init_kwargs(self) -> Dict:
5337
return {
54-
"uri": self.uri,
38+
"search_uri": self.search_uri,
39+
"include": self.include,
40+
"exclude": self.exclude,
41+
"suffixes": self.suffixes,
42+
"max_files": self.max_files,
5543
"level": self.level,
5644
"object_crop_shape": self.object_crop_shape,
57-
"config": self.config,
5845
"timestamp": self.timestamp,
5946
}
6047

6148
def partition_class_name(self) -> str:
62-
return "BioImagePartition"
49+
return "DirectoryPartition"
6350

6451
def metadata_array_uri(self) -> str:
6552
return None
@@ -76,157 +63,61 @@ def metadata_attributes(self) -> List[tiledb.Attr]:
7663
)
7764
return [image_uri_attr, location_attr]
7865

79-
def get_partitions(
80-
self, images_per_partitions: int = -1, **kwargs
81-
) -> List[BioImagePartition]:
82-
if images_per_partitions == -1:
83-
images_per_partitions = 1
84-
if self.images is None:
85-
vfs = tiledb.VFS(config=self.config)
86-
self.images = vfs.ls(self.uri)[1:]
87-
num_images = len(self.images)
88-
partitions = []
89-
partition_id = 0
90-
for start in range(0, num_images, images_per_partitions):
91-
end = start + images_per_partitions
92-
if end > num_images:
93-
end = num_images
94-
partitions.append(
95-
BioImagePartition(
96-
partition_id,
97-
image_uris=self.images[start:end],
98-
image_id_start=start,
99-
)
100-
)
101-
partition_id += 1
102-
return partitions
103-
10466
def read_objects(
105-
self, partition: BioImagePartition
67+
self, partition: DirectoryPartition
10668
) -> Tuple[OrderedDict, OrderedDict]:
10769
from tiledb.bioimg.openslide import TileDBOpenSlide
10870

109-
def compute_external_id() -> int:
110-
id = image_id * MAX_IMAGE_CROPS_PER_IMAGE + image_iter_id
111-
return id
112-
113-
def crop_image(dim_0_start, dim_0_end, dim_1_start, dim_1_end):
71+
def crop_image(path, dim_0_start, dim_0_end, dim_1_start, dim_1_end):
11472
cropped_image = image[dim_0_start:dim_0_end, dim_1_start:dim_1_end]
11573
images[write_id] = cropped_image.flatten()
11674
shapes[write_id] = np.array(cropped_image.shape, dtype=np.uint32)
117-
image_uris[write_id] = image_uri
75+
image_uris[write_id] = path
11876
locations[write_id] = np.array(
11977
[dim_0_start, dim_0_end, dim_1_start, dim_1_end], dtype=np.uint32
12078
)
121-
external_ids[write_id] = compute_external_id()
122-
123-
with tiledb.scope_ctx(ctx_or_config=self.config):
124-
max_size = MAX_IMAGE_CROPS_PER_IMAGE * len(partition.image_uris)
125-
images = np.empty(max_size, dtype="O")
126-
shapes = np.empty(max_size, dtype="O")
127-
external_ids = np.zeros(max_size, dtype=np.uint64)
128-
image_uris = np.empty(max_size, dtype="O")
129-
locations = np.empty(max_size, dtype="O")
130-
write_id = 0
131-
image_id = partition.image_id_start
132-
for image_uri in partition.image_uris:
133-
image_iter_id = 0
134-
slide = TileDBOpenSlide(image_uri)
135-
level_dimensions = slide.level_dimensions[self.level]
136-
image = slide.read_region((0, 0), self.level, level_dimensions)
137-
if self.object_crop_shape is None:
138-
crop_image(0, level_dimensions[1], 0, level_dimensions[0])
139-
write_id += 1
140-
else:
141-
for dim_0_start in range(
142-
0, level_dimensions[1], self.object_crop_shape[0]
143-
):
144-
for dim_1_start in range(
145-
0, level_dimensions[0], self.object_crop_shape[1]
146-
):
147-
dim_0_end = min(
148-
dim_0_start + self.object_crop_shape[0],
149-
level_dimensions[1],
150-
)
151-
dim_1_end = min(
152-
dim_1_start + self.object_crop_shape[1],
153-
level_dimensions[0],
154-
)
155-
crop_image(dim_0_start, dim_0_end, dim_1_start, dim_1_end)
156-
write_id += 1
157-
image_iter_id += 1
158-
image_id += 1
159-
return (
160-
{
161-
"image": images[0:write_id],
162-
"shape": shapes[0:write_id],
163-
"external_id": external_ids[0:write_id],
164-
},
165-
{
166-
"image_uri": image_uris[0:write_id],
167-
"location": locations[0:write_id],
168-
"external_id": external_ids[0:write_id],
169-
},
170-
)
171-
172-
def read_objects_by_external_ids(self, ids: List[int]) -> OrderedDict:
173-
from tiledb.bioimg.openslide import TileDBOpenSlide
174-
175-
def crop_image():
176-
i = 0
79+
external_ids[write_id] = abs(hash(f"{path}_{dim_0_start}_{dim_1_start}"))
80+
81+
max_size = MAX_IMAGE_CROPS_PER_IMAGE * len(partition.paths)
82+
images = np.empty(max_size, dtype="O")
83+
shapes = np.empty(max_size, dtype="O")
84+
external_ids = np.zeros(max_size, dtype=np.uint64)
85+
image_uris = np.empty(max_size, dtype="O")
86+
locations = np.empty(max_size, dtype="O")
87+
write_id = 0
88+
for path in partition.paths:
89+
slide = TileDBOpenSlide(path)
90+
level_dimensions = slide.level_dimensions[self.level]
91+
image = slide.read_region((0, 0), self.level, level_dimensions)
17792
if self.object_crop_shape is None:
178-
if image_iter_id == i:
179-
images[write_id] = image.flatten()
180-
shapes[write_id] = np.array(image.shape, dtype=np.uint32)
181-
external_ids[write_id] = external_id
182-
return
93+
crop_image(path, 0, level_dimensions[1], 0, level_dimensions[0])
94+
write_id += 1
18395
else:
18496
for dim_0_start in range(
18597
0, level_dimensions[1], self.object_crop_shape[0]
18698
):
18799
for dim_1_start in range(
188100
0, level_dimensions[0], self.object_crop_shape[1]
189101
):
190-
if image_iter_id == i:
191-
dim_0_end = min(
192-
dim_0_start + self.object_crop_shape[0],
193-
level_dimensions[1],
194-
)
195-
dim_1_end = min(
196-
dim_1_start + self.object_crop_shape[1],
197-
level_dimensions[0],
198-
)
199-
cropped_image = image[
200-
dim_0_start:dim_0_end, dim_1_start:dim_1_end
201-
]
202-
images[write_id] = cropped_image.flatten()
203-
shapes[write_id] = np.array(
204-
cropped_image.shape, dtype=np.uint32
205-
)
206-
external_ids[write_id] = external_id
207-
return
208-
i += 1
209-
210-
with tiledb.scope_ctx(ctx_or_config=self.config):
211-
size = len(ids)
212-
images = np.empty(size, dtype="O")
213-
shapes = np.empty(size, dtype="O")
214-
external_ids = np.zeros(size, dtype=np.uint64)
215-
if self.images is None:
216-
vfs = tiledb.VFS(config=self.config)
217-
self.images = vfs.ls(self.uri)[1:]
218-
219-
image_id = -1
220-
write_id = 0
221-
for external_id in ids:
222-
new_image_id = external_id // MAX_IMAGE_CROPS_PER_IMAGE
223-
image_iter_id = external_id % MAX_IMAGE_CROPS_PER_IMAGE
224-
if new_image_id != image_id:
225-
# Load image
226-
image_id = new_image_id
227-
slide = TileDBOpenSlide(self.images[image_id])
228-
level_dimensions = slide.level_dimensions[self.level]
229-
image = slide.read_region((0, 0), self.level, level_dimensions)
230-
crop_image()
231-
write_id += 1
232-
return {"image": images, "shape": shapes, "external_id": external_ids}
102+
dim_0_end = min(
103+
dim_0_start + self.object_crop_shape[0],
104+
level_dimensions[1],
105+
)
106+
dim_1_end = min(
107+
dim_1_start + self.object_crop_shape[1],
108+
level_dimensions[0],
109+
)
110+
crop_image(path, dim_0_start, dim_0_end, dim_1_start, dim_1_end)
111+
write_id += 1
112+
return (
113+
{
114+
"image": images[0:write_id],
115+
"shape": shapes[0:write_id],
116+
"external_id": external_ids[0:write_id],
117+
},
118+
{
119+
"image_uri": image_uris[0:write_id],
120+
"location": locations[0:write_id],
121+
"external_id": external_ids[0:write_id],
122+
},
123+
)

0 commit comments

Comments
 (0)