Skip to content

Commit efc18e1

Browse files
authored
Merge pull request #604 from aperture-data/release-0.4.49
Release 0.4.49 Includes a change in how GetSchema works for connections with no Classes.
2 parents 123effc + 3f32d38 commit efc18e1

File tree

3 files changed

+136
-31
lines changed

3 files changed

+136
-31
lines changed

aperturedb/Connector.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,30 @@
4949

5050

5151
PROTOCOL_VERSION = 1
52+
MESSAGE_LENGTH_FORMAT = '@I'
53+
MESSAGE_LENGTH_SIZE = struct.calcsize(MESSAGE_LENGTH_FORMAT)
54+
55+
# Protocol types
56+
PROTOCOL_TCP = 1
57+
PROTOCOL_SSL = 2
58+
59+
DEFAULT_PORT = 55555
60+
# aperturedb's param ADB_MAX_CONNECTION_MESSAGE_SIZE_MB = 2048 by default
61+
DEFAULT_MAX_MESSAGE_SIZE_MB = 2048
62+
63+
# Query retry constants
64+
DEFAULT_RETRY_INTERVAL_SECONDS = 1
65+
DEFAULT_RETRY_MAX_ATTEMPTS = 3
66+
DEFAULT_SESSION_EXPIRY_OFFSET_SEC = 10
67+
DEFAULT_QUERY_CONNECTION_ERROR_SUPPRESSION_DELTA_SEC = 30
68+
69+
# Session renewal constants
70+
RENEW_SESSION_MAX_ATTEMPTS = 3
71+
RENEW_SESSION_RETRY_INTERVAL_SEC = 1
72+
73+
# Status codes
74+
STATUS_OK = 0
75+
STATUS_ERROR_DEFAULT = -2
5276

5377

5478
class UnauthorizedException(Exception):
@@ -73,7 +97,7 @@ def valid(self) -> bool:
7397

7498
# This triggers refresh if the session is about to expire.
7599
if session_age > self.session_token_ttl - \
76-
int(os.getenv("SESSION_EXPIRY_OFFSET_SEC", 10)):
100+
int(os.getenv("SESSION_EXPIRY_OFFSET_SEC", DEFAULT_SESSION_EXPIRY_OFFSET_SEC)):
77101
return False
78102

79103
return True
@@ -106,14 +130,14 @@ class Connector(object):
106130
str (key): Apeture Key, configuration as a deflated compressed string
107131
"""
108132

109-
def __init__(self, host="localhost", port=55555,
133+
def __init__(self, host="localhost", port=DEFAULT_PORT,
110134
user="", password="", token="",
111135
use_ssl=True,
112136
shared_data=None,
113137
authenticate=True,
114138
use_keepalive=True,
115-
retry_interval_seconds=1,
116-
retry_max_attempts=3,
139+
retry_interval_seconds=DEFAULT_RETRY_INTERVAL_SECONDS,
140+
retry_max_attempts=DEFAULT_RETRY_MAX_ATTEMPTS,
117141
config: Optional[Configuration] = None,
118142
key: Optional[str] = None):
119143
"""
@@ -126,7 +150,8 @@ def __init__(self, host="localhost", port=55555,
126150
self.last_query_timestamp = None
127151
# suppress connection warnings which occur more than this time
128152
# after the last query
129-
self.query_connection_error_suppression_delta = timedelta(seconds=30)
153+
self.query_connection_error_suppression_delta = timedelta(
154+
seconds=DEFAULT_QUERY_CONNECTION_ERROR_SUPPRESSION_DELTA_SEC)
130155

131156
if key is not None:
132157
self.config = Configuration.reinflate(key)
@@ -196,20 +221,20 @@ def __del__(self):
196221
self.connected = False
197222

198223
def _send_msg(self, data):
199-
# aperturedb's param ADB_MAX_CONNECTION_MESSAGE_SIZE_MB = 256 by default
200-
if len(data) > (256 * 2**20):
224+
if len(data) > (DEFAULT_MAX_MESSAGE_SIZE_MB * 2**20):
201225
logger.warning(
202226
"Message sent is larger than default for ApertureDB Server. Server may disconnect.")
203227

204-
sent_len = struct.pack('@I', len(data)) # send size first
228+
sent_len = struct.pack(MESSAGE_LENGTH_FORMAT,
229+
len(data)) # send size first
205230
x = self.conn.send(sent_len + data)
206-
return x == len(data) + 4
231+
return x == len(data) + MESSAGE_LENGTH_SIZE
207232

208233
def _recv_msg(self):
209-
recv_len = self.conn.recv(4) # get message size
234+
recv_len = self.conn.recv(MESSAGE_LENGTH_SIZE) # get message size
210235
if recv_len == b'':
211236
return None
212-
recv_len = struct.unpack('@I', recv_len)[0]
237+
recv_len = struct.unpack(MESSAGE_LENGTH_FORMAT, recv_len)[0]
213238
response = bytearray(recv_len)
214239
read = 0
215240
while read < recv_len:
@@ -250,7 +275,7 @@ def _authenticate(self, user, password="", token=""):
250275
"Unexpected response from server upon authenticate request: " +
251276
str(response))
252277
session_info = response[0]["Authenticate"]
253-
if session_info["status"] != 0:
278+
if session_info["status"] != STATUS_OK:
254279
raise Exception(session_info["info"])
255280

256281
self.shared_data.session = Session(
@@ -280,7 +305,7 @@ def _refresh_token(self):
280305
logger.info(f"Refresh token response: \r\n{response}")
281306
if isinstance(response, list):
282307
session_info = response[0]["RefreshToken"]
283-
if session_info["status"] != 0:
308+
if session_info["status"] != STATUS_OK:
284309
# Refresh token failed, we need to re-authenticate
285310
# This is possible with a long lived connector, where
286311
# the session token and the refresh token have expired.
@@ -320,7 +345,7 @@ def _connect(self):
320345

321346
# Handshake with server to negotiate protocol
322347

323-
protocol = 2 if self.use_ssl else 1
348+
protocol = PROTOCOL_SSL if self.use_ssl else PROTOCOL_TCP
324349

325350
hello_msg = struct.pack('@II', PROTOCOL_VERSION, protocol)
326351

@@ -513,16 +538,16 @@ def query(self, q, blobs=[]):
513538

514539
def _renew_session(self):
515540
count = 0
516-
while count < 3:
541+
while count < RENEW_SESSION_MAX_ATTEMPTS:
517542
try:
518543
self._check_session_status()
519544
break
520545
except UnauthorizedException as e:
521546
logger.warning(
522-
f"[Attempt {count + 1} of 3] Failed to refresh token.",
547+
f"[Attempt {count + 1} of {RENEW_SESSION_MAX_ATTEMPTS}] Failed to refresh token.",
523548
exc_info=True,
524549
stack_info=True)
525-
time.sleep(1)
550+
time.sleep(RENEW_SESSION_RETRY_INTERVAL_SEC)
526551
count += 1
527552

528553
def clone(self) -> Connector:
@@ -583,7 +608,7 @@ def check_status(self, json_res: CommandResponses) -> int:
583608
int: The value recieved from the server, or -2 if not found.
584609
"""
585610
# Default status is -2, which is an error, but not a server error.
586-
status = -2
611+
status = STATUS_ERROR_DEFAULT
587612
if (isinstance(json_res, dict)):
588613
if ("status" not in json_res):
589614
status = self.check_status(json_res[list(json_res.keys())[0]])

aperturedb/ParallelLoader.py

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
21
from aperturedb import ParallelQuery
32
from aperturedb.Connector import Connector
43
from aperturedb.Utils import Utils
54
from aperturedb.Subscriptable import Subscriptable
65

76
import numpy as np
87
import logging
8+
9+
# For each property of each Entity or Connection,
10+
# the following information is returned as an array of 3 elements of the form [matched, has_index_flag, type]:
11+
# - matched: Number of objects that match the search.
12+
# - has_index_flag: Indicates whether the property is indexed or not.
13+
# - type: Type for the property. See supported types.
14+
# https://docs.aperturedata.io/query_language/Reference/db_commands/GetSchema
15+
PROPERTIES_SCHEMA_INDEX_FLAG = 1
16+
917
logger = logging.getLogger(__name__)
1018

1119

@@ -23,20 +31,92 @@ def __init__(self, client: Connector, dry_run: bool = False):
2331
self.utils = Utils(self.client)
2432
self.type = "element"
2533

34+
def get_entity_indexes(self, schema: dict) -> dict:
35+
"""
36+
Returns a dictionary of indexes for entities' properties.
37+
38+
Args:
39+
schema (dict): The schema dictionary to get indexes from.
40+
41+
Returns:
42+
dict: A dictionary of entity indexes.
43+
"""
44+
45+
indexes = {}
46+
47+
entities = schema.get("entities") or {}
48+
49+
for cls_name, cls_schema in (entities.get("classes") or {}).items():
50+
for prop_name, prop_schema in (cls_schema.get("properties") or {}).items():
51+
if prop_schema[PROPERTIES_SCHEMA_INDEX_FLAG]:
52+
indexes.setdefault("entity", {}).setdefault(
53+
cls_name, set()).add(prop_name)
54+
55+
return indexes
56+
57+
def get_connection_indexes(self, schema: dict) -> dict:
58+
"""
59+
Returns a dictionary of indexes for connections' properties.
60+
61+
Args:
62+
schema (dict): The schema dictionary to get indexes from.
63+
64+
Returns:
65+
dict: A dictionary of connection indexes.
66+
"""
67+
68+
indexes = {}
69+
70+
connections = schema.get("connections") or {}
71+
cls_names = connections.get("classes") or {}
72+
73+
for cls_name, cls_schema in cls_names.items():
74+
75+
# check if cls_schema is a dict or a list
76+
if isinstance(cls_schema, dict):
77+
for prop_name, prop_schema in (cls_schema.get("properties") or {}).items():
78+
if prop_schema[PROPERTIES_SCHEMA_INDEX_FLAG]:
79+
indexes.setdefault("connection", {}).setdefault(
80+
cls_name, set()).add(prop_name)
81+
elif isinstance(cls_schema, list):
82+
# If cls_schema is a list, this occurs when the schema defines multiple connection variants
83+
# for the same class. Each element in the list is expected to be a dictionary representing
84+
# a connection variant, with a "properties" key containing the property schemas.
85+
# Example schema format:
86+
# "connections": {
87+
# "classes": {
88+
# "SomeConnectionClass": [
89+
# {"properties": {"prop1": [...], "prop2": [...]}},
90+
# {"properties": {"prop3": [...], "prop4": [...]}},
91+
# ]
92+
# }
93+
# }
94+
for connection in cls_schema:
95+
for prop_name, prop_schema in (connection.get("properties") or {}).items():
96+
if prop_schema[PROPERTIES_SCHEMA_INDEX_FLAG]:
97+
indexes.setdefault("connection", {}).setdefault(
98+
cls_name, set()).add(prop_name)
99+
else:
100+
exception_msg = "Unexpected schema format for connection class "
101+
exception_msg += f"'{cls_name}': {cls_schema}"
102+
logger.error(exception_msg)
103+
raise ValueError(exception_msg)
104+
105+
return indexes
106+
26107
def get_existing_indices(self):
108+
109+
indexes = {}
27110
schema = self.utils.get_schema()
28-
existing_indices = {}
111+
29112
if schema:
30-
for index_type in (("entity", "entities"), ("connection", "connections")):
31-
foo = schema.get(index_type[1]) or {}
32-
bar = foo.get("classes") or {}
33-
for cls_name, cls_schema in bar.items():
34-
props = cls_schema.get("properties") or {}
35-
for prop_name, prop_schema in props.items():
36-
if prop_schema[1]: # indicates property has an index
37-
existing_indices.setdefault(index_type[0], {}).setdefault(
38-
cls_name, set()).add(prop_name)
39-
return existing_indices
113+
entity_indexes = self.get_entity_indexes(schema)
114+
connection_indexes = self.get_connection_indexes(schema)
115+
116+
# Combine both entity and connection indexes
117+
indexes = {**entity_indexes, **connection_indexes}
118+
119+
return indexes
40120

41121
def query_setup(self, generator: Subscriptable) -> None:
42122
"""

aperturedb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import signal
1111
import sys
1212

13-
__version__ = "0.4.48"
13+
__version__ = "0.4.49"
1414

1515
logger = logging.getLogger(__name__)
1616

0 commit comments

Comments
 (0)