diff --git a/Makefile b/Makefile index 579e6f4..642a223 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,11 @@ build: ${VENV} clean ## Clean and build ${PYTHON} setup.py bdist_wheel ; \ ls -lh dist +build-sync: ${VENV} clean ## Clean and build + set -e ; \ + OBX_BUILD_SYNC=1 ${PYTHON} setup.py bdist_wheel ; \ + ls -lh dist + ${VENV}: ${VENVBIN}/activate venv-init: @@ -49,6 +54,10 @@ depend: ${VENV} ## Prepare dependencies set -e ; \ ${PYTHON} download-c-lib.py +depend-sync: ${VENV} ## Prepare dependencies + set -e ; \ + ${PYTHON} download-c-lib.py --sync + test: ${VENV} ## Test all targets set -e ; \ ${PYTHON} -m pytest --capture=no --verbose diff --git a/download-c-lib.py b/download-c-lib.py index 58d711d..b75bfaf 100644 --- a/download-c-lib.py +++ b/download-c-lib.py @@ -2,12 +2,15 @@ import tarfile import zipfile import os +import sys # Script used to download objectbox-c shared libraries for all supported platforms. Execute by running `make get-lib` # on first checkout of this repo and any time after changing the objectbox-c lib version. -version = "v4.0.0" # see objectbox/c.py required_version +version = "v5.0.0" # see objectbox/c.py required_version variant = 'objectbox' # or 'objectbox-sync' +if len(sys.argv) > 1 and sys.argv[1] == '--sync': + variant = 'objectbox-sync' base_url = "https://github.com/objectbox/objectbox-c/releases/download/" @@ -21,7 +24,7 @@ "x86_64/libobjectbox.so": "linux-x64.tar.gz", "aarch64/libobjectbox.so": "linux-aarch64.tar.gz", "armv7l/libobjectbox.so": "linux-armv7hf.tar.gz", - "armv6l/libobjectbox.so": "linux-armv6hf.tar.gz", + #"armv6l/libobjectbox.so": "linux-armv6hf.tar.gz", # mac "macos-universal/libobjectbox.dylib": "macos-universal.zip", diff --git a/objectbox/__init__.py b/objectbox/__init__.py index 71bf846..d9a1b92 100644 --- a/objectbox/__init__.py +++ b/objectbox/__init__.py @@ -16,7 +16,7 @@ from objectbox.store import Store from objectbox.box import Box -from objectbox.model.entity import Entity +from objectbox.model.entity import Entity, SyncEntity from objectbox.model.properties import Id, String, Index, Bool, Int8, Int16, Int32, Int64, Float32, Float64, Bytes, BoolVector, Int8Vector, Int16Vector, Int32Vector, Int64Vector, Float32Vector, Float64Vector, CharVector, BoolList, Int8List, Int16List, Int32List, Int64List, Float32List, Float64List, CharList, Date, DateNano, Flex, HnswIndex, VectorDistanceType, HnswFlags from objectbox.model.model import Model from objectbox.c import version_core, DebugFlags @@ -74,11 +74,12 @@ 'PropertyQueryCondition', 'HnswFlags', 'Query', - 'QueryBuilder' + 'QueryBuilder', + 'SyncEntity' ] # Python binding version -version = Version(4, 0, 0) +version = Version(5, 0, 0) """ObjectBox Python package version""" def version_info(): diff --git a/objectbox/c.py b/objectbox/c.py index 03b7aea..9a8d8b4 100644 --- a/objectbox/c.py +++ b/objectbox/c.py @@ -16,10 +16,12 @@ import ctypes.util import os import platform -from objectbox.version import Version +from ctypes import c_char_p from typing import * + import numpy as np -from enum import IntEnum + +from objectbox.version import Version # This file contains C-API bindings based on lib/objectbox.h, linking to the 'objectbox' shared library. # The bindings are implementing using ctypes, see https://docs.python.org/dev/library/ctypes.html for introduction. @@ -27,7 +29,7 @@ # Version of the library used by the binding. This version is checked at runtime to ensure binary compatibility. # Don't forget to update download-c-lib.py when upgrading to a newer version. -required_version = "4.0.0" +required_version = "5.0.0" def shlib_name(library: str) -> str: @@ -303,6 +305,11 @@ class DbErrorCode(IntEnum): OBX_ERROR_TREE_OTHER = 10699 +class OBXEntityFlags(IntEnum): + SYNC_ENABLED = 2 + SHARED_GLOBAL_IDS = 4 + + def check_obx_err(code: obx_err, func, args) -> obx_err: """ Raises an exception if obx_err is not successful. """ if code != DbErrorCode.OBX_SUCCESS: @@ -310,6 +317,11 @@ def check_obx_err(code: obx_err, func, args) -> obx_err: raise create_db_error(code) return code +def check_obx_success(code: obx_err) -> bool: + if code == DbErrorCode.OBX_NO_SUCCESS: + return False + check_obx_err(code, None, None) + return True def check_obx_qb_cond(qb_cond: obx_qb_cond, func, args) -> obx_qb_cond: """ Raises an exception if obx_qb_cond is not successful. """ @@ -419,6 +431,9 @@ def c_array_pointer(py_list: Union[List[Any], np.ndarray], c_type): obx_model_entity = c_fn_rc('obx_model_entity', [ OBX_model_p, ctypes.c_char_p, obx_schema_id, obx_uid]) +# obx_err obx_model_entity_flags(OBX_model* model, uint32_t flags); +obx_model_entity_flags = c_fn_rc('obx_model_entity_flags', [OBX_model_p, ctypes.c_uint32]) + # obx_err (OBX_model* model, const char* name, OBXPropertyType type, obx_schema_id property_id, obx_uid property_uid); obx_model_property = c_fn_rc('obx_model_property', [OBX_model_p, ctypes.c_char_p, OBXPropertyType, obx_schema_id, obx_uid]) @@ -1068,3 +1083,241 @@ def c_array_pointer(py_list: Union[List[Any], np.ndarray], c_type): OBXBackupRestoreFlags_None = 0 OBXBackupRestoreFlags_OverwriteExistingData = 1 + + +# Sync API + +class OBX_sync(ctypes.Structure): + pass + + +OBX_sync_p = ctypes.POINTER(OBX_sync) + + +class OBX_sync_server(ctypes.Structure): + pass + + +OBX_sync_server_p = ctypes.POINTER(OBX_sync_server) + +OBXSyncCredentialsType = ctypes.c_int +OBXRequestUpdatesMode = ctypes.c_int +OBXSyncState = ctypes.c_int +OBXSyncCode = ctypes.c_int + + +class SyncCredentialsType(IntEnum): + NONE = 1 + SHARED_SECRET = 2 # Deprecated, use SHARED_SECRET_SIPPED instead + GOOGLE_AUTH = 3 + SHARED_SECRET_SIPPED = 4 # Uses shared secret to create a hashed credential + OBX_ADMIN_USER = 5 # ObjectBox admin users (username/password) + USER_PASSWORD = 6 # Generic credential type for admin users + JWT_ID = 7 # JSON Web Token (JWT): ID token with user identity + JWT_ACCESS = 8 # JSON Web Token (JWT): access token for resources + JWT_REFRESH = 9 # JSON Web Token (JWT): refresh token + JWT_CUSTOM = 10 # JSON Web Token (JWT): custom token type + + +class RequestUpdatesMode(IntEnum): + MANUAL = 0 # No updates by default, must call obx_sync_updates_request() manually + AUTO = 1 # Same as calling obx_sync_updates_request(sync, TRUE) + AUTO_NO_PUSHES = 2 # Same as calling obx_sync_updates_request(sync, FALSE) + + +class SyncState(IntEnum): + CREATED = 1 + STARTED = 2 + CONNECTED = 3 + LOGGED_IN = 4 + DISCONNECTED = 5 + STOPPED = 6 + DEAD = 7 + + +class OBXSyncError(IntEnum): + REJECT_TX_NO_PERMISSION = 1 # Sync client received rejection of transaction writes due to missing permissions + + +class OBXSyncObjectType(IntEnum): + FlatBuffers = 1 + String = 2 + Raw = 3 + + +class OBX_sync_change(ctypes.Structure): + _fields_ = [ + ('entity_id', obx_schema_id), + ('puts', ctypes.POINTER(OBX_id_array)), + ('removals', ctypes.POINTER(OBX_id_array)), + ] + + +class OBX_sync_change_array(ctypes.Structure): + _fields_ = [ + ('list', ctypes.POINTER(OBX_sync_change)), + ('count', ctypes.c_size_t), + ] + + +class OBX_sync_object(ctypes.Structure): + _fields_ = [ + ('type', ctypes.c_int), # OBXSyncObjectType + ('id', ctypes.c_uint64), + ('data', ctypes.c_void_p), + ('size', ctypes.c_size_t), + ] + + +class OBX_sync_msg_objects(ctypes.Structure): + _fields_ = [ + ('topic', ctypes.c_void_p), + ('topic_size', ctypes.c_size_t), + ('objects', ctypes.POINTER(OBX_sync_object)), + ('count', ctypes.c_size_t), + ] + + +class OBX_sync_msg_objects_builder(ctypes.Structure): + pass + + +OBX_sync_msg_objects_builder_p = ctypes.POINTER(OBX_sync_msg_objects_builder) + +# Define callback types for sync listeners +OBX_sync_listener_connect = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +OBX_sync_listener_disconnect = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +OBX_sync_listener_login = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +OBX_sync_listener_login_failure = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int) # arg, OBXSyncCode +OBX_sync_listener_complete = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +OBX_sync_listener_error = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int) # arg, OBXSyncError +OBX_sync_listener_change = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.POINTER(OBX_sync_change_array)) +OBX_sync_listener_server_time = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int64) +OBX_sync_listener_msg_objects = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.POINTER(OBX_sync_msg_objects)) + +# OBX_sync* obx_sync(OBX_store* store, const char* server_url); +obx_sync = c_fn("obx_sync", OBX_sync_p, [OBX_store_p, ctypes.c_char_p]) + +# OBX_sync* obx_sync_urls(OBX_store* store, const char* server_urls[], size_t server_urls_count); +obx_sync_urls = c_fn("obx_sync_urls", OBX_sync_p, [OBX_store_p, ctypes.POINTER(ctypes.c_char_p), ctypes.c_size_t]) + +# Client Credentials + +# obx_err obx_sync_credentials(OBX_sync* sync, OBXSyncCredentialsType type, const void* data, size_t size); +obx_sync_credentials = c_fn_rc('obx_sync_credentials', + [OBX_sync_p, OBXSyncCredentialsType, ctypes.c_void_p, ctypes.c_size_t]) + +# obx_err obx_sync_credentials_user_password(OBX_sync* sync, OBXSyncCredentialsType type, const char* username, const char* password); +obx_sync_credentials_user_password = c_fn_rc('obx_sync_credentials_user_password', + [OBX_sync_p, OBXSyncCredentialsType, ctypes.c_char_p, + ctypes.c_char_p]) + +# obx_err obx_sync_credentials_add(OBX_sync* sync, OBXSyncCredentialsType type, const void* data, size_t size, bool complete); +obx_sync_credentials_add = c_fn_rc('obx_sync_credentials_add', + [OBX_sync_p, OBXSyncCredentialsType, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_bool]) + +# obx_err obx_sync_credentials_add_user_password(OBX_sync* sync, OBXSyncCredentialsType type, const char* username, const char* password, bool complete); +obx_sync_credentials_add_user_password = c_fn_rc('obx_sync_credentials_add_user_password', + [OBX_sync_p, OBXSyncCredentialsType, ctypes.c_char_p, ctypes.c_char_p, + ctypes.c_bool]) + +# Sync Control + +# OBXSyncState obx_sync_state(OBX_sync* sync); +obx_sync_state = c_fn('obx_sync_state', OBXSyncState, [OBX_sync_p]) + +# obx_err obx_sync_request_updates_mode(OBX_sync* sync, OBXRequestUpdatesMode mode); +obx_sync_request_updates_mode = c_fn_rc('obx_sync_request_updates_mode', [OBX_sync_p, OBXRequestUpdatesMode]) + +# OBX_C_API obx_err obx_sync_updates_request(OBX_sync* sync, bool subscribe_for_pushes); +obx_sync_updates_request = c_fn_rc('obx_sync_updates_request', [OBX_sync_p, ctypes.c_bool]) + +# OBX_C_API obx_err obx_sync_updates_cancel(OBX_sync* sync); +obx_sync_updates_cancel = c_fn_rc('obx_sync_updates_cancel', [OBX_sync_p]) + +# obx_err obx_sync_start(OBX_sync* sync); +obx_sync_start = c_fn_rc('obx_sync_start', [OBX_sync_p]) + +# obx_err obx_sync_stop(OBX_sync* sync); +obx_sync_stop = c_fn_rc('obx_sync_stop', [OBX_sync_p]) + +# obx_err obx_sync_trigger_reconnect(OBX_sync* sync); +obx_sync_trigger_reconnect = c_fn_rc('obx_sync_trigger_reconnect', [OBX_sync_p]) + +# uint32_t obx_sync_protocol_version(); +obx_sync_protocol_version = c_fn('obx_sync_protocol_version', ctypes.c_uint32, []) + +# uint32_t obx_sync_protocol_version_server(OBX_sync* sync); +obx_sync_protocol_version_server = c_fn('obx_sync_protocol_version_server', ctypes.c_uint32, [OBX_sync_p]) + +# obx_err obx_sync_wait_for_logged_in_state(OBX_sync* sync, uint64_t timeout_millis); +obx_sync_wait_for_logged_in_state = c_fn_rc('obx_sync_wait_for_logged_in_state', [OBX_sync_p, ctypes.c_uint64]) + +# obx_err obx_sync_close(OBX_sync* sync); +obx_sync_close = c_fn_rc('obx_sync_close', [OBX_sync_p]) + +# Listener Callbacks + +# void obx_sync_listener_connect(OBX_sync* sync, OBX_sync_listener_connect* listener, void* listener_arg); +obx_sync_listener_connect = c_fn('obx_sync_listener_connect', None, [OBX_sync_p, OBX_sync_listener_connect, ctypes.c_void_p]) + +# void obx_sync_listener_disconnect(OBX_sync* sync, OBX_sync_listener_disconnect* listener, void* listener_arg); +obx_sync_listener_disconnect = c_fn('obx_sync_listener_disconnect', None, [OBX_sync_p, OBX_sync_listener_disconnect, ctypes.c_void_p]) + +# void obx_sync_listener_login(OBX_sync* sync, OBX_sync_listener_login* listener, void* listener_arg); +obx_sync_listener_login = c_fn('obx_sync_listener_login', None, [OBX_sync_p, OBX_sync_listener_login, ctypes.c_void_p]) + +# void obx_sync_listener_login_failure(OBX_sync* sync, OBX_sync_listener_login_failure* listener, void* listener_arg); +obx_sync_listener_login_failure = c_fn('obx_sync_listener_login_failure', None, [OBX_sync_p, OBX_sync_listener_login_failure, ctypes.c_void_p]) + +# void obx_sync_listener_complete(OBX_sync* sync, OBX_sync_listener_complete* listener, void* listener_arg); +obx_sync_listener_error = c_fn('obx_sync_listener_error', None, [OBX_sync_p, OBX_sync_listener_error, ctypes.c_void_p]) + +# void obx_sync_listener_change(OBX_sync* sync, OBX_sync_listener_change* listener, void* listener_arg); +obx_sync_listener_change = c_fn('obx_sync_listener_change', None, + [OBX_sync_p, OBX_sync_listener_change, ctypes.c_void_p]) + +# Filter Variables + +# obx_err obx_sync_filter_variables_put(OBX_sync* sync, const char* name, const char* value); +obx_sync_filter_variables_put = c_fn_rc('obx_sync_filter_variables_put', + [OBX_sync_p, c_char_p, c_char_p]) + +# obx_err obx_sync_filter_variables_remove(OBX_sync* sync, const char* name); +obx_sync_filter_variables_remove = c_fn_rc('obx_sync_filter_variables_remove', + [OBX_sync_p, c_char_p]) + +# obx_err obx_sync_filter_variables_remove_all(OBX_sync* sync); +obx_sync_filter_variables_remove_all = c_fn_rc('obx_sync_filter_variables_remove_all', + [OBX_sync_p]) + +# OBX_C_API obx_err obx_sync_outgoing_message_count(OBX_sync* sync, uint64_t limit, uint64_t* out_count); +obx_sync_outgoing_message_count = c_fn_rc('obx_sync_outgoing_message_count', + [OBX_sync_p, ctypes.c_uint64, ctypes.POINTER(ctypes.c_uint64)]) + +OBXFeature = ctypes.c_int + +class Feature(IntEnum): + ResultArray = 1 # Functions that are returning multiple results (e.g. multiple objects) can be only used if this is available. + TimeSeries = 2 # TimeSeries support (date/date-nano companion ID and other time-series functionality). + Sync = 3 # Sync client availability. Visit https://objectbox.io/sync for more details. + DebugLog = 4 # Check whether debug log can be enabled during runtime. + Admin = 5 # Admin UI including a database browser, user management, and more. Depends on HttpServer (if Admin is available HttpServer is too). + Tree = 6 # Tree with special GraphQL support + SyncServer = 7 # Sync server availability. Visit https://objectbox.io/sync for more details. + WebSockets = 8 # Implicitly added by Sync or SyncServer; disable via NoWebSockets + Cluster = 9 # Sync Server has cluster functionality. Implicitly added by SyncServer; disable via NoCluster + HttpServer = 10 # Embedded HTTP server. + GraphQL = 11 # Embedded GraphQL server (via HTTP). Depends on HttpServer (if GraphQL is available HttpServer is too). + Backup = 12 # Database Backup functionality; typically only enabled in Sync Server builds. + Lmdb = 13 # The default database "provider"; writes data persistently to disk (ACID). + VectorSearch = 14 # Vector search functionality; enables indexing for nearest neighbor search. + Wal = 15 # WAL (write-ahead logging). + SyncMongoDb = 16 # Sync connector to integrate MongoDB with SyncServer. + Auth = 17 # Enables additional authentication/authorization methods for sync login, e.g. + Trial = 18 # This is a free trial version; only applies to server builds (no trial builds for database and Sync clients). + SyncFilters = 19 # Server-side filters to return individual data for each sync user (user-specific data). + + +# bool obx_has_feature(OBXFeature feature); +obx_has_feature = c_fn('obx_has_feature', ctypes.c_bool, [OBXFeature]) diff --git a/objectbox/model/entity.py b/objectbox/model/entity.py index 9e6d46a..dda6f62 100644 --- a/objectbox/model/entity.py +++ b/objectbox/model/entity.py @@ -38,6 +38,7 @@ def __init__(self, user_type, uid: int = 0): self._id_property = None self._fill_properties() self._tl = threading.local() + self._flags = 0 @property def _id(self) -> int: @@ -320,3 +321,9 @@ def wrapper(class_) -> Callable[[Type], _Entity]: return entity_type return wrapper + + +def SyncEntity(cls): + entity: _Entity = obx_models_by_name["default"][-1] # get the last added entity + entity._flags |= OBXEntityFlags.SYNC_ENABLED + return cls diff --git a/objectbox/model/model.py b/objectbox/model/model.py index c4e4875..2c75ad5 100644 --- a/objectbox/model/model.py +++ b/objectbox/model/model.py @@ -102,6 +102,7 @@ def _create_property(self, prop: Property): def _create_entity(self, entity: _Entity): obx_model_entity(self._c_model, c_str(entity._name), entity._id, entity._uid) + obx_model_entity_flags(self._c_model, entity._flags) for prop in entity._properties: self._create_property(prop) obx_model_entity_last_property_id(self._c_model, entity._last_property_iduid.id, entity._last_property_iduid.uid) diff --git a/objectbox/store.py b/objectbox/store.py index c41e452..2fb07ba 100644 --- a/objectbox/store.py +++ b/objectbox/store.py @@ -127,6 +127,7 @@ def __init__(self, """ self._c_store = None + self._close_listeners: list[Callable[[], None]] = [] if not c_store: options = StoreOptions() try: @@ -272,6 +273,9 @@ def write_tx(self): def close(self): """Close database.""" + for listener in self._close_listeners: + listener() + self._close_listeners.clear() c_store_to_close = self._c_store if c_store_to_close: self._c_store = None @@ -285,3 +289,9 @@ def remove_db_files(db_dir: str): Path to DB directory. """ c.obx_remove_db_files(c.c_str(db_dir)) + + def c_store(self): + return self._c_store + + def add_store_close_listener(self, on_store_close: Callable[[], None]): + self._close_listeners.append(on_store_close) diff --git a/objectbox/sync.py b/objectbox/sync.py new file mode 100644 index 0000000..1549334 --- /dev/null +++ b/objectbox/sync.py @@ -0,0 +1,861 @@ +import ctypes +from enum import Enum, auto, IntEnum + +import objectbox.c as c +from objectbox import Store +from objectbox.c import OBX_sync_change_array + + +class SyncCredentials: + """Credentials used to authenticate a sync client against a server.""" + + def __init__(self, credential_type: c.SyncCredentialsType): + self.type = credential_type + + @staticmethod + def none() -> 'SyncCredentials': + """No credentials - usually only for development purposes with a server + configured to accept all connections without authentication. + + Returns: + A SyncCredentials instance with no authentication. + """ + return SyncCredentialsNone() + + @staticmethod + def shared_secret_string(secret: str) -> 'SyncCredentials': + """Shared secret authentication. + + Args: + secret: The shared secret string. + + Returns: + A SyncCredentials instance for shared secret authentication. + """ + return SyncCredentialsSecret(c.SyncCredentialsType.SHARED_SECRET_SIPPED, secret.encode('utf-8')) + + @staticmethod + def google_auth(secret: str) -> 'SyncCredentials': + """Google authentication. + + Args: + secret: The Google authentication token. + + Returns: + A SyncCredentials instance for Google authentication. + """ + return SyncCredentialsSecret(c.SyncCredentialsType.GOOGLE_AUTH, secret.encode('utf-8')) + + @staticmethod + def user_and_password(username: str, password: str) -> 'SyncCredentials': + """Username and password authentication. + + Args: + username: The username. + password: The password. + + Returns: + A SyncCredentials instance for username/password authentication. + """ + return SyncCredentialsUserPassword(c.SyncCredentialsType.USER_PASSWORD, username, password) + + @staticmethod + def jwt_id_token(jwt_id_token: str) -> 'SyncCredentials': + """JSON Web Token (JWT): an ID token that typically provides identity + information about the authenticated user. + + Args: + jwt_id_token: The JWT ID token. + + Returns: + A SyncCredentials instance for JWT ID token authentication. + """ + return SyncCredentialsSecret(c.SyncCredentialsType.JWT_ID, jwt_id_token.encode('utf-8')) + + @staticmethod + def jwt_access_token(jwt_access_token: str) -> 'SyncCredentials': + """JSON Web Token (JWT): an access token that is used to access resources. + + Args: + jwt_access_token: The JWT access token. + + Returns: + A SyncCredentials instance for JWT access token authentication. + """ + return SyncCredentialsSecret(c.SyncCredentialsType.JWT_ACCESS, jwt_access_token.encode('utf-8')) + + @staticmethod + def jwt_refresh_token(jwt_refresh_token: str) -> 'SyncCredentials': + """JSON Web Token (JWT): a refresh token that is used to obtain a new + access token. + + Args: + jwt_refresh_token: The JWT refresh token. + + Returns: + A SyncCredentials instance for JWT refresh token authentication. + """ + return SyncCredentialsSecret(c.SyncCredentialsType.JWT_REFRESH, jwt_refresh_token.encode('utf-8')) + + @staticmethod + def jwt_custom_token(jwt_custom_token: str) -> 'SyncCredentials': + """JSON Web Token (JWT): a token that is neither an ID, access, + nor refresh token. + + Args: + jwt_custom_token: The custom JWT token. + + Returns: + A SyncCredentials instance for custom JWT token authentication. + """ + return SyncCredentialsSecret(c.SyncCredentialsType.JWT_CUSTOM, jwt_custom_token.encode('utf-8')) + + +class SyncCredentialsNone(SyncCredentials): + """Internal use only. Represents no credentials for authentication.""" + + def __init__(self): + super().__init__(c.SyncCredentialsType.NONE) + + +class SyncCredentialsSecret(SyncCredentials): + """Internal use only. Sync credential that is a single secret string.""" + + def __init__(self, credential_type: c.SyncCredentialsType, secret: bytes): + """Creates a secret-based credential. + + Args: + credential_type: The type of credential. + secret: UTF-8 encoded secret bytes. + """ + super().__init__(credential_type) + self.secret = secret + + +class SyncCredentialsUserPassword(SyncCredentials): + """Internal use only. Sync credential with username and password.""" + + def __init__(self, credential_type: c.SyncCredentialsType, username: str, password: str): + """Creates a username/password credential. + + Args: + credential_type: The type of credential. + username: The username. + password: The password. + """ + super().__init__(credential_type) + self.username = username + self.password = password + + +class SyncState(Enum): + """Current state of the SyncClient.""" + + UNKNOWN = auto() + """State is unknown, e.g. C-API reported a state that's not recognized yet.""" + + CREATED = auto() + """Client created but not yet started.""" + + STARTED = auto() + """Client started and connecting.""" + + CONNECTED = auto() + """Connection with the server established but not authenticated yet.""" + + LOGGED_IN = auto() + """Client authenticated and synchronizing.""" + + DISCONNECTED = auto() + """Lost connection, will try to reconnect if the credentials are valid.""" + + STOPPED = auto() + """Client in the process of being closed.""" + + DEAD = auto() + """Invalid access to the client after it was closed.""" + + +class SyncRequestUpdatesMode: + """Configuration of how SyncClient fetches remote updates from the server.""" + + MANUAL = 'manual' + """No updates, SyncClient.request_updates() must be called manually.""" + + AUTO = 'auto' + """Automatic updates, including subsequent pushes from the server, same as + calling SyncClient.request_updates(True). This is the default unless + changed by SyncClient.set_request_updates_mode().""" + + AUTO_NO_PUSHES = 'auto_no_pushes' + """Automatic update after connection, without subscribing for pushes from the + server. Similar to calling SyncClient.request_updates(False).""" + + +class SyncConnectionEvent: + """Connection state change event.""" + + CONNECTED = 'connected' + """Connection to the server is established.""" + + DISCONNECTED = 'disconnected' + """Connection to the server is lost.""" + + +class SyncLoginEvent: + """Login state change event.""" + + LOGGED_IN = 'logged_in' + """Client has successfully logged in to the server.""" + + CREDENTIALS_REJECTED = 'credentials_rejected' + """Client's credentials have been rejected by the server. + Connection will NOT be retried until new credentials are provided.""" + + UNKNOWN_ERROR = 'unknown_error' + """An unknown error occurred during authentication.""" + + +class SyncCode(IntEnum): + """Sync response/error codes.""" + + OK = 20 + """Operation completed successfully.""" + + REQ_REJECTED = 40 + """Request was rejected.""" + + CREDENTIALS_REJECTED = 43 + """Credentials were rejected by the server.""" + + UNKNOWN = 50 + """Unknown error occurred.""" + + AUTH_UNREACHABLE = 53 + """Authentication server is unreachable.""" + + BAD_VERSION = 55 + """Protocol version mismatch.""" + + CLIENT_ID_TAKEN = 61 + """Client ID is already in use.""" + + TX_VIOLATED_UNIQUE = 71 + """Transaction violated a unique constraint.""" + + +class SyncChange: + """Sync incoming data event.""" + + def __init__(self, entity_id: int, puts: list[int], removals: list[int]): + """Creates a SyncChange event. + + Args: + entity_id: Entity ID this change relates to. + puts: List of "put" (inserted/updated) object IDs. + removals: List of removed object IDs. + """ + self.entity_id = entity_id + """Entity ID this change relates to.""" + + self.puts = puts + """List of "put" (inserted/updated) object IDs.""" + + self.removals = removals + """List of removed object IDs.""" + + +class SyncLoginListener: + """Listener for sync login events. + + Implement this class and pass to SyncClient.set_login_listener() to receive + notifications about login success or failure. + """ + + def on_logged_in(self): + """Called when the client has successfully logged in to the server.""" + pass + + def on_login_failed(self, sync_login_code: SyncCode): + """Called when login has failed. + + Args: + sync_login_code: The error code indicating why login failed. + """ + pass + + +class SyncConnectionListener: + """Listener for sync connection events. + + Implement this class and pass to SyncClient.set_connection_listener() to receive + notifications about connection state changes. + """ + + def on_connected(self): + """Called when the connection to the server is established.""" + pass + + def on_disconnected(self): + """Called when the connection to the server is lost.""" + pass + + +class SyncErrorListener: + """Listener for sync error events. + + Implement this class and pass to SyncClient.set_error_listener() to receive + notifications about sync errors. + """ + + def on_error(self, sync_error_code: int): + """Called when a sync error occurs. + + Args: + sync_error_code: The error code indicating what error occurred. + """ + pass + + +class SyncChangeListener: + + def on_change(self, sync_changes: list[SyncChange]): + """Called when incoming data changes are received from the server. + + Args: + sync_changes: List of SyncChange events representing the changes. + """ + pass + + +class SyncClient: + """Sync client is used to connect to an ObjectBox sync server. + + Use through the Sync class factory methods. + """ + + def __init__(self, store: Store, server_urls: list[str], + filter_variables: dict[str, str] | None = None): + """Creates a Sync client associated with the given store and options. + + This does not initiate any connection attempts yet: call start() to do so. + + Args: + store: The ObjectBox store to sync. + server_urls: List of server URLs to connect to. + filter_variables: Optional dictionary of filter variable names to values. + """ + self.__c_change_listener = None + self.__c_login_listener = None + self.__c_login_failure_listener = None + self.__c_connect_listener = None + self.__c_disconnect_listener = None + self.__c_error_listener = None + if not server_urls: + raise ValueError("Provide at least one server URL") + + if not Sync.is_available(): + raise RuntimeError( + 'Sync is not available in the loaded ObjectBox runtime library. ' + 'Please visit https://objectbox.io/sync/ for options.') + + self.__store = store + self.__server_urls = [url.encode('utf-8') for url in server_urls] + + self.__c_sync_client_ptr = c.obx_sync_urls(store.c_store(), + c.c_array_pointer(self.__server_urls, ctypes.c_char_p), + len(self.__server_urls)) + + for name, value in (filter_variables or {}).items(): + self.add_filter_variable(name, value) + + self.__store.add_store_close_listener(on_store_close=self.__close_sync_client_func()) + + def __close_sync_client_func(self): + def close_sync_client(): + self.close() + + return close_sync_client + + def __check_sync_ptr_not_null(self): + if self.__c_sync_client_ptr is None: + raise ValueError('SyncClient already closed') + + def set_credentials(self, credentials: SyncCredentials): + """Configure authentication credentials, depending on your server config. + + Args: + credentials: The credentials to use for authentication. + """ + self.__check_sync_ptr_not_null() + self.__credentials = credentials + if isinstance(credentials, SyncCredentialsNone): + c.obx_sync_credentials(self.__c_sync_client_ptr, credentials.type, None, 0) + elif isinstance(credentials, SyncCredentialsUserPassword): + c.obx_sync_credentials_user_password(self.__c_sync_client_ptr, + credentials.type, + credentials.username.encode('utf-8'), + credentials.password.encode('utf-8')) + elif isinstance(credentials, SyncCredentialsSecret): + c.obx_sync_credentials(self.__c_sync_client_ptr, credentials.type, + credentials.secret, + len(credentials.secret)) + + def set_multiple_credentials(self, credentials_list: list[SyncCredentials]): + """Like set_credentials, but accepts multiple credentials. + + However, does **not** support SyncCredentials.none(). + + Args: + credentials_list: List of credentials to use for authentication. + + Raises: + ValueError: If credentials_list is empty or contains SyncCredentials.none(). + """ + self.__check_sync_ptr_not_null() + if len(credentials_list) == 0: + raise ValueError("Provide at least one credential") + + for i in range(len(credentials_list)): + is_last = (i == len(credentials_list) - 1) + credentials = credentials_list[i] + + if isinstance(credentials, SyncCredentialsNone): + raise ValueError("SyncCredentials.none() is not supported, use set_credentials() instead") + + if isinstance(credentials, SyncCredentialsUserPassword): + c.obx_sync_credentials_add_user_password(self.__c_sync_client_ptr, + credentials.type, + credentials.username.encode('utf-8'), + credentials.password.encode('utf-8'), + is_last + ) + elif isinstance(credentials, SyncCredentialsSecret): + c.obx_sync_credentials_add(self.__c_sync_client_ptr, + credentials.type, + credentials.secret, + len(credentials.secret), + is_last) + + + def set_request_updates_mode(self, mode: SyncRequestUpdatesMode): + """Configures how sync updates are received from the server. + + If automatic updates are turned off, they will need to be requested manually. + + Args: + mode: The request updates mode to use. + """ + self.__check_sync_ptr_not_null() + if mode == SyncRequestUpdatesMode.MANUAL: + c_mode = c.RequestUpdatesMode.MANUAL + elif mode == SyncRequestUpdatesMode.AUTO: + c_mode = c.RequestUpdatesMode.AUTO + elif mode == SyncRequestUpdatesMode.AUTO_NO_PUSHES: + c_mode = c.RequestUpdatesMode.AUTO_NO_PUSHES + else: + raise ValueError(f"Invalid mode: {mode}") + c.obx_sync_request_updates_mode(self.__c_sync_client_ptr, c_mode) + + def get_sync_state(self) -> SyncState: + """Gets the current sync client state. + + Returns: + The current SyncState of this client. + """ + self.__check_sync_ptr_not_null() + c_state = c.obx_sync_state(self.__c_sync_client_ptr) + if c_state == c.SyncState.CREATED: + return SyncState.CREATED + elif c_state == c.SyncState.STARTED: + return SyncState.STARTED + elif c_state == c.SyncState.CONNECTED: + return SyncState.CONNECTED + elif c_state == c.SyncState.LOGGED_IN: + return SyncState.LOGGED_IN + elif c_state == c.SyncState.DISCONNECTED: + return SyncState.DISCONNECTED + elif c_state == c.SyncState.STOPPED: + return SyncState.STOPPED + elif c_state == c.SyncState.DEAD: + return SyncState.DEAD + else: + return SyncState.UNKNOWN + + def start(self): + """Once the sync client is configured, you can start it to initiate synchronization. + + This method triggers communication in the background and returns immediately. + The background thread will try to connect to the server, log-in and start + syncing data (depends on SyncRequestUpdatesMode). If the device, network or + server is currently offline, connection attempts will be retried later + automatically. If you haven't set the credentials in the options during + construction, call set_credentials() before start(). + """ + self.__check_sync_ptr_not_null() + c.obx_sync_start(self.__c_sync_client_ptr) + + def stop(self): + """Stops this sync client. Does nothing if it is already stopped.""" + self.__check_sync_ptr_not_null() + c.obx_sync_stop(self.__c_sync_client_ptr) + + def trigger_reconnect(self) -> bool: + """Triggers a reconnection attempt immediately. + + By default, an increasing backoff interval is used for reconnection attempts. + But sometimes the code using this API has additional knowledge and can + initiate a reconnection attempt sooner. + + Returns: + True if a reconnect was actually triggered. + """ + self.__check_sync_ptr_not_null() + return c.check_obx_success(c.obx_sync_trigger_reconnect(self.__c_sync_client_ptr)) + + def request_updates(self, subscribe_for_future_pushes: bool) -> bool: + """Request updates since we last synchronized our database. + + Additionally, you can subscribe for future pushes from the server, to let + it send us future updates as they come in. + Call cancel_updates() to stop the updates. + + Args: + subscribe_for_future_pushes: If True, also subscribe for future pushes. + + Returns: + True if the request was successful. + """ + self.__check_sync_ptr_not_null() + return c.check_obx_success(c.obx_sync_updates_request(self.__c_sync_client_ptr, subscribe_for_future_pushes)) + + def cancel_updates(self) -> bool: + """Cancel updates from the server so that it will stop sending updates. + + See also request_updates(). + + Returns: + True if the cancellation was successful. + """ + self.__check_sync_ptr_not_null() + return c.check_obx_success(c.obx_sync_updates_cancel(self.__c_sync_client_ptr)) + + @staticmethod + def protocol_version() -> int: + """Returns the protocol version this client uses.""" + return c.obx_sync_protocol_version() + + def protocol_server_version(self) -> int: + """Returns the protocol version of the server after a connection is + established (or attempted), zero otherwise. + """ + return c.obx_sync_protocol_version_server(self.__c_sync_client_ptr) + + def close(self): + """Closes and cleans up all resources used by this sync client. + + It can no longer be used afterwards, make a new sync client instead. + Does nothing if this sync client has already been closed. + """ + c.obx_sync_listener_error(self.__c_sync_client_ptr, None, None) + c.obx_sync_listener_login(self.__c_sync_client_ptr, None, None) + c.obx_sync_listener_login_failure(self.__c_sync_client_ptr, None, None) + c.obx_sync_listener_connect(self.__c_sync_client_ptr, None, None) + c.obx_sync_listener_disconnect(self.__c_sync_client_ptr, None, None) + c.obx_sync_listener_change(self.__c_sync_client_ptr, None, None) + c.obx_sync_close(self.__c_sync_client_ptr) + self.__c_sync_client_ptr = None + + def is_closed(self) -> bool: + """Returns if this sync client is closed and can no longer be used.""" + return self.__c_sync_client_ptr is None + + def set_login_listener(self, login_listener: SyncLoginListener): + """Sets a listener to observe login events (success/failure). + + Args: + login_listener: The listener to receive login events. + """ + self.__check_sync_ptr_not_null() + self.__c_login_listener = c.OBX_sync_listener_login(lambda arg: login_listener.on_logged_in()) + self.__c_login_failure_listener = c.OBX_sync_listener_login_failure( + lambda arg, sync_login_code: login_listener.on_login_failed(sync_login_code)) + c.obx_sync_listener_login( + self.__c_sync_client_ptr, + self.__c_login_listener, + None + ) + c.obx_sync_listener_login_failure( + self.__c_sync_client_ptr, + self.__c_login_failure_listener, + None + ) + + def set_connection_listener(self, connection_listener: SyncConnectionListener): + """Sets a listener to observe connection state changes (connect/disconnect). + + Args: + connection_listener: The listener to receive connection events. + """ + self.__check_sync_ptr_not_null() + self.__c_connect_listener = c.OBX_sync_listener_connect(lambda arg: connection_listener.on_connected()) + self.__c_disconnect_listener = c.OBX_sync_listener_disconnect(lambda arg: connection_listener.on_disconnected()) + c.obx_sync_listener_connect( + self.__c_sync_client_ptr, + self.__c_connect_listener, + None + ) + c.obx_sync_listener_disconnect( + self.__c_sync_client_ptr, + self.__c_disconnect_listener, + None + ) + + def set_error_listener(self, error_listener: SyncErrorListener): + """Sets a listener to observe sync error events. + + Args: + error_listener: The listener to receive error events. + """ + self.__check_sync_ptr_not_null() + self.__c_error_listener = c.OBX_sync_listener_error( + lambda arg, sync_error_code: error_listener.on_error(sync_error_code)) + c.obx_sync_listener_error( + self.__c_sync_client_ptr, + self.__c_error_listener, + None + ) + + def set_change_listener(self, change_listener: SyncChangeListener): + """Sets a listener to observe incoming data changes from the server. + + Args: + change_listener: The listener to receive change events. + """ + self.__check_sync_ptr_not_null() + + def c_change_callback(arg, sync_change_array_ptr): + sync_change_array = ctypes.cast(sync_change_array_ptr, ctypes.POINTER(OBX_sync_change_array)).contents + changes: list[SyncChange] = [] + for i in range(sync_change_array.count): + c_sync_change: c.OBX_sync_change = sync_change_array.list[i] + puts = [] + if c_sync_change.puts: + c_puts_id_array: c.OBX_id_array = ctypes.cast(c_sync_change.puts, c.OBX_id_array_p).contents + puts = list( + ctypes.cast(c_puts_id_array.ids, ctypes.POINTER(c.obx_id * c_puts_id_array.count)).contents) + removals = [] + if c_sync_change.removals: + c_removals_id_array: c.OBX_id_array = ctypes.cast(c_sync_change.removals, c.OBX_id_array_p).contents + removals = list( + ctypes.cast(c_removals_id_array.ids, + ctypes.POINTER(c.obx_id * c_removals_id_array.count)).contents) + changes.append(SyncChange( + entity_id=c_sync_change.entity_id, + puts=puts, + removals=removals + )) + change_listener.on_change(changes) + + self.__c_change_listener = c.OBX_sync_listener_change(c_change_callback) + c.obx_sync_listener_change( + self.__c_sync_client_ptr, + self.__c_change_listener, + None + ) + + def wait_for_logged_in_state(self, timeout_millis: int): + """Waits for the sync client to reach the logged-in state. + + Args: + timeout_millis: Maximum time to wait in milliseconds. + """ + self.__check_sync_ptr_not_null() + c.obx_sync_wait_for_logged_in_state(self.__c_sync_client_ptr, timeout_millis) + + def add_filter_variable(self, name: str, value: str): + """Adds or replaces a Sync filter variable value for the given name. + + Eventually, existing values for the same name are replaced. + + Sync client filter variables can be used in server-side Sync filters to + filter out objects that do not match the filters. Filter variables must be + added before login, so before calling start(). + + See also remove_filter_variable() and remove_all_filter_variables(). + + Args: + name: The name of the filter variable. + value: The value of the filter variable. + """ + self.__check_sync_ptr_not_null() + c.obx_sync_filter_variables_put(self.__c_sync_client_ptr, name.encode('utf-8'), value.encode('utf-8')) + + def remove_filter_variable(self, name: str): + """Removes a previously added Sync filter variable value. + + See also add_filter_variable() and remove_all_filter_variables(). + + Args: + name: The name of the filter variable to remove. + """ + self.__check_sync_ptr_not_null() + c.obx_sync_filter_variables_remove(self.__c_sync_client_ptr, name.encode('utf-8')) + + def remove_all_filter_variables(self): + """Removes all previously added Sync filter variable values. + + See also add_filter_variable() and remove_filter_variable(). + """ + self.__check_sync_ptr_not_null() + c.obx_sync_filter_variables_remove_all(self.__c_sync_client_ptr) + + def get_outgoing_message_count(self, limit: int = 0) -> int: + """Count the number of messages in the outgoing queue, i.e. those waiting + to be sent to the server. + + By default, counts all messages without any limitation. For a lower number + pass a limit that's enough for your app logic. + + Note: This call uses a (read) transaction internally: + 1) It's not just a "cheap" return of a single number. While this will + still be fast, avoid calling this function excessively. + 2) The result follows transaction view semantics, thus it may not always + match the actual value. + + Args: + limit: Optional limit for counting messages. Default is 0 (no limit). + + Returns: + The number of messages in the outgoing queue. + """ + self.__check_sync_ptr_not_null() + outgoing_message_count = ctypes.c_uint64(0) + c.obx_sync_outgoing_message_count(self.__c_sync_client_ptr, limit, ctypes.byref(outgoing_message_count)) + return outgoing_message_count.value + + +class Sync: + """ObjectBox Sync makes data available and synchronized across devices, + online and offline. + + Start a client using Sync.client() and connect to a remote server. + """ + __sync_clients: dict[Store, SyncClient] = {} + + @staticmethod + def is_available() -> bool: + """Returns True if the loaded ObjectBox native library supports Sync.""" + return c.obx_has_feature(c.Feature.Sync) + + @staticmethod + def client( + store: Store, + server_url: str, + credential: SyncCredentials, + filter_variables: dict[str, str] | None = None + ) -> SyncClient: + """Creates a Sync client associated with the given store and configures it + with the given options. + + This does not initiate any connection attempts yet, call SyncClient.start() + to do so. + + Before SyncClient.start(), you can still configure some aspects of the + client, e.g. its request updates mode. + + To configure Sync filter variables, pass variable names mapped to their + value to filter_variables. Sync client filter variables can be used in + server-side Sync filters to filter out objects that do not match the filter. + + Args: + store: The ObjectBox store to sync. + server_url: The URL of the sync server to connect to. + credential: The credentials to use for authentication. + filter_variables: Optional dictionary of filter variable names to values. + + Returns: + A configured SyncClient instance. + """ + client = SyncClient(store, [server_url], filter_variables) + client.set_credentials(credential) + return client + + @staticmethod + def client_multi_creds( + store: Store, + server_url: str, + credentials_list: list[SyncCredentials], + filter_variables: dict[str, str] | None = None + ) -> SyncClient: + """Like client(), but accepts a list of credentials. + + When passing multiple credentials, does **not** support + SyncCredentials.none(). + + Args: + store: The ObjectBox store to sync. + server_url: The URL of the sync server to connect to. + credentials_list: List of credentials to use for authentication. + filter_variables: Optional dictionary of filter variable names to values. + + Returns: + A configured SyncClient instance. + """ + client = SyncClient(store, [server_url], filter_variables) + client.set_multiple_credentials(credentials_list) + return client + + @staticmethod + def client_multi_urls( + store: Store, + server_urls: list[str], + credential: SyncCredentials, + filter_variables: dict[str, str] | None = None + ) -> SyncClient: + """Like client(), but accepts a list of URLs to work with multiple servers. + + Args: + store: The ObjectBox store to sync. + server_urls: List of server URLs to connect to. + credential: The credentials to use for authentication. + filter_variables: Optional dictionary of filter variable names to values. + + Returns: + A configured SyncClient instance. + """ + client = SyncClient(store, server_urls, filter_variables) + client.set_credentials(credential) + return client + + @staticmethod + def client_multi_creds_multi_urls( + store: Store, + server_urls: list[str], + credentials_list: list[SyncCredentials], + filter_variables: dict[str, str] | None = None + ) -> SyncClient: + """Like client(), but accepts a list of credentials and a list of URLs to + work with multiple servers. + + When passing multiple credentials, does **not** support + SyncCredentials.none(). + + Args: + store: The ObjectBox store to sync. + server_urls: List of server URLs to connect to. + credentials_list: List of credentials to use for authentication. + filter_variables: Optional dictionary of filter variable names to values. + + Returns: + A configured SyncClient instance. + + Raises: + ValueError: If a sync client is already active for the given store. + """ + if store in Sync.__sync_clients: + raise ValueError('Only one sync client can be active for a store') + client = SyncClient(store, server_urls, filter_variables) + client.set_multiple_credentials(credentials_list) + Sync.__sync_clients[store] = client + return client diff --git a/setup.py b/setup.py index bda3122..b328c9f 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,17 @@ +import os + import setuptools import objectbox with open("README.md", "r") as fh: long_description = fh.read() +package_name = "objectbox" +if "OBX_BUILD_SYNC" in os.environ: + package_name = "objectbox-sync" + setuptools.setup( - name="objectbox", + name=package_name, version=str(objectbox.version), author="ObjectBox", description="ObjectBox is a superfast lightweight database for objects", diff --git a/tests/conftest.py b/tests/conftest.py index 7c4bca4..8ccaf22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import pytest from objectbox.logger import logger +from objectbox.sync import SyncLoginListener, SyncConnectionListener, SyncErrorListener from common import * @@ -19,3 +20,54 @@ def test_store(): store = create_test_store() yield store store.close() + +class TestLoginListener(SyncLoginListener): + def __init__(self): + self.logged_in_called = False + self.login_failure_code = None + + def on_logged_in(self): + self.logged_in_called = True + + def on_login_failed(self, sync_login_code: int): + self.login_failure_code = sync_login_code + + +class TestConnectionListener(SyncConnectionListener): + def __init__(self): + self.connected_called = False + self.disconnected_called = False + + def on_connected(self): + self.connected_called = True + + def on_disconnected(self): + self.disconnected_called = True + + +class TestErrorListener(SyncErrorListener): + def __init__(self): + self.sync_error_code = None + + def on_error(self, sync_error_code: int): + self.sync_error_code = sync_error_code + +@pytest.fixture +def connection_listener(): + listener = TestConnectionListener() + yield listener + listener.connected_called = False + listener.disconnected_called = False + +@pytest.fixture +def login_listener(): + listener = TestLoginListener() + yield listener + listener.logged_in_called = False + listener.login_failure_code = None + +@pytest.fixture +def error_listener(): + listener = TestErrorListener() + yield listener + listener.sync_error_code = None \ No newline at end of file diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 0000000..f5a5b39 --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,136 @@ +from collections.abc import Callable +from time import sleep + +import pytest + +from objectbox.exceptions import IllegalArgumentError +from objectbox.sync import * + + +def test_sync_protocol_version(): + version = SyncClient.protocol_version() + assert version >= 1 + +def test_sync_client_states(test_store): + server_urls = ["ws://localhost:9999"] + client = SyncClient(test_store, server_urls) + assert client.get_sync_state() == SyncState.CREATED + client.start() + assert client.get_sync_state() == SyncState.STARTED + client.stop() + assert client.get_sync_state() == SyncState.STOPPED + client.close() + +def test_sync_listener(test_store, login_listener, connection_listener): + server_urls = ["ws://127.0.0.1:9999"] + client = SyncClient(test_store, server_urls) + client.set_credentials(SyncCredentials.shared_secret_string("shared_secret")) + client.set_login_listener(login_listener) + client.set_connection_listener(connection_listener) + + client.start() + sleep(1) + client.stop() + client.close() + + assert login_listener.login_failure_code is not None + assert login_listener.login_failure_code == SyncCode.CREDENTIALS_REJECTED + assert connection_listener.connected_called + assert connection_listener.disconnected_called + + +def test_filter_variables(test_store): + server_urls = ["ws://localhost:9999"] + + filter_vars = { + "name1": "val1", + "name2": "val2" + } + client = SyncClient(test_store, server_urls, filter_vars) + + client.add_filter_variable("name3", "val3") + client.remove_filter_variable("name1") + client.add_filter_variable("name4", "val4") + client.remove_all_filter_variables() + + with pytest.raises(IllegalArgumentError, match="Filter variables must have a name"): + client.add_filter_variable("", "val5") + + client.close() + + +def test_outgoing_message_count(test_store): + server_urls = ["ws://localhost:9999"] + client = SyncClient(test_store, server_urls) + + count = client.get_outgoing_message_count() + assert count == 0 + + count_limited = client.get_outgoing_message_count(limit=10) + assert count_limited == 0 + + client.close() + + with pytest.raises(IllegalArgumentError, match='Argument "sync" must not be null'): + client.get_outgoing_message_count() + + +def test_multiple_credentials(test_store): + server_urls = ["ws://localhost:9999"] + client = SyncClient(test_store, server_urls) + + # empty list should raise ValueError + with pytest.raises(ValueError, match='Provide at least one credential'): + client.set_multiple_credentials([]) + + # SyncCredentials.none() is not supported with multiple credentials + with pytest.raises(ValueError, match=r'SyncCredentials.none\(\) is not supported, use set_credentials\(\) instead'): + client.set_multiple_credentials([SyncCredentials.none()]) + + client.set_multiple_credentials([ + SyncCredentials.google_auth("token_google"), + SyncCredentials.user_and_password("user1", "password"), + SyncCredentials.shared_secret_string("secret1"), + SyncCredentials.jwt_id_token("token1"), + SyncCredentials.jwt_access_token("token2"), + SyncCredentials.jwt_refresh_token("token3"), + SyncCredentials.jwt_custom_token("token4") + ]) + + +def test_client_closed_when_store_closed(test_store): + server_urls = ["ws://localhost:9999"] + client = SyncClient(test_store, server_urls) + + assert not client.is_closed() + test_store.close() + assert client.is_closed() + + +def assert_raises_value_error(fn: Callable[[], object | None], message: str | None = None): + with pytest.raises(ValueError, match=message): + fn() + + +def test_client_access_after_close_throws_error(test_store): + server_urls = ["ws://localhost:9999"] + client = SyncClient(test_store, server_urls) + client.close() + + assert client.is_closed() + + match_error = "SyncClient already closed" + + assert_raises_value_error(message=match_error, fn=lambda: client.start()) + assert_raises_value_error(message=match_error, fn=lambda: client.stop()) + assert_raises_value_error(message=match_error, fn=lambda: client.get_sync_state()) + assert_raises_value_error(message=match_error, fn=lambda: client.get_outgoing_message_count()) + assert_raises_value_error(message=match_error, fn=lambda: client.set_credentials(SyncCredentials.none())) + assert_raises_value_error(message=match_error, + fn=lambda: client.set_credentials(SyncCredentials.google_auth("token_google"))) + assert_raises_value_error(message=match_error, fn=lambda: client.set_multiple_credentials([ + SyncCredentials.google_auth("token_google"), + SyncCredentials.user_and_password("user1", "password") + ])) + assert_raises_value_error(message=match_error, + fn=lambda: client.set_request_updates_mode(SyncRequestUpdatesMode.AUTO))