|
| 1 | +import logging |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from aperturedb.Entities import Entities |
| 6 | +from aperturedb.ParallelQuery import execute_batch |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +class Descriptors(Entities): |
| 12 | + """ |
| 13 | + Python wrapper for ApertureDB Descriptors API. |
| 14 | + """ |
| 15 | + |
| 16 | + db_object = "_Descriptor" |
| 17 | + |
| 18 | + def __init__(self, db): |
| 19 | + super().__init__(db) |
| 20 | + |
| 21 | + def find_similar(self, |
| 22 | + set: str, |
| 23 | + vector, |
| 24 | + k_neighbors: int, |
| 25 | + constraints=None, |
| 26 | + distances: bool = False, |
| 27 | + blobs: bool = False, |
| 28 | + results={"all_properties": True}): |
| 29 | + """ |
| 30 | + Find similar descriptor sets to the input descriptor set. |
| 31 | +
|
| 32 | + Args: |
| 33 | + set (str): Descriptor set name. |
| 34 | + vector (list): Input descriptor set vector. |
| 35 | + k_neighbors (int): Number of neighbors to return. |
| 36 | + distances (bool): Return similarity metric values. |
| 37 | + blobs (bool): Return vectors of the neighbors. |
| 38 | + results (dict): Dictionary with the results format. |
| 39 | + Defaults to all properties. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + results: Response from the server. |
| 43 | + """ |
| 44 | + |
| 45 | + command = {"FindDescriptor": { |
| 46 | + "set": set, |
| 47 | + "distances": distances, |
| 48 | + "blobs": blobs, |
| 49 | + "results": results, |
| 50 | + "k_neighbors": k_neighbors, |
| 51 | + }} |
| 52 | + |
| 53 | + if constraints is not None: |
| 54 | + command["FindDescriptor"]["constraints"] = constraints.constraints |
| 55 | + |
| 56 | + query = [command] |
| 57 | + blobs_in = [np.array(vector, dtype=np.float32).tobytes()] |
| 58 | + _, response, blobs_out = execute_batch(query, blobs_in, self.db) |
| 59 | + |
| 60 | + self.response = response[0]["FindDescriptor"]["entities"] |
| 61 | + |
| 62 | + if blobs: |
| 63 | + for i, entity in enumerate(self.response): |
| 64 | + entity["vector"] = np.frombuffer( |
| 65 | + blobs_out[i], dtype=np.float32) |
| 66 | + |
| 67 | + def _descriptorset_metric(self, set: str): |
| 68 | + """Find default metric for descriptor set""" |
| 69 | + command = {"FindDescriptorSet": {"with_name": set, "metrics": True}} |
| 70 | + query = [command] |
| 71 | + response, _ = self.db.query(query) |
| 72 | + logger.debug(response) |
| 73 | + assert self.db.last_query_ok(), response |
| 74 | + return response[0]["FindDescriptorSet"]['entities'][0]["_metrics"][0] |
| 75 | + |
| 76 | + def _vector_similarity(self, v1, v2): |
| 77 | + """Find similarity between two vectors using the metric of the descriptor set.""" |
| 78 | + if self.metric == "CS": |
| 79 | + return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) |
| 80 | + elif self.metric == "L2": |
| 81 | + # negate to turn distance into similarity |
| 82 | + return -np.linalg.norm(v1 - v2) |
| 83 | + elif self.metric == "IP": |
| 84 | + return np.dot(v1, v2) |
| 85 | + else: |
| 86 | + raise ValueError("Unknown metric: %s" % self.metric) |
| 87 | + |
| 88 | + def find_similar_mmr(self, |
| 89 | + set: str, |
| 90 | + vector, |
| 91 | + k_neighbors: int, |
| 92 | + fetch_k: int, |
| 93 | + lambda_mult: float = 0.5, |
| 94 | + **kwargs): |
| 95 | + """ |
| 96 | + As find_similar, but using the MMR algorithm to diversify the results. |
| 97 | +
|
| 98 | + Args: |
| 99 | +
|
| 100 | + set (str): Descriptor set name. |
| 101 | + vector (list): Input descriptor set vector. |
| 102 | + k_neighbors (int): Number of results to return. |
| 103 | + fetch_k (int): Number of neighbors to fetch from the database. |
| 104 | + lambda_mult (float): Lambda multiplier for the MMR algorithm. |
| 105 | + Defaults to 0.5. 1.0 means no diversity. |
| 106 | + """ |
| 107 | + self.metric = self._descriptorset_metric(set) |
| 108 | + vector = np.array(vector, dtype=np.float32) |
| 109 | + |
| 110 | + kwargs["blobs"] = True # force vector return |
| 111 | + self.find_similar(set, vector, fetch_k, **kwargs) |
| 112 | + |
| 113 | + # MMR algorithm |
| 114 | + # Calculate similarity between query and all documents |
| 115 | + query_similarity = [self._vector_similarity( |
| 116 | + vector, d["vector"]) for d in self] |
| 117 | + # Calculate similarity between all pairs of documents |
| 118 | + document_similarity = {} |
| 119 | + for i, d in enumerate(self): |
| 120 | + for j, d2 in enumerate(self[i + 1:], i + 1): |
| 121 | + similarity = self._vector_similarity(d["vector"], d2["vector"]) |
| 122 | + document_similarity[(i, j)] = similarity |
| 123 | + document_similarity[(j, i)] = similarity |
| 124 | + |
| 125 | + # We just gather indexes here, not the actual entities |
| 126 | + selected = [] |
| 127 | + unselected = list(range(len(self))) |
| 128 | + |
| 129 | + while len(selected) < k_neighbors and unselected: |
| 130 | + if not selected: |
| 131 | + selected.append(0) |
| 132 | + unselected.remove(0) |
| 133 | + else: |
| 134 | + selected_unselected_similarity = np.array( |
| 135 | + [[document_similarity[(i, j)] for j in unselected] for i in selected]) |
| 136 | + worst_similarity = np.max( |
| 137 | + selected_unselected_similarity, axis=0) |
| 138 | + relevance_scores = np.array( |
| 139 | + [query_similarity[i] for i in unselected]) |
| 140 | + scores = (1 - lambda_mult) * worst_similarity + \ |
| 141 | + lambda_mult * relevance_scores |
| 142 | + max_index = unselected[np.argmax(scores)] |
| 143 | + selected.append(max_index) |
| 144 | + unselected.remove(max_index) |
| 145 | + logger.info("Selected indexes: %s; unselected %s", |
| 146 | + selected, unselected) |
| 147 | + self.response = [self[i] for i in selected] |
0 commit comments