Skip to content

Commit 6832586

Browse files
authored
Merge pull request #453 from aperture-data/release-0.4.30
Release 0.4.30
2 parents a6ad860 + 75d03da commit 6832586

File tree

7 files changed

+193
-63
lines changed

7 files changed

+193
-63
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ ci:
1010
skip: []
1111
submodules: false
1212
repos:
13-
- repo: https://github.com/pre-commit/mirrors-autopep8
14-
rev: 898691a
13+
- repo: https://github.com/hhatto/autopep8
14+
rev: 8b75604
1515
hooks:
1616
- id: autopep8
1717
exclude: _pb2.py$
18-
args: ["--ignore", "E251,E241,E221", "-i"]
18+
args: ["--ignore", "E251,E241,E221,E402,E265,E275", "-i"]

aperturedb/Descriptors.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@ class Descriptors(Entities):
1818
def __init__(self, db):
1919
super().__init__(db)
2020

21-
def find_similar(self,
22-
set: str,
23-
vector,
24-
k_neighbors: int,
25-
constraints=None,
26-
distances: bool = False,
27-
blobs: bool = False,
28-
results={"all_properties": True}):
21+
def find_similar(
22+
self,
23+
set: str,
24+
vector,
25+
k_neighbors: int,
26+
constraints=None,
27+
distances: bool = False,
28+
blobs: bool = False,
29+
results={"all_properties": True},
30+
):
2931
"""
3032
Find similar descriptor sets to the input descriptor set.
3133
@@ -42,13 +44,15 @@ def find_similar(self,
4244
results: Response from the server.
4345
"""
4446

45-
command = {"FindDescriptor": {
46-
"set": set,
47-
"distances": distances,
48-
"blobs": blobs,
49-
"results": results,
50-
"k_neighbors": k_neighbors,
51-
}}
47+
command = {
48+
"FindDescriptor": {
49+
"set": set,
50+
"distances": distances,
51+
"blobs": blobs,
52+
"results": results,
53+
"k_neighbors": k_neighbors,
54+
}
55+
}
5256

5357
if constraints is not None:
5458
command["FindDescriptor"]["constraints"] = constraints.constraints
@@ -57,7 +61,7 @@ def find_similar(self,
5761
blobs_in = [np.array(vector, dtype=np.float32).tobytes()]
5862
_, response, blobs_out = execute_batch(query, blobs_in, self.db)
5963

60-
self.response = response[0]["FindDescriptor"]["entities"]
64+
self.response = response[0]["FindDescriptor"].get("entities", [])
6165

6266
if blobs:
6367
for i, entity in enumerate(self.response):
@@ -71,7 +75,7 @@ def _descriptorset_metric(self, set: str):
7175
response, _ = self.db.query(query)
7276
logger.debug(response)
7377
assert self.db.last_query_ok(), response
74-
return response[0]["FindDescriptorSet"]['entities'][0]["_metrics"][0]
78+
return response[0]["FindDescriptorSet"]["entities"][0]["_metrics"][0]
7579

7680
def _vector_similarity(self, v1, v2):
7781
"""Find similarity between two vectors using the metric of the descriptor set."""
@@ -85,13 +89,15 @@ def _vector_similarity(self, v1, v2):
8589
else:
8690
raise ValueError("Unknown metric: %s" % self.metric)
8791

88-
def find_similar_mmr(self,
89-
set: str,
90-
vector,
91-
k_neighbors: int,
92-
fetch_k: int,
93-
lambda_mult: float = 0.5,
94-
**kwargs):
92+
def find_similar_mmr(
93+
self,
94+
set: str,
95+
vector,
96+
k_neighbors: int,
97+
fetch_k: int,
98+
lambda_mult: float = 0.5,
99+
**kwargs,
100+
):
95101
"""
96102
As find_similar, but using the MMR algorithm to diversify the results.
97103
@@ -132,13 +138,18 @@ def find_similar_mmr(self,
132138
unselected.remove(0)
133139
else:
134140
selected_unselected_similarity = np.array(
135-
[[document_similarity[(i, j)] for j in unselected] for i in selected])
141+
[
142+
[document_similarity[(i, j)] for j in unselected]
143+
for i in selected
144+
]
145+
)
136146
worst_similarity = np.max(
137147
selected_unselected_similarity, axis=0)
138148
relevance_scores = np.array(
139149
[query_similarity[i] for i in unselected])
140-
scores = (1 - lambda_mult) * worst_similarity + \
141-
lambda_mult * relevance_scores
150+
scores = (
151+
1 - lambda_mult
152+
) * worst_similarity + lambda_mult * relevance_scores
142153
max_index = unselected[np.argmax(scores)]
143154
selected.append(max_index)
144155
unselected.remove(max_index)

aperturedb/ParallelQuery.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77
import math
8+
import inspect
89

910

1011
from aperturedb.DaskManager import DaskManager
@@ -17,7 +18,7 @@
1718
def execute_batch(q: Commands, blobs: Blobs, db: Connector,
1819
success_statuses: list[int] = [0],
1920
response_handler: Optional[Callable] = None, commands_per_query: int = 1, blobs_per_query: int = 0,
20-
strict_response_validation: bool = False) -> Tuple[int, CommandResponses, Blobs]:
21+
strict_response_validation: bool = False, cmd_index=None) -> Tuple[int, CommandResponses, Blobs]:
2122
"""
2223
Execute a batch of queries, doing useful logging around it.
2324
Calls the response handler if provided.
@@ -50,7 +51,8 @@ def execute_batch(q: Commands, blobs: Blobs, db: Connector,
5051
if response_handler is not None:
5152
try:
5253
ParallelQuery.map_response_to_handler(response_handler,
53-
q, blobs, r, b, commands_per_query, blobs_per_query)
54+
q, blobs, r, b, commands_per_query, blobs_per_query,
55+
cmd_index)
5456
except BaseException as e:
5557
logger.exception(e)
5658
if strict_response_validation:
@@ -112,7 +114,7 @@ def getSuccessStatus(cls):
112114

113115
@classmethod
114116
def map_response_to_handler(cls, handler, query, query_blobs, response, response_blobs,
115-
commands_per_query, blobs_per_query):
117+
commands_per_query, blobs_per_query, cmd_index_offset):
116118
# We could potentially always call this handler function
117119
# and let the user deal with the error cases.
118120
blobs_returned = 0
@@ -140,7 +142,8 @@ def map_response_to_handler(cls, handler, query, query_blobs, response, respons
140142
response[start:end] if issubclass(
141143
type(response), list) else response,
142144
response_blobs[blobs_returned:blobs_returned + b_count] if
143-
len(response_blobs) >= blobs_returned + b_count else None)
145+
len(response_blobs) >= blobs_returned + b_count else None,
146+
None if cmd_index_offset is None else cmd_index_offset + i)
144147
blobs_returned += b_count
145148

146149
def __init__(self, db: Connector, dry_run: bool = False):
@@ -218,7 +221,7 @@ def call_response_handler(self, q: Commands, blobs: Blobs, r: CommandResponses,
218221
except BaseException as e:
219222
logger.exception(e)
220223

221-
def do_batch(self, db: Connector, data: List[Tuple[Commands, Blobs]]) -> None:
224+
def do_batch(self, db: Connector, batch_start: int, data: List[Tuple[Commands, Blobs]]) -> None:
222225
"""
223226
Executes batch of queries and blobs in the database.
224227
@@ -257,6 +260,19 @@ def process_responses(requests, input_blobs, responses, output_blobs):
257260
response_handler = self.generator.response_handler
258261
if hasattr(self.generator, "strict_response_validation") and isinstance(self.generator.strict_response_validation, bool):
259262
strict_response_validation = self.generator.strict_response_validation
263+
264+
# if response_handler doesn't support index, just discard the index with a wrapper.
265+
if response_handler is not None:
266+
parameter_count = len(inspect.signature(
267+
response_handler).parameters)
268+
if parameter_count < 4 or parameter_count > 5:
269+
raise Exception("Bad Signature for response_handler :"
270+
f"expected 6 > args > 3, got {parameter_count}")
271+
if parameter_count == 4:
272+
indexless_handler = response_handler
273+
def response_handler(query, qblobs, resp, rblobs, qindex): return indexless_handler(
274+
query, qblobs, resp, rblobs)
275+
260276
result, r, b = self.batch_command(
261277
q,
262278
blobs,
@@ -265,7 +281,8 @@ def process_responses(requests, input_blobs, responses, output_blobs):
265281
response_handler,
266282
self.commands_per_query,
267283
self.blobs_per_query,
268-
strict_response_validation=strict_response_validation)
284+
strict_response_validation=strict_response_validation,
285+
cmd_index=batch_start)
269286
if result == 0:
270287
query_time = db.get_last_query_time()
271288
worker_stats["succeeded_commands"] = len(q)
@@ -316,7 +333,8 @@ def worker(self, thid: int, generator, start: int, end: int):
316333
batch_end = min(batch_start + self.batchsize, end)
317334

318335
try:
319-
self.do_batch(db, generator[batch_start:batch_end])
336+
self.do_batch(db, batch_start,
337+
generator[batch_start:batch_end])
320338
except Exception as e:
321339
logger.exception(e)
322340
logger.warning(

aperturedb/ParallelQuerySet.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def gen_execute_batch_sets(base_executor):
4949
#
5050
def execute_batch_sets(query_set, blob_set, db, success_statuses: list[int] = [0],
5151
response_handler: Optional[Callable] = None, commands_per_query: list[int] = -1,
52-
blobs_per_query: list[int] = -1, strict_response_validation: bool = False):
52+
blobs_per_query: list[int] = -1,
53+
strict_response_validation: bool = False, cmd_index: int = None):
5354

5455
logger.info("Execute Batch Sets = Batch Size {0} Comands Per Query {1} Blobs Per Query {2}".format(
5556
len(query_set), commands_per_query, blobs_per_query))
@@ -297,7 +298,11 @@ def constraint_filter(single_line, single_results):
297298
if len(executable_queries) > 0:
298299
result_code, db_results, db_blobs = base_executor(executable_queries, used_blobs,
299300
db, local_success_statuses,
300-
None, commands_per_query[i], blobs_per_query[i], strict_response_validation=strict_response_validation)
301+
None,
302+
commands_per_query[i],
303+
blobs_per_query[i],
304+
strict_response_validation=strict_response_validation,
305+
cmd_index=cmd_index)
301306
if response_handler != None and db.last_query_ok():
302307
def map_to_set(query, query_blobs, resp, resp_blobs):
303308
response_handler(
@@ -365,7 +370,7 @@ def verify_generator(self, generator) -> bool:
365370
logger.error(type(generator[0]))
366371
return False
367372

368-
def do_batch(self, db: Connector, data: List[Tuple[Commands, Blobs]]) -> None:
373+
def do_batch(self, db: Connector, batch_start: int, data: List[Tuple[Commands, Blobs]]) -> None:
369374
"""
370375
This is an override of ParallelQuery.do_batch.
371376
@@ -387,7 +392,7 @@ def do_batch(self, db: Connector, data: List[Tuple[Commands, Blobs]]) -> None:
387392
self.batch_command = gen_execute_batch_sets(
388393
self.base_batch_command)
389394

390-
ParallelQuery.do_batch(self, db, data)
395+
ParallelQuery.do_batch(self, db, batch_start, data)
391396

392397
def print_stats(self) -> None:
393398

aperturedb/Utils.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
"""
22
Miscellaneous utility functions for ApertureDB.
33
"""
4+
from aperturedb.Query import QueryBuilder
5+
from aperturedb.cli.configure import ls
6+
from aperturedb.Configuration import Configuration
7+
from aperturedb.ParallelQuery import execute_batch
8+
from aperturedb import ProgressBar
9+
from aperturedb.ConnectorRest import ConnectorRest
10+
from aperturedb.Connector import Connector
411
import logging
512
import json
613
import os
714
import importlib
815
import sys
916
from typing import List, Optional, Dict
1017

11-
from graphviz import Source, Digraph
18+
HAS_GRAPHVIZ = True
19+
try:
20+
from graphviz import Source, Digraph
21+
except:
22+
HAS_GRAPHVIZ = False
23+
24+
class Source:
25+
pass
26+
27+
class Digraph:
28+
pass
1229

13-
from aperturedb.Connector import Connector
14-
from aperturedb.ConnectorRest import ConnectorRest
15-
from aperturedb import ProgressBar
16-
from aperturedb.ParallelQuery import execute_batch
17-
from aperturedb.Configuration import Configuration
18-
from aperturedb.cli.configure import ls
19-
from aperturedb.Query import QueryBuilder
2030

2131
logger = logging.getLogger(__name__)
2232

@@ -213,10 +223,24 @@ def visualize_schema(self, filename: str = None, format: str = "png") -> Source:
213223
Returns:
214224
source: The visualization of the schema.
215225
"""
226+
if not HAS_GRAPHVIZ:
227+
raise Exception("graphviz not installed.")
216228
r = self.get_schema()
217229

230+
colors = dict(
231+
edge="#3A3B9C",
232+
entity_background="#2A2E78",
233+
entity_foreground="#E2E0F1",
234+
property_background="#337EC0",
235+
property_foreground="#E2E0F1",
236+
connection_background="#5956F1",
237+
connection_foreground="#E2E0F1",
238+
connection_property_background="#33E1FF",
239+
connection_property_foreground="#2A2E78"
240+
)
241+
218242
dot = Digraph(comment='ApertureDB Schema Diagram', node_attr={
219-
'shape': 'none', 'fontcolor': '#E2E0F1'}, graph_attr={'rankdir': 'LR'}, edge_attr={'color': '#3A3B9C'})
243+
'shape': 'none'}, graph_attr={'rankdir': 'LR'}, edge_attr={'color': colors['edge']})
220244

221245
# Add entities as nodes and connections as edges
222246
entities = r['entities']['classes']
@@ -228,25 +252,25 @@ def visualize_schema(self, filename: str = None, format: str = "png") -> Source:
228252
properties = data["properties"]
229253
table = f'''<
230254
<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
231-
<TR><TD BGCOLOR="#2A2E78" COLSPAN="3"><B>{entity}</B> ({matched:,})</TD></TR>
255+
<TR><TD BGCOLOR="{colors["entity_background"]}" COLSPAN="3"><FONT COLOR="{colors["entity_foreground"]}"><B>{entity}</B> ({matched:,})</FONT></TD></TR>
232256
'''
233257
for prop, (matched, indexed, typ) in properties.items():
234-
table += f'<TR><TD BGCOLOR="#337EC0"><B>{prop.strip()}</B></TD> <TD BGCOLOR="#337EC0">{matched:,}</TD> <TD BGCOLOR="#337EC0">{"Indexed" if indexed else "Unindexed"}, {typ}</TD></TR>'
258+
table += f'<TR><TD BGCOLOR="{colors["property_background"]}"><FONT COLOR="{colors["property_foreground"]}"><B>{prop.strip()}</B></FONT></TD> <TD BGCOLOR="{colors["property_background"]}"><FONT COLOR="{colors["property_foreground"]}">{matched:,}</FONT></TD> <TD BGCOLOR="{colors["property_background"]}"><FONT COLOR="{colors["property_foreground"]}">{"Indexed" if indexed else "Unindexed"}, {typ}</FONT></TD></TR>'
235259
for connection, data in connections.items():
236260
if data['src'] == entity:
237261
matched = data["matched"]
238262
# dictionary from name to (matched, indexed, type)
239263
properties = data["properties"]
240-
table += f'<TR><TD BGCOLOR="#5956F1" COLSPAN="3" PORT="{connection}"><B>{connection}</B> ({matched:,})</TD></TR>'
264+
table += f'<TR><TD BGCOLOR="{colors["connection_background"]}" COLSPAN="3" PORT="{connection}"><FONT COLOR="{colors["connection_foreground"]}"><B>{connection}</B> ({matched:,})</FONT></TD></TR>'
241265
if properties:
242266
for prop, (matched, indexed, typ) in properties.items():
243-
table += f'<TR><TD BGCOLOR="#33E1FF"><B>{prop.strip()}</B></TD> <TD BGCOLOR="#33E1FF">{matched:,}</TD> <TD BGCOLOR="#33E1FF">{"Indexed" if indexed else "Unindexed"}, {typ}</TD></TR>'
267+
table += f'<TR><TD BGCOLOR="{colors["connection_property_background"]}"><FONT COLOR="{colors["connection_property_foreground"]}"><B>{prop.strip()}</B></FONT></TD> <TD BGCOLOR="{colors["connection_property_background"]}"><FONT COLOR="{colors["connection_property_foreground"]}">{matched:,}</FONT></TD> <TD BGCOLOR="{colors["connection_property_background"]}"><FONT COLOR="{colors["connection_property_foreground"]}">{"Indexed" if indexed else "Unindexed"}, {typ}</FONT></TD></TR>'
244268
table += '</TABLE>>'
245269
dot.node(entity, label=table)
246270

247271
for connection, data in connections.items():
248272
dot.edge(f'{data["src"]}:{connection}',
249-
f'{data["dst"]}:{connection}')
273+
f'{data["dst"]}')
250274

251275
# Render the diagram inline
252276
s = Source(dot.source, filename="schema_diagram.gv", format="png")

0 commit comments

Comments
 (0)