Skip to content

Commit e109734

Browse files
authored
Merge pull request #425 from aperture-data/release-0.4.25
Release 0.4.25
2 parents 5905396 + 588fd9c commit e109734

File tree

4 files changed

+150
-3
lines changed

4 files changed

+150
-3
lines changed

aperturedb/Descriptors.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import logging
2+
3+
import numpy as np
4+
5+
from aperturedb.Entities import Entities
6+
from aperturedb.ParallelQuery import execute_batch
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class Descriptors(Entities):
12+
"""
13+
Python wrapper for ApertureDB Descriptors API.
14+
"""
15+
16+
db_object = "_Descriptor"
17+
18+
def __init__(self, db):
19+
super().__init__(db)
20+
21+
def find_similar(self,
22+
set: str,
23+
vector,
24+
k_neighbors: int,
25+
constraints=None,
26+
distances: bool = False,
27+
blobs: bool = False,
28+
results={"all_properties": True}):
29+
"""
30+
Find similar descriptor sets to the input descriptor set.
31+
32+
Args:
33+
set (str): Descriptor set name.
34+
vector (list): Input descriptor set vector.
35+
k_neighbors (int): Number of neighbors to return.
36+
distances (bool): Return similarity metric values.
37+
blobs (bool): Return vectors of the neighbors.
38+
results (dict): Dictionary with the results format.
39+
Defaults to all properties.
40+
41+
Returns:
42+
results: Response from the server.
43+
"""
44+
45+
command = {"FindDescriptor": {
46+
"set": set,
47+
"distances": distances,
48+
"blobs": blobs,
49+
"results": results,
50+
"k_neighbors": k_neighbors,
51+
}}
52+
53+
if constraints is not None:
54+
command["FindDescriptor"]["constraints"] = constraints.constraints
55+
56+
query = [command]
57+
blobs_in = [np.array(vector, dtype=np.float32).tobytes()]
58+
_, response, blobs_out = execute_batch(query, blobs_in, self.db)
59+
60+
self.response = response[0]["FindDescriptor"]["entities"]
61+
62+
if blobs:
63+
for i, entity in enumerate(self.response):
64+
entity["vector"] = np.frombuffer(
65+
blobs_out[i], dtype=np.float32)
66+
67+
def _descriptorset_metric(self, set: str):
68+
"""Find default metric for descriptor set"""
69+
command = {"FindDescriptorSet": {"with_name": set, "metrics": True}}
70+
query = [command]
71+
response, _ = self.db.query(query)
72+
logger.debug(response)
73+
assert self.db.last_query_ok(), response
74+
return response[0]["FindDescriptorSet"]['entities'][0]["_metrics"][0]
75+
76+
def _vector_similarity(self, v1, v2):
77+
"""Find similarity between two vectors using the metric of the descriptor set."""
78+
if self.metric == "CS":
79+
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
80+
elif self.metric == "L2":
81+
# negate to turn distance into similarity
82+
return -np.linalg.norm(v1 - v2)
83+
elif self.metric == "IP":
84+
return np.dot(v1, v2)
85+
else:
86+
raise ValueError("Unknown metric: %s" % self.metric)
87+
88+
def find_similar_mmr(self,
89+
set: str,
90+
vector,
91+
k_neighbors: int,
92+
fetch_k: int,
93+
lambda_mult: float = 0.5,
94+
**kwargs):
95+
"""
96+
As find_similar, but using the MMR algorithm to diversify the results.
97+
98+
Args:
99+
100+
set (str): Descriptor set name.
101+
vector (list): Input descriptor set vector.
102+
k_neighbors (int): Number of results to return.
103+
fetch_k (int): Number of neighbors to fetch from the database.
104+
lambda_mult (float): Lambda multiplier for the MMR algorithm.
105+
Defaults to 0.5. 1.0 means no diversity.
106+
"""
107+
self.metric = self._descriptorset_metric(set)
108+
vector = np.array(vector, dtype=np.float32)
109+
110+
kwargs["blobs"] = True # force vector return
111+
self.find_similar(set, vector, fetch_k, **kwargs)
112+
113+
# MMR algorithm
114+
# Calculate similarity between query and all documents
115+
query_similarity = [self._vector_similarity(
116+
vector, d["vector"]) for d in self]
117+
# Calculate similarity between all pairs of documents
118+
document_similarity = {}
119+
for i, d in enumerate(self):
120+
for j, d2 in enumerate(self[i + 1:], i + 1):
121+
similarity = self._vector_similarity(d["vector"], d2["vector"])
122+
document_similarity[(i, j)] = similarity
123+
document_similarity[(j, i)] = similarity
124+
125+
# We just gather indexes here, not the actual entities
126+
selected = []
127+
unselected = list(range(len(self)))
128+
129+
while len(selected) < k_neighbors and unselected:
130+
if not selected:
131+
selected.append(0)
132+
unselected.remove(0)
133+
else:
134+
selected_unselected_similarity = np.array(
135+
[[document_similarity[(i, j)] for j in unselected] for i in selected])
136+
worst_similarity = np.max(
137+
selected_unselected_similarity, axis=0)
138+
relevance_scores = np.array(
139+
[query_similarity[i] for i in unselected])
140+
scores = (1 - lambda_mult) * worst_similarity + \
141+
lambda_mult * relevance_scores
142+
max_index = unselected[np.argmax(scores)]
143+
selected.append(max_index)
144+
unselected.remove(max_index)
145+
logger.info("Selected indexes: %s; unselected %s",
146+
selected, unselected)
147+
self.response = [self[i] for i in selected]

aperturedb/EntityUpdateDataCSV.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class SingleEntityUpdateDataCSV(CSVParser.CSVParser):
2525
- a series of updateif_ to determine if an update is necessary
2626
2727
Conditionals:
28-
updateif>_prop - updates if the databset value > csv value
28+
updateif>_prop - updates if the database value > csv value
2929
updateif<_prop - updates if the database value < csv value
3030
updateif!_prop - updates if the database value is != csv value
3131

aperturedb/ParallelLoader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def ingest(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool
7979
8080
Args:
8181
generator (_type_): The list of data, or a class derived from [Subscriptable](/python_sdk/helpers/Subscriptable) to be ingested.
82-
batchsize (int, optional): The size of batch to be ussed. Defaults to 1.
82+
batchsize (int, optional): The size of batch to be used. Defaults to 1.
8383
numthreads (int, optional): Number of workers to create. Defaults to 4.
8484
stats (bool, optional): If stats need to be presented, realtime. Defaults to False.
8585
"""

aperturedb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
logger = logging.getLogger(__name__)
99

10-
__version__ = "0.4.24"
10+
__version__ = "0.4.25"
1111

1212
# set log level
1313
logger.setLevel(logging.DEBUG)

0 commit comments

Comments
 (0)