diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 21816edc..eb41c043 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.13", "3.12", "3.11", "3.10", "3.9"] + python-version: ["3.13", "3.12", "3.11", "3.10"] neo4j-version: ["community", "enterprise", "5.5-enterprise", "4.4-enterprise", "4.4-community"] steps: diff --git a/.gitignore b/.gitignore index 8ff170e3..b4e6978b 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ coverage_report/ .coverage* .DS_STORE cov.xml +test/data/model_diagram.* diff --git a/.sonarcloud.properties b/.sonarcloud.properties index 05fed61f..bd5fb8ad 100644 --- a/.sonarcloud.properties +++ b/.sonarcloud.properties @@ -1,3 +1,6 @@ sonar.sources = neomodel/ sonar.tests = test/ -sonar.python.version = 3.9, 3.10, 3.11, 3.12, 3.13 \ No newline at end of file +sonar.python.version = 3.10, 3.11, 3.12, 3.13 +sonar.issue.ignore.multicriteria=e1 +sonar.issue.ignore.multicriteria.e1.ruleKey=python:S100 +sonar.issue.ignore.multicriteria.e1.resourceKey=**/neomodel/config.py \ No newline at end of file diff --git a/Changelog b/Changelog index 2af22d7c..349f712c 100644 --- a/Changelog +++ b/Changelog @@ -1,3 +1,14 @@ +Version 6.0.0 2025-xx +* Modernize config object, using a dataclass with typing, runtime and update validation rules, and environment variables support +* Fix async support of parallel transactions, using ContextVar +* Introduces merge_by parameter for batch operations to customize merge behaviour (label and property keys) +* Enforce strict cardinality check by default +* Refactor internal code: core.py file is now split into smaller files for database, node, transaction +* Fix object resolution for maps and lists Cypher objects, even when nested. This changes the way you can access lists in your Cypher results, see documentation for more info +* Make AsyncDatabase / Database a true singleton for clarity +* Remove deprecated methods (including fetch_relations & traverse_relations, replaced with traverse ; database operations like clear_neo4j_database or change_neo4j_password have been moved to db/adb singleton internal methods) +* Housekeeping and bug fixes + Version 5.5.3 2025-09 * Fix duplicated code issue in the advanced querying methods * Remove py.typed - this was a premature change, we should write stubs for full typing support first diff --git a/README.md b/README.md index c7dc01d0..bc62c0ae 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,12 @@ GitHub repo found at . # Requirements +**For neomodel releases 6.x :** + +- Python 3.10+ +- Neo4j 2025.x.x, 5.x, 4.4 (LTS) +- Neo4j Enterprise, Community and Aura are supported + **For neomodel releases 5.x :** - Python 3.8+ @@ -37,27 +43,19 @@ GitHub repo found at . Available on [readthedocs](http://neomodel.readthedocs.org). -# New in 5.4.0 - -This version adds many new features, expanding neomodel's querying capabilities. Those features were kindly contributed back by the [OpenStudyBuilder team](https://openstudybuilder.com/). A VERY special thanks to [@tonioo](https://github.com/tonioo) for the integration work. - -There are too many new capabilities here, so I advise you to start by looking at the full summary example in the [Getting Started guide](https://neomodel.readthedocs.io/en/latest/getting_started.html#full-example). It will then point you to the various relevant sections. +# New in 6.0.0 -We also validated support for [Python 3.13](https://docs.python.org/3/whatsnew/3.13.html). +From now on, neomodel will use **SemVer (major.minor.patch)** for versioning. -# New in 5.3.0 +This version introduces a modern configuration system, using a dataclass with typing, runtime and update validation rules, and environment variables support. +See the [documentation](https://neomodel.readthedocs.io/en/latest/configuration.html) section for more details. -neomodel now supports asynchronous programming, thanks to the [Neo4j driver async API](https://neo4j.com/docs/api/python-driver/current/async_api.html). The [documentation](http://neomodel.readthedocs.org) has been updated accordingly, with an updated getting started section, and some specific documentation for the async API. +[Semantic Indexes](https://neomodel.readthedocs.io/en/latest/semantic_indexes.html#) (Vector and Full-text) are now natively supported so you do not have to use a custom Cypher query. Special thanks to @greengori11a for this. -# Breaking changes in 5.3.0 +### Breaking changes -- config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` script instead. _Note : this is because of the addition of async, but also because it might lead to uncontrolled creation of indexes/constraints. The script makes you more in control of said creation._ -- The Database class has been moved into neomodel.sync_.core - and a new AsyncDatabase introduced into neomodel.async_.core -- Based on Python version [status](https://devguide.python.org/versions/), -neomodel will be dropping support for Python 3.7 in an upcoming release -(5.3 or later). _This does not mean neomodel will stop working on Python 3.7, but -it will no longer be tested against it_ -- Some standalone methods have been refactored into the Database() class. Check the [documentation](http://neomodel.readthedocs.org) for a full list. +* List object resolution from Cypher was creating "2-depth" lists for no apparent reason. This release fixes this so that, for example "RETURN collect(node)" will return the nodes directly as a list in the result. In other words, you can extract this list at `results[0][0]` instead of `results[0][0][0]` +* See more breaking changes in the [documentation](http://neomodel.readthedocs.org) # Installation @@ -84,7 +82,7 @@ You can find some performance tests made using Locust [in this repo](https://git Two learnings from this : * The wrapping of the driver made by neomodel is very thin performance-wise : it does not add a lot of overhead ; -* When used in a concurrent fashion, async neomodel is faster than concurrent sync neomodel, and a lot of faster than serial queries. +* When used in a concurrent fashion, async neomodel is faster than concurrent sync neomodel, and a lot faster than serial queries. # Contributing diff --git a/doc/source/batch.rst b/doc/source/batch.rst index 641ef25b..31cf5d9e 100644 --- a/doc/source/batch.rst +++ b/doc/source/batch.rst @@ -22,48 +22,200 @@ Create multiple nodes at once in a single transaction:: create_or_update() ------------------ -Atomically create or update nodes in a single operation:: +Atomically create or update nodes in a single operation. +The **required** and **unique** properties are used as keys to match nodes, +all other properties being used only on the resulting write operation. +For example:: + + class Person(StructuredNode): + name = StringProperty(required=True) + age = IntegerProperty() people = Person.create_or_update( - {'name': 'Tim', 'age': 83}, - {'name': 'Bob', 'age': 23}, - {'name': 'Jill', 'age': 34}, + {'name': 'Tim', 'age': 83}, # created + {'name': 'Bob', 'age': 23}, # created + {'name': 'Jill', 'age': 34}, # created ) more_people = Person.create_or_update( - {'name': 'Tim', 'age': 73}, - {'name': 'Bob', 'age': 35}, - {'name': 'Jane', 'age': 24}, + {'name': 'Tim', 'age': 73}, # updated + {'name': 'Bob', 'age': 35}, # updated + {'name': 'Jane', 'age': 24}, # created ) -This is useful for ensuring data is up to date, each node is matched by its required and/or unique properties. Any -additional properties will be set on a newly created or an existing node. +Custom Merge Keys +~~~~~~~~~~~~~~~~~ +By default, neomodel uses all required properties as merge keys. +However, you can specify custom merge criteria using the ``merge_by`` parameter:: + + class User(StructuredNode): + username = StringProperty(required=True) + email = StringProperty(required=True) + full_name = StringProperty() + age = IntegerProperty() + + # Default behavior (merge by username + email) + users = User.create_or_update({ + 'username': 'johndoe', + 'email': 'john@example.com', + 'age': 30 + }) + + # Custom merge by email only + users = User.create_or_update({ + 'username': 'johndoe', + 'email': 'john@example.com', + 'age': 31 + }, merge_by={'keys': ['email']}) + + # Custom merge by username only + users = User.create_or_update({ + 'username': 'johndoe', + 'email': 'john.doe@newcompany.com', + 'age': 32 + }, merge_by={'label': 'User', 'keys': ['username']}) + +The ``merge_by`` parameter accepts a dictionary with: +- ``label``: The Neo4j label to match against (optional, defaults to the node's inherited labels) +- ``keys``: The property name(s) to use as the merge key(s). + +This is particularly useful when you want to merge nodes based on specific properties +rather than all required properties, or when you need to merge based on properties +that are not required. + +Examples of different merge key configurations:: + + # Single key (string) + users = User.create_or_update({ + 'username': 'johndoe', + 'email': 'john@example.com', + 'age': 30 + }, merge_by={'keys': ['email']}) + + # Multiple keys (list) + users = User.create_or_update({ + 'username': 'johndoe', + 'email': 'john@example.com', + 'age': 30 + }, merge_by={'label': 'User', 'keys': ['username', 'email']}) + + # Multiple keys with different label + users = User.create_or_update({ + 'username': 'johndoe', + 'email': 'john@example.com', + 'age': 30 + }, merge_by={'label': 'Person', 'keys': ['email', 'age']}) # For when your node has multiple labels + + +Only explicitly provided properties will be updated on the node in all other cases:: + class NodeWithDefaultProp(AsyncStructuredNode): + name = StringProperty(required=True) + age = IntegerProperty(default=30) + other_prop = StringProperty() + + node = await NodeWithDefaultProp.create_or_update({"name": "Tania", "age": 20}) + assert node[0].name == "Tania" + assert node[0].age == 20 + + node = await MultiRequiredPropNode.create_or_update( + {"name": "Tania", "other_prop": "other"} + ) + assert node[0].name == "Tania" + assert ( + node[0].age == 20 + ) # Tania is still 20 even though default says she should be 30 + assert ( + node[0].other_prop == "other" + ) # She does have a brand new other_prop, lucky her ! + + +However, if fields used as keys have default values, those default values will be used if the property is omitted in your call. +This means that when using `UniqueIdProperty`, which is both unique and has a default value, if you do not pass it explicitly, +it will generate a new (random) value for it, and thus create a new node instead of updating an existing one:: + + class UniquePerson(StructuredNode): + uid = UniqueIdProperty() + name = StringProperty(required=True) + + unique_person = UniquePerson.create_or_update({"name": "Tim"}) # created + unique_person = UniquePerson.create_or_update({"name": "Tim"}) # created again with a new uid + +.. attention:: + This has been raised as an [issue in GitHub](https://github.com/neo4j-contrib/neomodel/issues/807). + While it is not a bug in itself, it is a deviation from the expected behavior of the function, and thus may be unexpected. + Therefore, an idea would be to refactor the batch mechanism to allow users to specify which properties are used as keys to match nodes. -It is important to provide unique identifiers where known, any fields with default values that are omitted will be generated. get_or_create() --------------- -Atomically get or create nodes in a single operation:: +Atomically get or create nodes in a single operation. +For example:: people = Person.get_or_create( - {'name': 'Tim'}, - {'name': 'Bob'}, + {'name': 'Tim'}, # created + {'name': 'Bob'}, # created ) people_with_jill = Person.get_or_create( - {'name': 'Tim'}, - {'name': 'Bob'}, - {'name': 'Jill'}, + {'name': 'Tim'}, # fetched + {'name': 'Bob'}, # fetched + {'name': 'Jill'}, # created ) # are same nodes assert people[0] == people_with_jill[0] assert people[1] == people_with_jill[1] -This is useful for ensuring specific nodes exist only and all required properties must be specified to ensure -uniqueness. In this example 'Tim' and 'Bob' are created on the first call, and are retrieved in the second call. +The **required** and **unique** properties are used as keys to match nodes, +all other properties being used only when a new node is created. +For example:: + class Person(StructuredNode): + name = StringProperty(required=True) + age = IntegerProperty() + + node = await Person.get_or_create({"name": "Tania", "age": 20}) + assert node[0].name == "Tania" + assert node[0].age == 20 + + node = await MultiRequiredPropNode.get_or_create({"name": "Tania", "age": 30}) + assert node[0].name == "Tania" + assert node[0].age == 20 # Tania was fetched and not created, age is still 20 + +Custom Merge Keys +~~~~~~~~~~~~~~~~~ +The ``get_or_create()`` method also supports the ``merge_by`` parameter for custom merge criteria:: + + class User(StructuredNode): + username = StringProperty(required=True, unique_index=True) + email = StringProperty(required=True, unique_index=True) + full_name = StringProperty() + age = IntegerProperty() + + # Default behavior (merge by username + email) + users = User.get_or_create({ + 'username': 'johndoe', + 'email': 'john@example.com', + 'age': 30 + }) + + # Custom merge by email only + users = User.get_or_create({ + 'username': 'johndoe', + 'email': 'john@example.com', + 'age': 31 + }, merge_by={'keys': ['email']}) + + # Custom merge by username only + users = User.get_or_create({ + 'username': 'johndoe', + 'email': 'john.doe@newcompany.com', + 'age': 32 + }, merge_by={'label': 'User', 'keys': ['username']}) + +The same ``merge_by`` parameter format applies to both ``create_or_update()`` and ``get_or_create()`` methods. + Additionally, get_or_create() allows the "relationship" parameter to be passed. When a relationship is specified, the -matching is done based on that relationship and not globally:: +matching is done based on that relationship and not globally. The relationship becomes one of the keys to match nodes:: class Dog(StructuredNode): name = StringProperty(required=True) @@ -81,6 +233,3 @@ matching is done based on that relationship and not globally:: # not the same gizmo assert bobs_gizmo[0] != tims_gizmo[0] - -In case when the only required property is unique, the operation is redundant. However with simple required properties, -the relationship becomes a part of the unique identifier. diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index e7e38607..d4f74cce 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -1,133 +1,354 @@ Configuration ============= -This section is covering the Neomodel 'config' module and its variables. +Neomodel provides a modern, type-safe configuration system for connecting to your Neo4j database. This guide covers the recommended approach using the new dataclass-based configuration system (available from version 6.0), with backward compatibility information for existing code. -.. _connection_options_doc: +.. _configuration_options_doc: -Connection ----------- +Database Connection Setup +------------------------- -There are two ways to define your connection to the database : +The primary way to configure neomodel is to set up your database connection. There are two approaches: -1. Provide a Neo4j URL and some options - Driver will be managed by neomodel -2. Create your own Neo4j driver and pass it to neomodel +1. **Neomodel-managed connection** (recommended) - Let neomodel handle the driver lifecycle +2. **Self-managed connection** - Provide your own Neo4j driver instance -neomodel-managed (default) --------------------------- +Neomodel-managed Connection (Recommended) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Set the connection URL:: +This is the simplest and most common approach. Neomodel will create and manage the Neo4j driver for you. - config.DATABASE_URL = 'bolt://neo4j:neo4j@localhost:7687` +Basic connection setup:: -Adjust driver configuration - these options are only available for this connection method:: + from neomodel import get_config + + config = get_config() + config.database_url = 'bolt://neo4j:password@localhost:7687' - config.MAX_CONNECTION_POOL_SIZE = 100 # default - config.CONNECTION_ACQUISITION_TIMEOUT = 60.0 # default - config.CONNECTION_TIMEOUT = 30.0 # default - config.ENCRYPTED = False # default - config.KEEP_ALIVE = True # default - config.MAX_CONNECTION_LIFETIME = 3600 # default - config.MAX_CONNECTION_POOL_SIZE = 100 # default - config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default - config.RESOLVER = None # default - config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default - config.USER_AGENT = neomodel/v5.5.1 # default +You can also set the database name separately:: -Setting the database name, if different from the default one:: + config.database_url = 'bolt://neo4j:password@localhost:7687' + config.database_name = 'mydatabase' - # Using the URL only - config.DATABASE_URL = 'bolt://neo4j:neo4j@localhost:7687/mydb` +Advanced driver configuration, for example:: + + config.connection_timeout = 60.0 + config.max_connection_pool_size = 50 + config.encrypted = True + config.keep_alive = True + +Self-managed Connection +~~~~~~~~~~~~~~~~~~~~~~~ + +If you need more control over the driver configuration, you can provide your own Neo4j driver:: + + from neo4j import GraphDatabase + from neomodel import get_config + + # Create your own driver + driver = GraphDatabase.driver( + 'bolt://localhost:7687', + auth=('neo4j', 'password'), + encrypted=True, + max_connection_lifetime=3600 + ) + + # Pass it to neomodel + config = get_config() + config.driver = driver + +.. note:: + When using a self-managed driver, you are responsible for closing it when your application shuts down. +Modern Configuration System (Version 6.0+) +------------------------------------------ + +Neomodel 6.0 introduces a modern dataclass-based configuration system with the following benefits: + +* **Type Safety**: All configuration values are properly typed +* **Validation**: Configuration values are validated at startup and when changed +* **Environment Variables**: Automatic loading from environment variables +* **IDE Support**: Better autocomplete and type checking + +Using the Modern Configuration API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Access and modify configuration:: + + from neomodel import get_config, set_config, reset_config + + # Get the current configuration + config = get_config() + print(config.database_url) + print(config.force_timezone) + + # Update configuration + config.update(database_url='bolt://new:url@localhost:7687') + + # Set a custom configuration + from neomodel import NeomodelConfig + custom_config = NeomodelConfig( + database_url='bolt://custom:url@localhost:7687', + force_timezone=True + ) + set_config(custom_config) + + # Reset to defaults (loads from environment variables or defaults) + reset_config() + +Environment Variable Support +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Configuration is automatically loaded from environment variables using the ``NEOMODEL_`` prefix: + +* ``NEOMODEL_DATABASE_URL`` - Database connection URL +* ``NEOMODEL_DATABASE_NAME`` - Database name for custom driver +* ``NEOMODEL_CONNECTION_ACQUISITION_TIMEOUT`` - Connection acquisition timeout +* ``NEOMODEL_CONNECTION_TIMEOUT`` - Connection timeout +* ``NEOMODEL_ENCRYPTED`` - Enable encrypted connections +* ``NEOMODEL_KEEP_ALIVE`` - Enable keep-alive connections +* ``NEOMODEL_MAX_CONNECTION_LIFETIME`` - Maximum connection lifetime +* ``NEOMODEL_MAX_CONNECTION_POOL_SIZE`` - Maximum connection pool size +* ``NEOMODEL_MAX_TRANSACTION_RETRY_TIME`` - Maximum transaction retry time +* ``NEOMODEL_USER_AGENT`` - User agent string +* ``NEOMODEL_FORCE_TIMEZONE`` - Force timezone-aware datetime objects +* ``NEOMODEL_SOFT_CARDINALITY_CHECK`` - Enable soft cardinality checking +* ``NEOMODEL_CYPHER_DEBUG`` - Enable Cypher debug logging +* ``NEOMODEL_SLOW_QUERIES`` - Threshold in seconds for slow query logging (0 = disabled) + +.. note:: + For boolean values, the following strings are supported: ``true``, ``1``, ``yes``, ``on``, ``false``, ``0``, ``no``, ``off``. + +Example:: + + # Set environment variables + export NEOMODEL_DATABASE_URL='bolt://neo4j:password@localhost:7687' + export NEOMODEL_FORCE_TIMEZONE='true' + export NEOMODEL_CONNECTION_TIMEOUT='60.0' + + # Configuration will be automatically loaded from environment + from neomodel import config + print(config.DATABASE_URL) # 'bolt://neo4j:password@localhost:7687' + print(config.FORCE_TIMEZONE) # True + print(config.CONNECTION_TIMEOUT) # 60.0 + +.. autofunction:: neomodel.config.get_config + +See also the :class:`~neomodel.config.NeomodelConfig` dataclass for a full list of fields, typing and options: + +.. autoclass:: neomodel.config.NeomodelConfig + :members: + :undoc-members: + :show-inheritance: + +Configuration Validation +~~~~~~~~~~~~~~~~~~~~~~~~ + +The configuration system validates values when they are set:: + + from neomodel import get_config + + config = get_config() + + # This will raise a ValueError + try: + config.connection_timeout = -1 + except ValueError as e: + print(f"Validation error: {e}") + + # Invalid database URLs are also caught + try: + config.database_url = "invalid-url" + except ValueError as e: + print(f"Validation error: {e}") + +Legacy Configuration (Backward Compatibility) +--------------------------------------------- + +.. warning:: + The legacy configuration approach described below is **deprecated** and will be removed in a future version. + Deprecation warnings are shown when using the legacy API to encourage migration to the modern configuration system. + +.. note:: + The following section describes the legacy configuration approach, available in neomodel 5.5.3 and earlier. + While still supported for backward compatibility, we recommend using the modern configuration system described above. + +For existing code, the traditional uppercase configuration attributes are still available:: + + from neomodel import config + + # Legacy approach (still works but shows deprecation warnings) + config.DATABASE_URL = 'bolt://neo4j:neo4j@localhost:7687' + config.MAX_CONNECTION_POOL_SIZE = 100 + config.CONNECTION_ACQUISITION_TIMEOUT = 60.0 + config.CONNECTION_TIMEOUT = 30.0 + config.ENCRYPTED = False + config.KEEP_ALIVE = True + config.MAX_CONNECTION_LIFETIME = 3600 + config.MAX_TRANSACTION_RETRY_TIME = 30.0 + config.RESOLVER = None + config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES + config.USER_AGENT = 'neomodel/v5.5.1' + +Setting the database name with legacy approach:: + + # Using the URL only + config.DATABASE_URL = 'bolt://neo4j:neo4j@localhost:7687/mydb' + # Using config option - config.DATABASE_URL = 'bolt://neo4j:neo4j@localhost:7687` + config.DATABASE_URL = 'bolt://neo4j:neo4j@localhost:7687' config.DATABASE_NAME = 'mydb' -self-managed ------------- - -Create a Neo4j driver:: +Legacy self-managed driver setup:: from neo4j import GraphDatabase my_driver = GraphDatabase().driver('bolt://localhost:7687', auth=('neo4j', 'password')) config.DRIVER = my_driver -See the `driver documentation ` here. +.. note:: + Only the synchronous driver works with the legacy self-managed approach. For async drivers, use the modern configuration system. -This mode allows you to use all the available driver options that neomodel doesn't implement, for example auth tokens for SSO. -Note that you have to manage the driver's lifecycle yourself. +Deprecation Warnings +~~~~~~~~~~~~~~~~~~~~ -However, everything else is still handled by neomodel : sessions, transactions, etc... +When using the legacy configuration API, deprecation warnings are shown to encourage migration:: -NB : Only the synchronous driver will work in this way. See the next section for the preferred method, and how to pass an async driver instance. + import warnings + from neomodel import config + + # This will show a deprecation warning + config.DATABASE_URL = 'bolt://neo4j:password@localhost:7687' + # DeprecationWarning: Setting config.DATABASE_URL is deprecated and will be removed in a future version. + # Use the modern configuration API instead: + # from neomodel import get_config; config = get_config(); config.database_url = value -Change/Close the connection ---------------------------- +To suppress deprecation warnings temporarily (not recommended for production):: -Optionally, you can change the connection at any time by calling ``set_connection``:: + import warnings + warnings.filterwarnings('ignore', category=DeprecationWarning, module='neomodel.config') + + # Now legacy config access won't show warnings + config.DATABASE_URL = 'bolt://neo4j:password@localhost:7687' - from neomodel import db - # Using URL - auto-managed - db.set_connection(url='bolt://neo4j:neo4j@localhost:7687') +Migration Guide +~~~~~~~~~~~~~~~ + +To migrate from legacy to modern configuration: + +**Before (Legacy):** +:: + + from neomodel import config + config.DATABASE_URL = 'bolt://neo4j:password@localhost:7687' + config.FORCE_TIMEZONE = True + config.MAX_CONNECTION_POOL_SIZE = 50 + +**After (Modern):** +:: + + from neomodel import get_config + config = get_config() + config.database_url = 'bolt://neo4j:password@localhost:7687' + config.force_timezone = True + config.max_connection_pool_size = 50 + + +Managing Connections +-------------------- + +Changing Connections +~~~~~~~~~~~~~~~~~~~~ + +You can change the connection at any time using the modern configuration API:: + + from neomodel import get_config + + config = get_config() + config.database_url = 'bolt://new:url@localhost:7687' # Using self-managed driver db.set_connection(driver=my_driver) -The new connection url will be applied to the current thread or process. +Or using the legacy approach:: + + from neomodel import db + # Using URL - auto-managed + db.set_connection(url='bolt://neo4j:neo4j@localhost:7687') -Since Neo4j version 5, driver auto-close is deprecated. Make sure to close the connection anytime you want to replace it, -as well as at the end of your application's lifecycle by calling ``close_connection``:: +Closing Connections +~~~~~~~~~~~~~~~~~~~ + +Since Neo4j version 5, driver auto-close is deprecated. Make sure to close the connection when your application shuts down:: from neomodel import db db.close_connection() - # If you then want a new connection - db.set_connection(url=url) +This will close the Neo4j driver and clean up neomodel's internal resources. -This will close the Neo4j driver, and clean up everything that neomodel creates for its internal workings. +Security Best Practices +----------------------- -Protect your credentials ------------------------- +Protect Your Credentials +~~~~~~~~~~~~~~~~~~~~~~~~ You should `avoid setting database access credentials in plain sight `_. Neo4J defines a number of -`environment variables `_ that are used in its tools and these can be re-used for other applications -too. +www.ndss-symposium.org/wp-content/uploads/2019/02/ndss2019_04B-3_Meli_paper.pdf>`_. -These are: +**Recommended approach using environment variables**:: -* ``NEO4J_USERNAME`` -* ``NEO4J_PASSWORD`` -* ``NEO4J_BOLT_URL`` + # Set environment variables + export NEOMODEL_DATABASE_URL='bolt://neo4j:password@localhost:7687' + + # Configuration automatically loads from environment + from neomodel import get_config + config = get_config() -By setting these with (for example): :: - $ export NEO4J_USERNAME=neo4j - $ export NEO4J_PASSWORD=neo4j - $ export NEO4J_BOLT_URL="bolt://$NEO4J_USERNAME:$NEO4J_PASSWORD@localhost:7687" +Additional Configuration Options +-------------------------------- -They can be accessed from a Python script via the ``environ`` dict of module ``os`` and be used to set the connection -with something like: :: +Force Timezone on DateTime Properties +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - import os - from neomodel import config +Ensure all DateTimes are provided with a timezone before being serialized to UTC epoch:: + + from neomodel import get_config + + config = get_config() + config.force_timezone = True # default False - config.DATABASE_URL = os.environ["NEO4J_BOLT_URL"] +Enable Soft Cardinality Checking +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Enable warnings instead of errors for relationship cardinality violations:: -Enable automatic index and constraint creation ----------------------------------------------- + config.soft_cardinality_check = True # default False -Neomodel provides the :ref:`neomodel_install_labels` script for this task, -however if you want to handle this manually see below. +Enable Cypher Debug Logging +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Log all Cypher queries for debugging:: + + config.cypher_debug = True # default False + +Enable Slow Query Logging +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Log queries that take longer than the specified threshold:: + + config.slow_queries = 1.0 # Log queries taking more than 1 second + +Index and Constraint Management +------------------------------- + +Neomodel provides the :ref:`neomodel_install_labels` script for automatic index and constraint creation. Install indexes and constraints for a single class:: from neomodel import install_labels install_labels(YourClass) -Or for an entire 'schema' :: +Or for an entire schema:: import yourapp # make sure your app is loaded from neomodel import install_all_labels @@ -140,11 +361,4 @@ Or for an entire 'schema' :: # ... .. note:: - config.AUTO_INSTALL_LABELS has been removed from neomodel in version 5.3 - -Require timezones on DateTimeProperty -------------------------------------- - -Ensure all DateTimes are provided with a timezone before being serialised to UTC epoch:: - - config.FORCE_TIMEZONE = True # default False + ``config.AUTO_INSTALL_LABELS`` has been removed from neomodel in version 5.3 diff --git a/doc/source/filtering_ordering.rst b/doc/source/filtering_ordering.rst index 41a37581..2232718a 100644 --- a/doc/source/filtering_ordering.rst +++ b/doc/source/filtering_ordering.rst @@ -162,7 +162,7 @@ Sometimes you need to order results based on properties situated on different no # Find the most expensive coffee to deliver # Then order by the date the supplier started supplying - Coffee.nodes.order_by( + Coffee.nodes.traverse("suppliers").order_by( '-suppliers__delivery_cost', 'suppliers|since', ) @@ -172,6 +172,7 @@ In the example above, note the following syntax elements: - The name of relationships as defined in the `StructuredNode` class is used to traverse relationships. `suppliers` in this example. - Double underscore `__` is used to target a property of a node. `delivery_cost` in this example. - A pipe `|` is used to separate the relationship traversal from the property filter. This is a special syntax to indicate that the filter is on the relationship itself, not on the node at the end of the relationship. +- The traversal is done explicitly before the ordering, so that the traversed relationship's properties are available for ordering. Traversals can be of any length, with each relationships separated by a double underscore `__`, for example:: diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index fecf6098..d0acdb60 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -7,12 +7,14 @@ Connecting Before executing any neomodel code, set the connection url:: - from neomodel import config - config.DATABASE_URL = 'bolt://neo4j_username:neo4j_password@localhost:7687' # default + from neomodel import get_config + + config = get_config() + config.database_url = 'bolt://neo4j:password@localhost:7687' # default This must be called early on in your app, if you are using Django the `settings.py` file is ideal. -See the Configuration page (:ref:`connection_options_doc`) for config options. +See the Configuration page (:ref:`_configuration_options_doc`) for config options. If you are using your neo4j server for the first time you will need to change the default password. This can be achieved by visiting the neo4j admin panel (default: ``http://localhost:7474`` ). @@ -31,10 +33,12 @@ Defining Node Entities and Relationships Below is a definition of three related nodes `Person`, `City` and `Country`: :: - from neomodel import (config, StructuredNode, StringProperty, IntegerProperty, + from neomodel import (get_config, StructuredNode, StringProperty, IntegerProperty, UniqueIdProperty, RelationshipTo) - config.DATABASE_URL = 'bolt://neo4j_username:neo4j_password@localhost:7687' + + config = get_config() + config.database_url = 'bolt://neo4j_username:neo4j_password@localhost:7687' class Country(StructuredNode): code = StringProperty(unique_index=True, required=True) diff --git a/doc/source/index.rst b/doc/source/index.rst index 321105fc..3dfcac70 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -41,18 +41,24 @@ To install from github:: $ pip install git+git://github.com/neo4j-contrib/neomodel.git@HEAD#egg=neomodel-dev -.. note:: +.. attention:: - **Breaking changes in 5.3** + **New in 6.0** - Introducing support for asynchronous programming to neomodel required to introduce some breaking changes: + From now on, neomodel will use SemVer (major.minor.patch) for versioning. - - config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` (:ref:`neomodel_install_labels`) command instead. - - The Database class has been moved into neomodel.sync_.core - and a new AsyncDatabase introduced into neomodel.async_.core + This version introduces a modern configuration system, using a dataclass with typing, runtime and update validation rules, and environment variables support. + See the :ref:`configuration_options_doc` section for more details. - **Deprecations in 5.3** + This version introduces the merge_by parameter for batch operations to customize merge behaviour (label and property keys). + See the :ref:`batch` section for more details. - - Some standalone methods are moved into the Database() class and will be removed in a future release : + **Breaking changes in 6.0** + + - The soft cardinality check is now available for all cardinalities, and strict check is enabled by default. + - List object resolution from Cypher was creating "2-depth" lists for no apparent reason. This release fixes this so that, for example "RETURN collect(node)" will return the nodes directly as a list in the result. In other words, you can extract this list at `results[0][0]` instead of `results[0][0][0]` + - AsyncDatabase / Database are now true singletons for clarity + - Standalone methods moved into the Database() class have been removed outside of the Database() class : - change_neo4j_password - clear_neo4j_database - drop_constraints @@ -60,7 +66,7 @@ To install from github:: - remove_all_labels - install_labels - install_all_labels - - Additionally, to call these methods with async, use the ones in the AsyncDatabase() _adb_ singleton. + - Note : to call these methods with async, use the ones in the AsyncDatabase() _adb_ singleton. Contents diff --git a/doc/source/relationships.rst b/doc/source/relationships.rst index f79467f1..5e37d0a0 100644 --- a/doc/source/relationships.rst +++ b/doc/source/relationships.rst @@ -17,7 +17,7 @@ This avoids cyclic imports:: Cardinality =========== -It is possible to (softly) enforce cardinality constraints on your relationships. +It is possible to enforce cardinality constraints on your relationships. Remember this needs to be declared on both sides of the relationship definition:: class Person(StructuredNode): @@ -36,12 +36,10 @@ The following cardinality constraints are available: If a cardinality constraint is violated by existing data a :class:`~neomodel.exception.CardinalityViolation` exception is raised. -It is possible to enable a soft check for cardinality violations. This will print a warning to the console and create the relationship anyway. +This enforcement is strict by default and will throw an exception if a cardinality constraint is violated. +It is possible to enable a soft check. This will print a warning to the console and create the relationship anyway. This is useful for development purposes:: - config.SOFT_INVERSE_CARDINALITY_CHECK = True - -Note that this is only available for remote cardinality checks where we check the cardinality of the other end of the relationship. It is enabled by default in this case. -It will be made available for all cardinality checks in version 6.0, and will be disabled by default in all cases. + config.soft_cardinality_check = True Properties @@ -253,7 +251,7 @@ It is possible to specify a node traversal by creating a :class:`~neomodel.match.Traversal` object. This will get all ``Person`` entities that are directly related to another ``Person``, through all relationships:: - definition = dict(node_class=Person, direction=OUTGOING, + definition = dict(node_class=Person, direction=RelationshipDirection.OUTGOING, relation_type=None, model=None) relations_traversal = Traversal(jim, Person.__label__, definition) diff --git a/doc/source/semantic_indexes.rst b/doc/source/semantic_indexes.rst index d054675d..509654f8 100644 --- a/doc/source/semantic_indexes.rst +++ b/doc/source/semantic_indexes.rst @@ -1,4 +1,4 @@ -.. _Semantic Indexes: +.. _Semantic Indexes: ================================== Semantic Indexes @@ -6,15 +6,15 @@ Semantic Indexes Full Text Index ---------------- -From version x.x (version number tbc) neomodel provides a way to interact with neo4j `Full Text indexing `_. -The Full Text Index can be be created for both node and relationship properties. Only available for Neo4j version 5.16 or higher. +From version 6.0.0 neomodel provides a way to interact with neo4j `Full Text indexing `_. +The Full Text Index can be created for both node and relationship properties. Only available for Neo4j version 5.16 or higher. Defining a Full Text Index on a Property ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Within neomodel, indexing is a decision that is made at class definition time as the index needs to be built. A Full Text index is defined using :class:`~neomodel.properties.FulltextIndex` -To define a property with a full text index we use the following symantics:: +To define a property with a full text index we use the following syntax:: - StringProperty(fulltext_index=FulltextIndex(analyzer="standard-no-stop-words", eventually_consistent=False) + StringProperty(fulltext_index=FulltextIndex(analyzer="standard-no-stop-words", eventually_consistent=False)) Where, - ``analyzer``: The analyzer to use. The default is ``standard-no-stop-words``. @@ -27,13 +27,46 @@ Please refer to the `Neo4j documentation `_. +From version 5.5.0 neomodel provides a way to interact with neo4j `vector indexing `_. The Vector Index can be created on both node and relationship properties. Only available for Neo4j version 5.15 (node) and 5.18 (relationship) or higher. @@ -41,20 +74,20 @@ Defining a Vector Index on a Property ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Within neomodel, indexing is a decision that is made at class definition time as the index needs to be built. A vector index is defined using :class:`~neomodel.properties.VectorIndex`. -To define a property with a vector index we use the following symantics:: +To define a property with a vector index we use the following syntax:: - ArrayProperty(base_property=FloatProperty(), vector_index=VectorIndex(dimensions=512, similarity_function="cosine") + ArrayProperty(base_property=FloatProperty(), vector_index=VectorIndex(dimensions=512, similarity_function="cosine")) Where, - ``dimensions``: The dimension of the vector. The default is 1536. - ``similarity_function``: The similarity algorithm to use. The default is ``cosine``. -The index must then be built, this occurs when the function :func:`~neomodel.sync_.core.install_all_labels` is run +The index must then be built, this occurs when the function :func:`~neomodel.sync_.core.install_all_labels` is run. The vector indexes will then have the name "vector_index_{node.__label__}_{propertyname_with_vector_index}". .. attention:: - Neomodel creates a new vectorindex for each specified property, thus you cannot have two distinct properties being placed into the same index. + Neomodel creates a new vector index for each specified property, thus you cannot have two distinct properties being placed into the same index. Querying a Vector Index on a Property ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -64,10 +97,10 @@ Node Property The following node vector index property:: class someNode(StructuredNode): - vector = ArrayProperty(base_property=FloatProperty(), vector_index=VectorIndex(dimensions=512, similarity_function="cosine") + vector = ArrayProperty(base_property=FloatProperty(), vector_index=VectorIndex(dimensions=512, similarity_function="cosine")) name = StringProperty() -Can be queried using :class:`~neomodel.sematic_filters.VectorFilter`. Such as:: +Can be queried using :class:`~neomodel.semantic_filters.VectorFilter`. Such as:: from neomodel.semantic_filters import VectorFilter result = someNode.nodes.filter(vector_filter=VectorFilter(topk=3, vector_attribute_name="vector")).all() @@ -78,12 +111,11 @@ The :class:`~neomodel.semantic_filters.VectorFilter` can be used in conjunction .. attention:: If you use VectorFilter in conjunction with normal filter types, only nodes that fit the filters will return thus, you may get less than the topk specified. - Furthermore, all node filters **should** work with VectorFilter, relationship filters will also work but WILL NOT return the vector similiarty score alongside the relationship filter, instead the topk nodes and their appropriate relationships will be returned. + Furthermore, all node filters **should** work with VectorFilter, relationship filters will also work but WILL NOT return the vector similarity score alongside the relationship filter, instead the topk nodes and their appropriate relationships will be returned. RelationshipProperty ^^^^^^^^^^^^^^^^^^^^ Currently neomodel has not implemented an OGM method for querying vector indexes on relationships. -If this is something that you like please submit a github issue requirements highlighting your usage pattern. +If this is something that you would like, please submit a GitHub issue with requirements highlighting your usage pattern. Alternatively, whilst this has not been implemented yet you can still leverage `db.cypher_query` with the correct syntax to perform your required query. - diff --git a/doc/source/transactions.rst b/doc/source/transactions.rst index 92f5b37e..0fac2abb 100644 --- a/doc/source/transactions.rst +++ b/doc/source/transactions.rst @@ -124,7 +124,7 @@ the resulting bookmark may be extracted only after the context manager has exite # All database access happens after completion of the transactions # listed in bookmark1 and bookmark2 - bookmark = transaction.last_bookmark + bookmark = transaction.last_bookmarks Bookmarks are strings and may be passed between processes. ``transaction.bookmarks`` may be set to a single bookmark, a sequence of bookmarks, or None. @@ -146,20 +146,20 @@ the second element:: return Person.nodes.all() - result, bookmark = update_user_name(uid, name) + result, bookmarks = update_user_name(uid, name) - users, last_bookmark = get_all_users(bookmarks=[bookmark]) + users, last_bookmarks = get_all_users(bookmarks=bookmarks) for user in users: ... or manually:: - db.begin(bookmarks=[bookmark]) + db.begin(bookmarks=bookmarks) try: new_user = Person(name=username, email=email).save() send_email(new_user) - bookmark = db.commit() + bookmarks = db.commit() except Exception as e: db.rollback() diff --git a/docker-scripts/tests-with-docker-compose.sh b/docker-scripts/tests-with-docker-compose.sh index 7a97377e..77c8083a 100644 --- a/docker-scripts/tests-with-docker-compose.sh +++ b/docker-scripts/tests-with-docker-compose.sh @@ -13,8 +13,8 @@ for dir in neomodel test; do rm -f ${dir}/**/*.pyc find ${dir} -name __pycache__ -exec rm -Rf {} \; done -: "${NEO4J_VERSIONS:=5.4 4.4}" -: "${PYTHON_VERSIONS:=3.11 3.10 3.9 3.8 3.7}" +: "${NEO4J_VERSIONS:=enterprise community 5.5-enterprise 4.4-enterprise 4.4-community}" +: "${PYTHON_VERSIONS:=3.13 3.12 3.11 3.10}" for NEO4J_VERSION in ${NEO4J_VERSIONS}; do for PYTHON_VERSION in ${PYTHON_VERSIONS}; do export NEO4J_VERSION diff --git a/neomodel/__init__.py b/neomodel/__init__.py index ef524bdc..5f9f9d22 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -5,8 +5,9 @@ AsyncZeroOrMore, AsyncZeroOrOne, ) -from neomodel.async_.core import AsyncStructuredNode, adb +from neomodel.async_.database import adb from neomodel.async_.match import AsyncNodeSet, AsyncTraversal +from neomodel.async_.node import AsyncStructuredNode from neomodel.async_.path import AsyncNeomodelPath from neomodel.async_.property_manager import AsyncPropertyManager from neomodel.async_.relationship import AsyncStructuredRel @@ -17,6 +18,7 @@ AsyncRelationshipManager, AsyncRelationshipTo, ) +from neomodel.config import NeomodelConfig, get_config, reset_config, set_config from neomodel.exceptions import * from neomodel.match_q import Q # noqa from neomodel.properties import ( @@ -39,18 +41,9 @@ VectorIndex, ) from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne -from neomodel.sync_.core import ( - StructuredNode, - change_neo4j_password, - clear_neo4j_database, - db, - drop_constraints, - drop_indexes, - install_all_labels, - install_labels, - remove_all_labels, -) +from neomodel.sync_.database import db from neomodel.sync_.match import NodeSet, Traversal +from neomodel.sync_.node import StructuredNode from neomodel.sync_.path import NeomodelPath from neomodel.sync_.property_manager import PropertyManager from neomodel.sync_.relationship import StructuredRel @@ -61,7 +54,6 @@ RelationshipManager, RelationshipTo, ) -from neomodel.util import EITHER, INCOMING, OUTGOING __author__ = "Robin Edwards" __email__ = "robin.ge@gmail.com" diff --git a/neomodel/_version.py b/neomodel/_version.py index 16b899cb..0f607a5d 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.5.3" +__version__ = "6.0.0" diff --git a/neomodel/async_/cardinality.py b/neomodel/async_/cardinality.py index 6d6cc6df..7a60177e 100644 --- a/neomodel/async_/cardinality.py +++ b/neomodel/async_/cardinality.py @@ -4,6 +4,7 @@ AsyncRelationshipManager, AsyncZeroOrMore, ) +from neomodel.config import get_config from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation if TYPE_CHECKING: @@ -15,17 +16,16 @@ class AsyncZeroOrOne(AsyncRelationshipManager): description = "zero or one relationship" - async def _check_cardinality( - self, node: "AsyncStructuredNode", soft_check: bool = False - ) -> None: + async def check_cardinality(self, node: "AsyncStructuredNode") -> None: if await self.get_len(): - if soft_check: + detailed_description = str(self) + if get_config().soft_cardinality_check: print( - f"Cardinality violation detected : Node already has one relationship of type {self.definition['relation_type']}, should not connect more. Soft check is enabled so the relationship will be created. Note that strict check will be enabled by default in version 6.0" + f"Cardinality violation detected : Node already has {detailed_description}, should not connect more. Soft check is enabled so the relationship will be created." ) else: raise AttemptedCardinalityViolation( - f"Node already has one relationship of type {self.definition['relation_type']}. Use reconnect() to replace the existing relationship." + f"Node already has {detailed_description}. Use reconnect() to replace the existing relationship." ) async def single(self) -> Optional["AsyncStructuredNode"]: @@ -46,7 +46,7 @@ async def all(self) -> list["AsyncStructuredNode"]: return [node] if node else [] async def connect( - self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None + self, node: "AsyncStructuredNode", properties: dict[str, Any] | None = None ) -> "AsyncStructuredRel": """ Connect to a node. @@ -97,6 +97,11 @@ async def disconnect(self, node: "AsyncStructuredNode") -> None: raise AttemptedCardinalityViolation("One or more expected") return await super().disconnect(node) + async def disconnect_all(self) -> None: + raise AttemptedCardinalityViolation( + "Cardinality one or more, cannot disconnect_all use reconnect." + ) + class AsyncOne(AsyncRelationshipManager): """ @@ -105,17 +110,16 @@ class AsyncOne(AsyncRelationshipManager): description = "one relationship" - async def _check_cardinality( - self, node: "AsyncStructuredNode", soft_check: bool = False - ) -> None: + async def check_cardinality(self, node: "AsyncStructuredNode") -> None: if await self.get_len(): - if soft_check: + detailed_description = str(self) + if get_config().soft_cardinality_check: print( - f"Cardinality violation detected : Node already has one relationship of type {self.definition['relation_type']}, should not connect more. Soft check is enabled so the relationship will be created. Note that strict check will be enabled by default in version 6.0" + f"Cardinality violation detected : Node already has {detailed_description}, should not connect more. Soft check is enabled so the relationship will be created." ) else: raise AttemptedCardinalityViolation( - f"Node already has one relationship of type {self.definition['relation_type']}. Use reconnect() to replace the existing relationship." + f"Node already has {detailed_description}. Use reconnect() to replace the existing relationship." ) async def single(self) -> "AsyncStructuredNode": @@ -150,7 +154,7 @@ async def disconnect_all(self) -> None: ) async def connect( - self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None + self, node: "AsyncStructuredNode", properties: dict[str, Any] | None = None ) -> "AsyncStructuredRel": """ Connect a node diff --git a/neomodel/async_/core.py b/neomodel/async_/database.py similarity index 56% rename from neomodel/async_/core.py rename to neomodel/async_/database.py index 95a408ac..964f3cb3 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/database.py @@ -1,13 +1,13 @@ +""" +Database connection and management for the async neomodel module. +""" + import logging import os import sys import time -import warnings -from asyncio import iscoroutinefunction -from functools import wraps -from itertools import combinations -from threading import local -from typing import Any, Callable, Optional, TextIO, Union +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any, Callable, TextIO from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -23,72 +23,49 @@ from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired from neo4j.graph import Node, Path, Relationship -from neomodel import config -from neomodel._async_compat.util import AsyncUtil -from neomodel.async_.property_manager import AsyncPropertyManager +from neomodel.config import get_config +from neomodel.constants import ( + ACCESS_MODE_READ, + ACCESS_MODE_WRITE, + CONSTRAINT_ALREADY_EXISTS, + DROP_CONSTRAINT_COMMAND, + DROP_INDEX_COMMAND, + ELEMENT_ID_METHOD, + ENTERPRISE_EDITION_TAG, + INDEX_ALREADY_EXISTS, + LEGACY_ID_METHOD, + LIST_CONSTRAINTS_COMMAND, + LOOKUP_INDEX_TYPE, + NO_SESSION_OPEN, + NO_TRANSACTION_IN_PROGRESS, + RULE_ALREADY_EXISTS, + UNKNOWN_SERVER_VERSION, + VERSION_FULLTEXT_INDEXES_SUPPORT, + VERSION_LEGACY_ID, + VERSION_PARALLEL_RUNTIME_SUPPORT, + VERSION_RELATIONSHIP_CONSTRAINTS_SUPPORT, + VERSION_RELATIONSHIP_VECTOR_INDEXES_SUPPORT, + VERSION_VECTOR_INDEXES_SUPPORT, +) from neomodel.exceptions import ( ConstraintValidationFailed, - DoesNotExist, FeatureNotSupported, - NodeClassAlreadyDefined, NodeClassNotDefined, RelationshipClassNotDefined, UniqueProperty, ) -from neomodel.hooks import hooks from neomodel.properties import FulltextIndex, Property, VectorIndex -from neomodel.util import ( - _UnsavedNode, - classproperty, - deprecated, - version_tag_to_integer, -) +from neomodel.util import version_tag_to_integer -logger = logging.getLogger(__name__) +# The imports inside this block are only for type checking tools (like mypy or IDEs) to help with code hints and error checking. +# These imports are ignored when the code actually runs, so they don't affect runtime performance or cause circular import problems. +if TYPE_CHECKING: + from neomodel.async_.node import AsyncStructuredNode # type: ignore + from neomodel.async_.transaction import AsyncTransactionProxy, ImpersonationHandler -RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" -INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" -CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" -STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" -NOT_COROUTINE_ERROR = "The decorated function must be a coroutine" - -# Access mode constants -ACCESS_MODE_WRITE = "WRITE" -ACCESS_MODE_READ = "READ" - -# Database edition constants -ENTERPRISE_EDITION_TAG = "enterprise" - -# Neo4j version constants -VERSION_LEGACY_ID = "4" -VERSION_RELATIONSHIP_CONSTRAINTS_SUPPORT = "5.7" -VERSION_PARALLEL_RUNTIME_SUPPORT = "5.13" -VERSION_VECTOR_INDEXES_SUPPORT = "5.15" -VERSION_FULLTEXT_INDEXES_SUPPORT = "5.16" -VERSION_RELATIONSHIP_VECTOR_INDEXES_SUPPORT = "5.18" - -# ID method constants -LEGACY_ID_METHOD = "id" -ELEMENT_ID_METHOD = "elementId" - -# Cypher query constants -LIST_CONSTRAINTS_COMMAND = "SHOW CONSTRAINTS" -DROP_CONSTRAINT_COMMAND = "DROP CONSTRAINT " -DROP_INDEX_COMMAND = "DROP INDEX " - -# Index type constants -LOOKUP_INDEX_TYPE = "LOOKUP" - -# Info messages constants -NO_TRANSACTION_IN_PROGRESS = "No transaction in progress" -NO_SESSION_OPEN = "No session open" -UNKNOWN_SERVER_VERSION = """ - Unable to perform this operation because the database server version is not known. - This might mean that the database server is offline. -""" +logger = logging.getLogger(__name__) -# make sure the connection url has been set prior to executing the wrapped function def ensure_connection(func: Callable) -> Callable: """Decorator that ensures a connection is established before executing the decorated function. @@ -97,7 +74,6 @@ def ensure_connection(func: Callable) -> Callable: Returns: callable: The decorated function. - """ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable: @@ -108,38 +84,190 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable: _db = self if not _db.driver: - if hasattr(config, "DATABASE_URL") and config.DATABASE_URL: - await _db.set_connection(url=config.DATABASE_URL) - elif hasattr(config, "DRIVER") and config.DRIVER: - await _db.set_connection(driver=config.DRIVER) + config = get_config() + if hasattr(config, "database_url") and config.database_url: + await _db.set_connection(url=config.database_url) + elif hasattr(config, "driver") and config.driver: + await _db.set_connection(driver=config.driver) return await func(self, *args, **kwargs) return wrapper -class AsyncDatabase(local): +class AsyncDatabase: """ A singleton object via which all operations from neomodel to the Neo4j backend are handled with. + + This class enforces singleton behavior - only one instance can exist at a time. + The singleton instance is accessible via the module-level 'adb' variable. """ + # Shared global registries _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {} + # Singleton instance tracking + _instance: "AsyncDatabase | None" = None + _initialized: bool = False + + def __new__(cls) -> "AsyncDatabase": + """ + Enforce singleton pattern - only one instance can exist. + + Returns: + AsyncDatabase: The singleton instance + + Raises: + RuntimeError: If attempting to create a second instance + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + def __init__(self) -> None: - self._active_transaction: Optional[AsyncTransaction] = None - self.url: Optional[str] = None - self.driver: Optional[AsyncDriver] = None - self._session: Optional[AsyncSession] = None - self._pid: Optional[int] = None - self._database_name: Optional[str] = DEFAULT_DATABASE - self._database_version: Optional[str] = None - self._database_edition: Optional[str] = None - self.impersonated_user: Optional[str] = None - self._parallel_runtime: Optional[bool] = False + # Prevent re-initialization of the singleton instance + if AsyncDatabase._initialized: + return + # Private to instances and contexts + self.__active_transaction: ContextVar[AsyncTransaction | None] = ContextVar( + "_active_transaction", default=None + ) + self.__url: ContextVar[str | None] = ContextVar("url", default=None) + self.__driver: ContextVar[AsyncDriver | None] = ContextVar( + "driver", default=None + ) + self.__session: ContextVar[AsyncSession | None] = ContextVar( + "_session", default=None + ) + self.__pid: ContextVar[int | None] = ContextVar("_pid", default=None) + self.__database_name: ContextVar[str | None] = ContextVar( + "_database_name", default=DEFAULT_DATABASE + ) + self.__database_version: ContextVar[str | None] = ContextVar( + "_database_version", default=None + ) + self.__database_edition: ContextVar[str | None] = ContextVar( + "_database_edition", default=None + ) + self.__impersonated_user: ContextVar[str | None] = ContextVar( + "impersonated_user", default=None + ) + self.__parallel_runtime: ContextVar[bool | None] = ContextVar( + "_parallel_runtime", default=False + ) + + # Mark the singleton as initialized + AsyncDatabase._initialized = True + + @classmethod + def get_instance(cls) -> "AsyncDatabase": + """ + Get the singleton instance of AsyncDatabase. + + Returns: + AsyncDatabase: The singleton instance + """ + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + async def reset_instance(cls) -> None: + """ + Reset the singleton instance. This should only be used for testing purposes. + + Warning: This will close any existing connections and reset all state. + """ + if cls._instance is not None: + # Close any existing connections + await cls._instance.close_connection() + + cls._instance = None + cls._initialized = False + + @property + def _active_transaction(self) -> AsyncTransaction | None: + return self.__active_transaction.get() + + @_active_transaction.setter + def _active_transaction(self, value: AsyncTransaction | None) -> None: + self.__active_transaction.set(value) + + @property + def url(self) -> str | None: + return self.__url.get() + + @url.setter + def url(self, value: str | None) -> None: + self.__url.set(value) + + @property + def driver(self) -> AsyncDriver | None: + return self.__driver.get() + + @driver.setter + def driver(self, value: AsyncDriver | None) -> None: + self.__driver.set(value) + + @property + def _session(self) -> AsyncSession | None: + return self.__session.get() + + @_session.setter + def _session(self, value: AsyncSession | None) -> None: + self.__session.set(value) + + @property + def _pid(self) -> int | None: + return self.__pid.get() + + @_pid.setter + def _pid(self, value: int | None) -> None: + self.__pid.set(value) + + @property + def _database_name(self) -> str | None: + return self.__database_name.get() + + @_database_name.setter + def _database_name(self, value: str | None) -> None: + self.__database_name.set(value) + + @property + def _database_version(self) -> str | None: + return self.__database_version.get() + + @_database_version.setter + def _database_version(self, value: str | None) -> None: + self.__database_version.set(value) + + @property + def _database_edition(self) -> str | None: + return self.__database_edition.get() + + @_database_edition.setter + def _database_edition(self, value: str | None) -> None: + self.__database_edition.set(value) + + @property + def impersonated_user(self) -> str | None: + return self.__impersonated_user.get() + + @impersonated_user.setter + def impersonated_user(self, value: str | None) -> None: + self.__impersonated_user.set(value) + + @property + def _parallel_runtime(self) -> bool | None: + return self.__parallel_runtime.get() + + @_parallel_runtime.setter + def _parallel_runtime(self, value: bool | None) -> None: + self.__parallel_runtime.set(value) async def set_connection( - self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None + self, url: str | None = None, driver: AsyncDriver | None = None ) -> None: """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -153,8 +281,9 @@ async def set_connection( """ if driver: self.driver = driver - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME + config = get_config() + if hasattr(config, "database_name") and config.database_name: + self._database_name = config.database_name elif url: self._parse_driver_from_url(url=url) @@ -207,21 +336,22 @@ def _parse_driver_from_url(self, url: str) -> None: f"Expecting url format: bolt://user:password@localhost:7687 got {url}" ) + config = get_config() options = { "auth": basic_auth(username, password), - "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, - "connection_timeout": config.CONNECTION_TIMEOUT, - "keep_alive": config.KEEP_ALIVE, - "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, - "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, - "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, - "resolver": config.RESOLVER, - "user_agent": config.USER_AGENT, + "connection_acquisition_timeout": config.connection_acquisition_timeout, + "connection_timeout": config.connection_timeout, + "keep_alive": config.keep_alive, + "max_connection_lifetime": config.max_connection_lifetime, + "max_connection_pool_size": config.max_connection_pool_size, + "max_transaction_retry_time": config.max_transaction_retry_time, + "resolver": config.resolver, + "user_agent": config.user_agent, } if "+s" not in parsed_url.scheme: - options["encrypted"] = config.ENCRYPTED - options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + options["encrypted"] = config.encrypted + options["trusted_certificates"] = config.trusted_certificates # Ignore the type error because the workaround would be duplicating code self.driver = AsyncGraphDatabase.driver( @@ -230,8 +360,8 @@ def _parse_driver_from_url(self, url: str) -> None: self.url = url # The database name can be provided through the url or the config if database_name == "": - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME + if hasattr(config, "database_name") and config.database_name: + self._database_name = config.database_name else: self._database_name = database_name @@ -248,14 +378,14 @@ async def close_connection(self) -> None: self.driver = None @property - async def database_version(self) -> Optional[str]: + async def database_version(self) -> str | None: if self._database_version is None: await self._update_database_version() return self._database_version @property - async def database_edition(self) -> Optional[str]: + async def database_edition(self) -> str | None: if self._database_edition is None: await self._update_database_version() @@ -266,18 +396,26 @@ def transaction(self) -> "AsyncTransactionProxy": """ Returns the current transaction object """ + from neomodel.async_.transaction import AsyncTransactionProxy # type: ignore + return AsyncTransactionProxy(self) @property def write_transaction(self) -> "AsyncTransactionProxy": + from neomodel.async_.transaction import AsyncTransactionProxy # type: ignore + return AsyncTransactionProxy(self, access_mode=ACCESS_MODE_WRITE) @property def read_transaction(self) -> "AsyncTransactionProxy": + from neomodel.async_.transaction import AsyncTransactionProxy # type: ignore + return AsyncTransactionProxy(self, access_mode=ACCESS_MODE_READ) @property def parallel_read_transaction(self) -> "AsyncTransactionProxy": + from neomodel.async_.transaction import AsyncTransactionProxy # type: ignore + return AsyncTransactionProxy( self, access_mode=ACCESS_MODE_READ, parallel_runtime=True ) @@ -291,6 +429,8 @@ async def impersonate(self, user: str) -> "ImpersonationHandler": Returns: ImpersonationHandler: Context manager to set/unset the user to impersonate """ + from neomodel.async_.transaction import ImpersonationHandler # type: ignore + db_edition = await self.database_edition if db_edition != ENTERPRISE_EDITION_TAG: raise FeatureNotSupported( @@ -451,12 +591,18 @@ def _object_resolution(self, object_to_resolve: Any) -> Any: ) if isinstance(object_to_resolve, Path): - from neomodel.async_.path import AsyncNeomodelPath + from neomodel.async_.path import AsyncNeomodelPath # type: ignore return AsyncNeomodelPath(object_to_resolve) if isinstance(object_to_resolve, list): - return self._result_resolution([object_to_resolve]) + return [self._object_resolution(item) for item in object_to_resolve] + + if isinstance(object_to_resolve, dict): + return { + key: self._object_resolution(value) + for key, value in object_to_resolve.items() + } return object_to_resolve @@ -492,11 +638,11 @@ def _result_resolution(self, result_list: list) -> list: async def cypher_query( self, query: str, - params: Optional[dict[str, Any]] = None, + params: dict[str, Any] | None = None, handle_unique: bool = True, retry_on_session_expire: bool = False, resolve_objects: bool = False, - ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: + ) -> tuple[list | None, tuple[str, ...] | None]: """ Runs a query on the database and returns a list of results and their headers. @@ -548,13 +694,13 @@ async def cypher_query( async def _run_cypher_query( self, - session: Union[AsyncSession, AsyncTransaction], + session: AsyncSession | AsyncTransaction, query: str, params: dict[str, Any], handle_unique: bool, retry_on_session_expire: bool, resolve_objects: bool, - ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: + ) -> tuple[list | None, tuple[str, ...] | None]: try: # Retrieve the data start = time.time() @@ -615,7 +761,7 @@ async def get_id_method(self) -> str: else: return ELEMENT_ID_METHOD - async def parse_element_id(self, element_id: Optional[str]) -> Union[str, int]: + async def parse_element_id(self, element_id: str | None) -> str | int: if element_id is None: raise ValueError( "Unable to parse element id, are you sure this element has been saved ?" @@ -712,12 +858,12 @@ async def clear_neo4j_database( """ ) if clear_constraints: - await drop_constraints() + await self.drop_constraints() if clear_indexes: - await drop_indexes() + await self.drop_indexes() async def drop_constraints( - self, quiet: bool = True, stdout: Optional[TextIO] = None + self, quiet: bool = True, stdout: TextIO | None = None ) -> None: """ Discover and drop all constraints. @@ -745,7 +891,7 @@ async def drop_constraints( stdout.write("\n") async def drop_indexes( - self, quiet: bool = True, stdout: Optional[TextIO] = None + self, quiet: bool = True, stdout: TextIO | None = None ) -> None: """ Discover and drop all indexes, except the automatically created token lookup indexes. @@ -766,7 +912,7 @@ async def drop_indexes( if not quiet: stdout.write("\n") - async def remove_all_labels(self, stdout: Optional[TextIO] = None) -> None: + async def remove_all_labels(self, stdout: TextIO | None = None) -> None: """ Calls functions for dropping constraints and indexes. @@ -783,7 +929,7 @@ async def remove_all_labels(self, stdout: Optional[TextIO] = None) -> None: stdout.write("Dropping indexes...\n") await self.drop_indexes(quiet=False, stdout=stdout) - async def install_all_labels(self, stdout: Optional[TextIO] = None) -> None: + async def install_all_labels(self, stdout: TextIO | None = None) -> None: """ Discover all subclasses of StructuredNode in your application and execute install_labels on each. Note: code must be loaded (imported) in order for a class to be discovered. @@ -804,9 +950,11 @@ def subsub(cls: Any) -> list: # recursively return all subclasses stdout.write("Setting up indexes and constraints...\n\n") i = 0 + from .node import AsyncStructuredNode + for cls in subsub(AsyncStructuredNode): stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") - await install_labels(cls, quiet=False, stdout=stdout) + await self.install_labels(cls, quiet=False, stdout=stdout) i += 1 if i: @@ -815,7 +963,7 @@ def subsub(cls: Any) -> list: # recursively return all subclasses stdout.write(f"Finished {i} classes.\n") async def install_labels( - self, cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None + self, cls: Any, quiet: bool = True, stdout: TextIO | None = None ) -> None: """ Setup labels with indexes and constraints for a given class @@ -1195,742 +1343,4 @@ async def _install_relationship( # Create a singleton instance of the database object -adb = AsyncDatabase() - - -# Deprecated methods -async def change_neo4j_password( - db: AsyncDatabase, user: str, new_password: str -) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.change_neo4j_password(user, new_password) instead. - This direct call will be removed in an upcoming version. - """ - ) - await db.change_neo4j_password(user, new_password) - - -async def clear_neo4j_database( - db: AsyncDatabase, clear_constraints: bool = False, clear_indexes: bool = False -) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead. - This direct call will be removed in an upcoming version. - """ - ) - await db.clear_neo4j_database(clear_constraints, clear_indexes) - - -async def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_constraints(quiet, stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - await adb.drop_constraints(quiet, stdout) - - -async def drop_indexes(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_indexes(quiet, stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - await adb.drop_indexes(quiet, stdout) - - -async def remove_all_labels(stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.remove_all_labels(stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - await adb.remove_all_labels(stdout) - - -async def install_labels( - cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None -) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_labels(cls, quiet, stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - await adb.install_labels(cls, quiet, stdout) - - -async def install_all_labels(stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_all_labels(stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - await adb.install_all_labels(stdout) - - -class AsyncTransactionProxy: - bookmarks: Optional[Bookmarks] = None - - def __init__( - self, - db: AsyncDatabase, - access_mode: Optional[str] = None, - parallel_runtime: Optional[bool] = False, - ): - self.db: AsyncDatabase = db - self.access_mode: Optional[str] = access_mode - self.parallel_runtime: Optional[bool] = parallel_runtime - - @ensure_connection - async def __aenter__(self) -> "AsyncTransactionProxy": - if self.parallel_runtime and not await self.db.parallel_runtime_available(): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.parallel_runtime = False - self.db._parallel_runtime = self.parallel_runtime - await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) - self.bookmarks = None - return self - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - self.db._parallel_runtime = False - if exc_value: - await self.db.rollback() - - if ( - exc_type is ClientError - and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" - ): - raise UniqueProperty(exc_value.message) - - if not exc_value: - self.last_bookmark = await self.db.commit() - - def __call__(self, func: Callable) -> Callable: - if AsyncUtil.is_async_code and not iscoroutinefunction(func): - raise TypeError(NOT_COROUTINE_ERROR) - - @wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> Callable: - async with self: - return await func(*args, **kwargs) - - return wrapper - - @property - def with_bookmark(self) -> "BookmarkingAsyncTransactionProxy": - return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) - - -class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): - def __call__(self, func: Callable) -> Callable: - if AsyncUtil.is_async_code and not iscoroutinefunction(func): - raise TypeError(NOT_COROUTINE_ERROR) - - async def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, None]: - self.bookmarks = kwargs.pop("bookmarks", None) - - async with self: - result = await func(*args, **kwargs) - self.last_bookmark = None - - return result, self.last_bookmark - - return wrapper - - -class ImpersonationHandler: - def __init__(self, db: AsyncDatabase, impersonated_user: str): - self.db = db - self.impersonated_user = impersonated_user - - def __enter__(self) -> "ImpersonationHandler": - self.db.impersonated_user = self.impersonated_user - return self - - def __exit__( - self, exception_type: Any, exception_value: Any, exception_traceback: Any - ) -> None: - self.db.impersonated_user = None - - print("\nException type:", exception_type) - print("\nException value:", exception_value) - print("\nTraceback:", exception_traceback) - - def __call__(self, func: Callable) -> Callable: - def wrapper(*args: Any, **kwargs: Any) -> Callable: - with self: - return func(*args, **kwargs) - - return wrapper - - -class NodeMeta(type): - DoesNotExist: type[DoesNotExist] - __required_properties__: tuple[str, ...] - __all_properties__: tuple[tuple[str, Any], ...] - __all_aliases__: tuple[tuple[str, Any], ...] - __all_relationships__: tuple[tuple[str, Any], ...] - __label__: str - __optional_labels__: list[str] - - defined_properties: Callable[..., dict[str, Any]] - - def __new__( - mcs: type, name: str, bases: tuple[type, ...], namespace: dict[str, Any] - ) -> Any: - namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) - cls: NodeMeta = type.__new__(mcs, name, bases, namespace) - cls.DoesNotExist._model_class = cls - - if hasattr(cls, "__abstract_node__"): - delattr(cls, "__abstract_node__") - else: - if "deleted" in namespace: - raise ValueError( - "Property name 'deleted' is not allowed as it conflicts with neomodel internals." - ) - elif "id" in namespace: - raise ValueError( - """ - Property name 'id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as id is also a Neo4j internal. - """ - ) - elif "element_id" in namespace: - raise ValueError( - """ - Property name 'element_id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. - """ - ) - for key, value in ( - (x, y) for x, y in namespace.items() if isinstance(y, Property) - ): - value.name, value.owner = key, cls - if hasattr(value, "setup") and callable(value.setup): - value.setup() - - # cache various groups of properies - cls.__required_properties__ = tuple( - name - for name, property in cls.defined_properties( - aliases=False, rels=False - ).items() - if property.required or property.unique_index - ) - cls.__all_properties__ = tuple( - cls.defined_properties(aliases=False, rels=False).items() - ) - cls.__all_aliases__ = tuple( - cls.defined_properties(properties=False, rels=False).items() - ) - cls.__all_relationships__ = tuple( - cls.defined_properties(aliases=False, properties=False).items() - ) - - cls.__label__ = namespace.get("__label__", name) - cls.__optional_labels__ = namespace.get("__optional_labels__", []) - - build_class_registry(cls) - - return cls - - -def build_class_registry(cls: Any) -> None: - base_label_set = frozenset(cls.inherited_labels()) - optional_label_set = set(cls.inherited_optional_labels()) - - # Construct all possible combinations of labels + optional labels - possible_label_combinations = [ - frozenset(set(x).union(base_label_set)) - for i in range(1, len(optional_label_set) + 1) - for x in combinations(optional_label_set, i) - ] - possible_label_combinations.append(base_label_set) - - for label_set in possible_label_combinations: - if not hasattr(cls, "__target_databases__"): - if label_set not in adb._NODE_CLASS_REGISTRY: - adb._NODE_CLASS_REGISTRY[label_set] = cls - else: - raise NodeClassAlreadyDefined( - cls, adb._NODE_CLASS_REGISTRY, adb._DB_SPECIFIC_CLASS_REGISTRY - ) - else: - for database in cls.__target_databases__: - if database not in adb._DB_SPECIFIC_CLASS_REGISTRY: - adb._DB_SPECIFIC_CLASS_REGISTRY[database] = {} - if label_set not in adb._DB_SPECIFIC_CLASS_REGISTRY[database]: - adb._DB_SPECIFIC_CLASS_REGISTRY[database][label_set] = cls - else: - raise NodeClassAlreadyDefined( - cls, adb._NODE_CLASS_REGISTRY, adb._DB_SPECIFIC_CLASS_REGISTRY - ) - - -NodeBase: type = NodeMeta( - "NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True} -) - - -class AsyncStructuredNode(NodeBase): - """ - Base class for all node definitions to inherit from. - - If you want to create your own abstract classes set: - __abstract_node__ = True - """ - - # static properties - - __abstract_node__ = True - - # magic methods - - def __init__(self, *args: Any, **kwargs: Any): - if "deleted" in kwargs: - raise ValueError("deleted property is reserved for neomodel") - - for key, val in self.__all_relationships__: - self.__dict__[key] = val.build_manager(self, key) - - super().__init__(*args, **kwargs) - - def __eq__(self, other: Any) -> bool: - """ - Compare two node objects. - If both nodes were saved to the database, compare them by their element_id. - Otherwise, compare them using object id in memory. - If `other` is not a node, always return False. - """ - if not isinstance(other, (AsyncStructuredNode,)): - return False - if self.was_saved and other.was_saved: - return self.element_id == other.element_id - return id(self) == id(other) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}: {self}>" - - def __str__(self) -> str: - return repr(self.__properties__) - - # dynamic properties - - @classproperty - def nodes(self) -> Any: - """ - Returns a NodeSet object representing all nodes of the classes label - :return: NodeSet - :rtype: NodeSet - """ - from neomodel.async_.match import AsyncNodeSet - - return AsyncNodeSet(self) - - @property - def element_id(self) -> Optional[Any]: - if hasattr(self, "element_id_property"): - return self.element_id_property - return None - - # Version 4.4 support - id is deprecated in version 5.x - @property - def id(self) -> int: - try: - return int(self.element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) - - @property - def was_saved(self) -> bool: - """ - Shows status of node in the database. False, if node hasn't been saved yet, True otherwise. - """ - return self.element_id is not None - - # methods - - @classmethod - async def _build_merge_query( - cls, - merge_params: tuple[dict[str, Any], ...], - update_existing: bool = False, - lazy: bool = False, - relationship: Optional[Any] = None, - ) -> tuple[str, dict[str, Any]]: - """ - Get a tuple of a CYPHER query and a params dict for the specified MERGE query. - - :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". - :type merge_params: list of dict - :param update_existing: True to update properties of existing nodes, default False to keep existing values. - :type update_existing: bool - :rtype: tuple - """ - query_params: dict[str, Any] = {"merge_params": merge_params} - n_merge_labels = ":".join(cls.inherited_labels()) - n_merge_prm = ", ".join( - ( - f"{getattr(cls, p).get_db_property_name(p)}: params.create.{getattr(cls, p).get_db_property_name(p)}" - for p in cls.__required_properties__ - ) - ) - n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" - if relationship is None: - # create "simple" unwind query - query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " - else: - # validate relationship - if not isinstance(relationship.source, AsyncStructuredNode): - raise ValueError( - f"relationship source [{repr(relationship.source)}] is not a StructuredNode" - ) - relation_type = relationship.definition.get("relation_type") - if not relation_type: - raise ValueError( - "No relation_type is specified on provided relationship" - ) - - from neomodel.async_.match import _rel_helper - - if relationship.source.element_id is None: - raise RuntimeError( - "Could not identify the relationship source, its element id was None." - ) - query_params["source_id"] = await adb.parse_element_id( - relationship.source.element_id - ) - query = f"MATCH (source:{relationship.source.__label__}) WHERE {await adb.get_id_method()}(source) = $source_id\n " - query += "WITH source\n UNWIND $merge_params as params \n " - query += "MERGE " - query += _rel_helper( - lhs="source", - rhs=n_merge, - ident=None, - relation_type=relation_type, - direction=relationship.definition["direction"], - ) - - query += "ON CREATE SET n = params.create\n " - # if update_existing, write properties on match as well - if update_existing is True: - query += "ON MATCH SET n += params.update\n" - - # close query - if lazy: - query += f"RETURN {await adb.get_id_method()}(n)" - else: - query += "RETURN n" - - return query, query_params - - @classmethod - async def create(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: - """ - Call to CREATE with parameters map. A new instance will be created and saved. - - :param props: dict of properties to create the nodes. - :type props: tuple - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :type: bool - :rtype: list - """ - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - lazy = kwargs.get("lazy", False) - # create mapped query - query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" - - # close query - if lazy: - query += f" RETURN {await adb.get_id_method()}(n)" - else: - query += " RETURN n" - - results = [] - for item in [ - cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props - ]: - node, _ = await adb.cypher_query(query, {"create_params": item}) - results.extend(node[0]) - - nodes = [cls.inflate(node) for node in results] - - if not lazy and hasattr(cls, "post_create"): - for node in nodes: - node.post_create() - - return nodes - - @classmethod - async def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, - this is an atomic operation. If an instance already exists all optional properties specified will be updated. - - Note that the post_create hook isn't called after create_or_update - - :param props: List of dict arguments to get or create the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy: bool = bool(kwargs.get("lazy", False)) - relationship = kwargs.get("relationship") - - # build merge query, make sure to update only explicitly specified properties - create_or_update_params = [] - for specified, deflated in [ - (p, cls.deflate(p, skip_empty=True)) for p in props - ]: - create_or_update_params.append( - { - "create": deflated, - "update": dict( - (k, v) for k, v in deflated.items() if k in specified - ), - } - ) - query, params = await cls._build_merge_query( - tuple(create_or_update_params), - update_existing=True, - relationship=relationship, - lazy=lazy, - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = await adb.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - async def cypher( - self, query: str, params: Optional[dict[str, Any]] = None - ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: - """ - Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. - - :param query: cypher query string - :type: string - :param params: query parameters - :type: dict - :return: tuple containing a list of query results, and the meta information as a tuple - :rtype: tuple - """ - self._pre_action_check("cypher") - _params = params or {} - if self.element_id is None: - raise ValueError("Can't run cypher operation on unsaved node") - element_id = await adb.parse_element_id(self.element_id) - _params.update({"self": element_id}) - return await adb.cypher_query(query, _params) - - @hooks - async def delete(self) -> bool: - """ - Delete a node and its relationships - - :return: True - """ - self._pre_action_check("delete") - await self.cypher( - f"MATCH (self) WHERE {await adb.get_id_method()}(self)=$self DETACH DELETE self" - ) - delattr(self, "element_id_property") - self.deleted = True - return True - - @classmethod - async def get_or_create(cls: Any, *props: tuple, **kwargs: dict[str, Any]) -> list: - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, - this is an atomic operation. - Parameters must contain all required properties, any non required properties with defaults will be generated. - - Note that the post_create hook isn't called after get_or_create - - :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create - the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy = kwargs.get("lazy", False) - relationship = kwargs.get("relationship") - - # build merge query - get_or_create_params = [ - {"create": cls.deflate(p, skip_empty=True)} for p in props - ] - query, params = await cls._build_merge_query( - tuple(get_or_create_params), relationship=relationship, lazy=lazy - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = await adb.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - @classmethod - def inflate(cls: Any, node: Any) -> Any: - """ - Inflate a raw neo4j_driver node to a neomodel node - :param node: - :return: node object - """ - # support lazy loading - if isinstance(node, str) or isinstance(node, int): - snode = cls() - snode.element_id_property = node - else: - snode = super().inflate(node) - snode.element_id_property = node.element_id - - return snode - - @classmethod - def inherited_labels(cls: Any) -> list[str]: - """ - Return list of labels from nodes class hierarchy. - - :return: list - """ - return [ - scls.__label__ - for scls in cls.mro() - if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") - ] - - @classmethod - def inherited_optional_labels(cls: Any) -> list[str]: - """ - Return list of optional labels from nodes class hierarchy. - - :return: list - :rtype: list - """ - return [ - label - for scls in cls.mro() - for label in getattr(scls, "__optional_labels__", []) - if not hasattr(scls, "__abstract_node__") - ] - - async def labels(self) -> list[str]: - """ - Returns list of labels tied to the node from neo4j. - - :return: list of labels - :rtype: list - """ - self._pre_action_check("labels") - result = await self.cypher( - f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self " "RETURN labels(n)" - ) - if result is None or result[0] is None: - raise ValueError("Could not get labels, node may not exist") - return result[0][0][0] - - def _pre_action_check(self, action: str) -> None: - if hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on deleted node" - ) - if not hasattr(self, "element_id"): - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on unsaved node" - ) - - async def refresh(self) -> None: - """ - Reload the node from neo4j - """ - self._pre_action_check("refresh") - if hasattr(self, "element_id"): - results = await self.cypher( - f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self RETURN n" - ) - request = results[0] - if not request or not request[0]: - raise self.__class__.DoesNotExist("Can't refresh non existent node") - node = self.inflate(request[0][0]) - for key, val in node.__properties__.items(): - setattr(self, key, val) - else: - raise ValueError("Can't refresh unsaved node") - - @hooks - async def save(self) -> "AsyncStructuredNode": - """ - Save the node to neo4j or raise an exception - - :return: the node instance - """ - - # create or update instance node - if hasattr(self, "element_id_property"): - # update - params = self.deflate(self.__properties__, self) - query = f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self\n" - - if params: - query += "SET " - query += ",\n".join([f"n.{key} = ${key}" for key in params]) - query += "\n" - if self.inherited_labels(): - query += "\n".join( - [f"SET n:`{label}`" for label in self.inherited_labels()] - ) - await self.cypher(query, params) - elif hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.save() attempted on deleted node" - ) - else: # create - result = await self.create(self.__properties__) - created_node = result[0] - self.element_id_property = created_node.element_id - return self +adb = AsyncDatabase.get_instance() diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 55bfdd2d..2ac7d37f 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1,21 +1,21 @@ import inspect import re import string -import warnings from dataclasses import dataclass from typing import Any, AsyncIterator from typing import Optional as TOptional -from typing import Tuple, Union +from typing import Union from neomodel.async_ import relationship_manager -from neomodel.async_.core import AsyncStructuredNode, adb +from neomodel.async_.database import adb +from neomodel.async_.node import AsyncStructuredNode from neomodel.async_.relationship import AsyncStructuredRel from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty, ArrayProperty, Property -from neomodel.semantic_filters import VectorFilter +from neomodel.semantic_filters import FulltextFilter, VectorFilter from neomodel.typing import Subquery, Transformation -from neomodel.util import INCOMING, OUTGOING +from neomodel.util import RelationshipDirection CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)") @@ -23,10 +23,10 @@ def _rel_helper( lhs: str, rhs: str, - ident: TOptional[str] = None, - relation_type: TOptional[str] = None, - direction: TOptional[int] = None, - relation_properties: TOptional[dict] = None, + ident: str | None = None, + relation_type: str | None = None, + direction: int | None = None, + relation_properties: dict | None = None, **kwargs: dict[str, Any], # NOSONAR ) -> str: """ @@ -57,20 +57,19 @@ def _rel_helper( rel_props = f" {{{rel_props_str}}}" rel_def = "" - # relation_type is unspecified - if relation_type is None: - rel_def = "" - # all("*" wildcard) relation_type - elif relation_type == "*": - rel_def = "[*]" - else: - # explicit relation_type - rel_def = f"[{ident if ident else ''}:`{relation_type}`{rel_props}]" + match relation_type: + case None: # relation_type is unspecified + rel_def = "" + case "*": # all("*" wildcard) relation_type + rel_def = "[*]" + case _: # explicit relation_type + rel_def = f"[{ident if ident else ''}:`{relation_type}`{rel_props}]" stmt = "" - if direction == OUTGOING: + + if direction == RelationshipDirection.OUTGOING: stmt = f"-{rel_def}->" - elif direction == INCOMING: + elif direction == RelationshipDirection.INCOMING: stmt = f"<-{rel_def}-" else: stmt = f"-{rel_def}-" @@ -88,9 +87,9 @@ def _rel_merge_helper( lhs: str, rhs: str, ident: str = "neomodelident", - relation_type: TOptional[str] = None, - direction: TOptional[int] = None, - relation_properties: TOptional[dict] = None, + relation_type: str | None = None, + direction: int | None = None, + relation_properties: dict | None = None, **kwargs: dict[str, Any], # NOSONAR ) -> str: """ @@ -113,9 +112,9 @@ def _rel_merge_helper( :returns: string """ - if direction == OUTGOING: + if direction == RelationshipDirection.OUTGOING: stmt = "-{0}->" - elif direction == INCOMING: + elif direction == RelationshipDirection.INCOMING: stmt = "<-{0}-" else: stmt = "-{0}-" @@ -143,15 +142,14 @@ def _rel_merge_helper( rel_none_props = ( f" ON CREATE SET {rel_prop_val_str} ON MATCH SET {rel_prop_val_str}" ) - # relation_type is unspecified - if relation_type is None: - stmt = stmt.format("") - # all("*" wildcard) relation_type - elif relation_type == "*": - stmt = stmt.format("[*]") - else: - # explicit relation_type - stmt = stmt.format(f"[{ident}:`{relation_type}`{rel_props}]") + + match relation_type: + case None: # relation_type is unspecified + stmt = stmt.format("") + case "*": # all("*" wildcard) relation_type + stmt = stmt.format("[*]") + case _: # explicit relation_type + stmt = stmt.format(f"[{ident}:`{relation_type}`{rel_props}]") return f"({lhs}){stmt}({rhs}){rel_none_props}" @@ -228,7 +226,7 @@ def install_traversals( def _handle_special_operators( property_obj: Property, key: str, value: str, operator: str, prop: str -) -> Tuple[str, str, str]: +) -> tuple[str, str, str]: if operator == _SPECIAL_OPERATOR_IN: if not isinstance(value, (list, tuple)): raise ValueError( @@ -265,7 +263,7 @@ def _deflate_value( value: str, operator: str, prop: str, -) -> Tuple[str, str, str]: +) -> tuple[str, str, str]: if isinstance(property_obj, AliasProperty): prop = property_obj.aliased_to() deflated_value = getattr(cls, prop).deflate(value) @@ -280,7 +278,7 @@ def _deflate_value( def _initialize_filter_args_variables( cls: type[AsyncStructuredNode], key: str -) -> Tuple[type[AsyncStructuredNode], None, None, str, bool, str]: +) -> tuple[type[AsyncStructuredNode], None, None, str, bool, str]: current_class = cls current_rel_model = None leaf_prop = None @@ -300,7 +298,7 @@ def _initialize_filter_args_variables( def _process_filter_key( cls: type[AsyncStructuredNode], key: str -) -> Tuple[Property, str, str]: +) -> tuple[Property, str, str]: ( current_class, current_rel_model, @@ -396,33 +394,35 @@ class QueryAST: match: list[str] optional_match: list[str] where: list[str] - with_clause: TOptional[str] - return_clause: TOptional[str] - order_by: TOptional[list[str]] - skip: TOptional[int] - limit: TOptional[int] - result_class: TOptional[type] - lookup: TOptional[str] - additional_return: TOptional[list[str]] - is_count: TOptional[bool] - vector_index_query: TOptional[type] + with_clause: str | None + return_clause: str | None + order_by: list[str] | None + skip: int | None + limit: int | None + result_class: type | None + lookup: str | None + additional_return: list[str] | None + is_count: bool | None + vector_index_query: VectorFilter | None + fulltext_index_query: FulltextFilter | None def __init__( self, - match: TOptional[list[str]] = None, - optional_match: TOptional[list[str]] = None, - where: TOptional[list[str]] = None, - optional_where: TOptional[list[str]] = None, - with_clause: TOptional[str] = None, - return_clause: TOptional[str] = None, - order_by: TOptional[list[str]] = None, - skip: TOptional[int] = None, - limit: TOptional[int] = None, - result_class: TOptional[type] = None, - lookup: TOptional[str] = None, - additional_return: TOptional[list[str]] = None, - is_count: TOptional[bool] = False, - vector_index_query: TOptional[type] = None, + match: list[str] | None = None, + optional_match: list[str] | None = None, + where: list[str] | None = None, + optional_where: list[str] | None = None, + with_clause: str | None = None, + return_clause: str | None = None, + order_by: list[str] | None = None, + skip: int | None = None, + limit: int | None = None, + result_class: type | None = None, + lookup: str | None = None, + additional_return: list[str] | None = None, + is_count: bool | None = False, + vector_index_query: VectorFilter | None = None, + fulltext_index_query: FulltextFilter | None = None, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -440,13 +440,14 @@ def __init__( ) self.is_count = is_count self.vector_index_query = vector_index_query + self.fulltext_index_query = fulltext_index_query self.subgraph: dict = {} self.mixed_filters: bool = False class AsyncQueryBuilder: def __init__( - self, node_set: "AsyncBaseSet", subquery_namespace: TOptional[str] = None + self, node_set: "AsyncBaseSet", subquery_namespace: str | None = None ) -> None: self.node_set = node_set self._ast = QueryAST() @@ -454,7 +455,7 @@ def __init__( self._place_holder_registry: dict = {} self._relation_identifier_count: int = 0 self._node_identifier_count: int = 0 - self._subquery_namespace: TOptional[str] = subquery_namespace + self._subquery_namespace: str | None = subquery_namespace async def build_ast(self) -> "AsyncQueryBuilder": if isinstance(self.node_set, AsyncNodeSet) and hasattr( @@ -468,7 +469,14 @@ async def build_ast(self) -> "AsyncQueryBuilder": and hasattr(self.node_set, "vector_query") and self.node_set.vector_query ): - self.build_vector_query(self.node_set.vector_query, self.node_set.source) + self.build_vector_query() + + if ( + isinstance(self.node_set, AsyncNodeSet) + and hasattr(self.node_set, "fulltext_query") + and self.node_set.fulltext_query + ): + self.build_fulltext_query() await self.build_source(self.node_set) @@ -551,30 +559,59 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: order_by.append(f"{result[0]}.{prop}") self._ast.order_by = order_by - def build_vector_query(self, vectorfilter: "VectorFilter", source: "AsyncNodeSet"): + def build_vector_query(self): """ Query a vector indexed property on the node. """ + vector_filter = self.node_set.vector_query + source_class = self.node_set.source_class try: - attribute = getattr(source, vectorfilter.vector_attribute_name) + attribute = getattr( + self.node_set.source, vector_filter.vector_attribute_name + ) except AttributeError as e: raise AttributeError( - f"Attribute '{vectorfilter.vector_attribute_name}' not found on '{type(source).__name__}'." + f"Attribute '{vector_filter.vector_attribute_name}' not found on '{source_class.__name__}'." ) from e if not attribute.vector_index: raise AttributeError( - f"Attribute {vectorfilter.vector_attribute_name} is not declared with a vector index." + f"Attribute {vector_filter.vector_attribute_name} is not declared with a vector index." ) - vectorfilter.index_name = ( - f"vector_index_{source.__label__}_{vectorfilter.vector_attribute_name}" - ) - vectorfilter.node_set_label = source.__label__.lower() + vector_filter.index_name = f"vector_index_{source_class.__label__}_{vector_filter.vector_attribute_name}" + vector_filter.node_set_label = source_class.__label__.lower() - self._ast.vector_index_query = vectorfilter - self._ast.return_clause = f"{vectorfilter.node_set_label}, score" - self._ast.result_class = source.__class__ + self._ast.vector_index_query = vector_filter + self._ast.return_clause = f"{vector_filter.node_set_label}, score" + self._ast.result_class = source_class.__class__ + + def build_fulltext_query(self): + """ + Query a free text indexed property on the node. + """ + full_text_filter = self.node_set.fulltext_query + source_class = self.node_set.source_class + try: + attribute = getattr( + self.node_set.source, full_text_filter.fulltext_attribute_name + ) + except AttributeError as e: + raise AttributeError( + f"Atribute '{full_text_filter.fulltext_attribute_name}' not found on '{source_class.__name__}'." + ) from e + + if not attribute.fulltext_index: + raise AttributeError( + f"Attribute {full_text_filter.fulltext_attribute_name} is not declared with a full text index." + ) + + full_text_filter.index_name = f"fulltext_index_{source_class.__label__}_{full_text_filter.fulltext_attribute_name}" + full_text_filter.node_set_label = source_class.__label__.lower() + + self._ast.fulltext_index_query = full_text_filter + self._ast.return_clause = f"{full_text_filter.node_set_label}, score" + self._ast.result_class = source_class.__class__ async def build_traversal(self, traversal: "AsyncTraversal") -> str: """ @@ -614,7 +651,7 @@ def _additional_return(self, name: str) -> None: def build_traversal_from_path( self, relation: "Path", source_class: Any - ) -> Tuple[str, Any]: + ) -> tuple[str, Any]: path: str = relation.value stmt: str = "" source_class_iterator = source_class @@ -772,7 +809,7 @@ def _register_place_holder(self, key: str) -> str: def _parse_path( self, source_class: type[AsyncStructuredNode], prop: str - ) -> Tuple[str, str, Any, bool]: + ) -> tuple[str, str, Any, bool]: is_rel_filter = "|" in prop if is_rel_filter: path, prop = prop.rsplit("|", 1) @@ -839,7 +876,7 @@ def _build_filter_statements( target.append((statement, is_optional_relation)) def _parse_q_filters( - self, ident: str, q: Union[QBase, Any], source_class: type[AsyncStructuredNode] + self, ident: str, q: QBase | Any, source_class: type[AsyncStructuredNode] ) -> tuple[str, str]: target: list[tuple[str, bool]] = [] @@ -883,7 +920,7 @@ def build_where_stmt( ident: str, filters: list, source_class: type[AsyncStructuredNode], - q_filters: Union[QBase, Any, None] = None, + q_filters: QBase | Any | None = None, ) -> None: """ Construct a where statement from some filters. @@ -925,7 +962,7 @@ def build_where_stmt( def lookup_query_variable( self, path: str, return_relation: bool = False - ) -> TOptional[Tuple[str, Any, bool]]: + ) -> tuple[str, Any, bool] | None: """Retrieve the variable name generated internally for the given traversal path.""" subgraph = self._ast.subgraph if not subgraph: @@ -937,7 +974,7 @@ def lookup_query_variable( return None # Check if relation is coming from an optional MATCH - # (declared using fetch|traverse_relations) + # (declared using traverse) is_optional_relation = False for relation in self.node_set.relations_to_fetch: if relation.value == path: @@ -977,6 +1014,16 @@ def build_query(self) -> str: # This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering query += f""" WITH {self._ast.vector_index_query.node_set_label}, score""" + if self._ast.fulltext_index_query: + query += f"""CALL () {{ + CALL db.index.fulltext.queryNodes("{self._ast.fulltext_index_query.index_name}", "{self._ast.fulltext_index_query.query_string}") + YIELD node AS {self._ast.fulltext_index_query.node_set_label}, score + RETURN {self._ast.fulltext_index_query.node_set_label}, score LIMIT {self._ast.fulltext_index_query.topk} + }} + """ + # This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering + query += f""" WITH {self._ast.fulltext_index_query.node_set_label}, score""" + # Instead of using only one MATCH statement for every relation # to follow, we use one MATCH per relation (to avoid cartesian # product issues...). @@ -1125,7 +1172,7 @@ async def _count(self) -> int: results, _ = await adb.cypher_query(query, self._query_params) return int(results[0][0]) - async def _contains(self, node_element_id: TOptional[Union[str, int]]) -> bool: + async def _contains(self, node_element_id: str | int | None) -> bool: # inject id = into ast if not self._ast.return_clause and self._ast.additional_return: self._ast.return_clause = self._ast.additional_return[0] @@ -1224,7 +1271,7 @@ async def check_nonzero(self) -> bool: """ return await self.check_bool() - async def check_contains(self, obj: Union[AsyncStructuredNode, Any]) -> bool: + async def check_contains(self, obj: AsyncStructuredNode | Any) -> bool: if isinstance(obj, AsyncStructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: ast = await self.query_cls(self).build_ast() @@ -1234,7 +1281,7 @@ async def check_contains(self, obj: Union[AsyncStructuredNode, Any]) -> bool: raise ValueError("Expecting StructuredNode instance") - async def get_item(self, key: Union[int, slice]) -> TOptional["AsyncBaseSet"]: + async def get_item(self, key: int | slice) -> TOptional["AsyncBaseSet"]: if isinstance(key, slice): if key.stop and key.start: self.limit = key.stop - key.start @@ -1254,8 +1301,6 @@ async def get_item(self, key: Union[int, slice]) -> TOptional["AsyncBaseSet"]: _first_item = [node async for node in ast._execute()][0] return _first_item - return None - @dataclass class Optional: # type: ignore[no-redef] @@ -1273,7 +1318,7 @@ class Path: include_nodes_in_return: bool = True include_rels_in_return: bool = True relation_filtering: bool = False - alias: TOptional[str] = None + alias: str | None = None @dataclass @@ -1281,7 +1326,7 @@ class RelationNameResolver: """Helper to refer to a relation variable name. Since variable names are generated automatically within MATCH statements (for - anything injected using fetch_relations or traverse_relations), we need a way to + anything injected using traverse), we need a way to retrieve them. """ @@ -1302,7 +1347,7 @@ class NodeNameResolver: """Helper to refer to a node variable name. Since variable names are generated automatically within MATCH statements (for - anything injected using fetch_relations or traverse_relations), we need a way to + anything injected using traverse), we need a way to retrieve them. """ @@ -1448,13 +1493,14 @@ def __init__(self, source: Any) -> None: self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] self._unique_variables: list[str] = [] - self.vector_query: Optional[str] = None + self.vector_query: VectorFilter | None = None + self.fulltext_query: FulltextFilter | None = None def __await__(self) -> Any: return self.all().__await__() # type: ignore[attr-defined] async def _get( - self, limit: TOptional[int] = None, lazy: bool = False, **kwargs: dict[str, Any] + self, limit: int | None = None, lazy: bool = False, **kwargs: dict[str, Any] ) -> list: self.filter(**kwargs) if limit: @@ -1552,15 +1598,16 @@ def filter(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet": """ if args or kwargs: # Need to grab and remove the VectorFilter from both args and kwargs - new_args = ( - [] - ) # As args are a tuple, they're immutable. But we need to remove the vectorfilter from the arguments so they don't go into Q. + # As args are a tuple, they're immutable. But we need to remove the vectorfilter from the arguments so they don't go into Q. + new_args = [] for arg in args: if isinstance(arg, VectorFilter) and (not self.vector_query): self.vector_query = arg - new_args.append(arg) - new_args = tuple(new_args) + if isinstance(arg, FulltextFilter) and (not self.fulltext_query): + self.fulltext_query = arg + + new_args.append(arg) if ( kwargs.get("vector_filter") @@ -1569,7 +1616,14 @@ def filter(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet": ): self.vector_query = kwargs.pop("vector_filter") - self.q_filters = Q(self.q_filters & Q(*new_args, **kwargs)) + if ( + kwargs.get("fulltext_filter") + and isinstance(kwargs["fulltext_filter"], FulltextFilter) + and not self.fulltext_query + ): + self.fulltext_query = kwargs.pop("fulltext_filter") + + self.q_filters = Q(self.q_filters & Q(*tuple(new_args), **kwargs)) return self @@ -1624,7 +1678,7 @@ def order_by(self, *props: Any) -> "AsyncBaseSet": return self def _register_relation_to_fetch( - self, relation_def: Any, alias: TOptional[str] = None + self, relation_def: Any, alias: str | None = None ) -> "Path": if isinstance(relation_def, Path): item = relation_def @@ -1636,9 +1690,9 @@ def _register_relation_to_fetch( item.alias = alias return item - def unique_variables(self, *paths: tuple[str, ...]) -> "AsyncNodeSet": + def unique_variables(self, *paths: str) -> "AsyncNodeSet": """Generate unique variable names for the given paths.""" - self._unique_variables = paths + self._unique_variables = list(paths) return self def traverse( @@ -1655,60 +1709,12 @@ def traverse( self.relations_to_fetch = relations return self - def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet": - """Specify a set of relations to traverse and return.""" - warnings.warn( - "fetch_relations() will be deprecated in version 6, use traverse() instead.", - DeprecationWarning, - ) - relations = [] - for relation_name in relation_names: - if isinstance(relation_name, Optional): - relation_name = Path(value=relation_name.relation, optional=True) - relations.append(self._register_relation_to_fetch(relation_name)) - self.relations_to_fetch = relations - return self - - def traverse_relations( - self, *relation_names: tuple[str, ...], **aliased_relation_names: dict - ) -> "AsyncNodeSet": - """Specify a set of relations to traverse only.""" - - warnings.warn( - "traverse_relations() will be deprecated in version 6, use traverse() instead.", - DeprecationWarning, - ) - - def convert_to_path(input: Union[str, Optional]) -> Path: - if isinstance(input, Optional): - path = Path(value=input.relation, optional=True) - else: - path = Path(value=input) - path.include_nodes_in_return = False - path.include_rels_in_return = False - return path - - relations = [] - for relation_name in relation_names: - relations.append( - self._register_relation_to_fetch(convert_to_path(relation_name)) - ) - for alias, relation_def in aliased_relation_names.items(): - relations.append( - self._register_relation_to_fetch( - convert_to_path(relation_def), alias=alias - ) - ) - - self.relations_to_fetch = relations - return self - def annotate(self, *vars: tuple, **aliased_vars: tuple) -> "AsyncNodeSet": """Annotate node set results with extra variables.""" def register_extra_var( - vardef: Union[AggregatingFunction, ScalarFunction, Any], - varname: Union[str, None] = None, + vardef: AggregatingFunction | ScalarFunction | Any, + varname: str | None = None, ) -> None: if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results.append( @@ -1767,20 +1773,12 @@ async def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ - if ( - self.relations_to_fetch - and not self.relations_to_fetch[0].include_nodes_in_return - and not self.relations_to_fetch[0].include_rels_in_return - ): - raise NotImplementedError( - "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." - ) results: list = [] qbuilder = self.query_cls(self) await qbuilder.build_ast() if not qbuilder._ast.subgraph: raise RuntimeError( - "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + "Nothing to resolve. Make sure to include relations in the result using traverse() or filter()." ) other_nodes = {} root_node = None @@ -1789,9 +1787,6 @@ async def resolve_subgraph(self) -> list: if node.__class__ is self.source and "_" not in name: root_node = node continue - if isinstance(node, list) and isinstance(node[0], list): - other_nodes[name] = node[0] - continue other_nodes[name] = node results.append( self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) @@ -1802,7 +1797,7 @@ async def subquery( self, nodeset: "AsyncNodeSet", return_set: list[str], - initial_context: TOptional[list[str]] = None, + initial_context: list[str] | None = None, ) -> "AsyncNodeSet": """Add a subquery to this node set. @@ -1835,7 +1830,7 @@ async def subquery( raise RuntimeError(f"Variable '{var}' is not returned by subquery.") if initial_context: for var in initial_context: - if type(var) is not str and not isinstance( + if not isinstance(var, str) and not isinstance( var, (NodeNameResolver, RelationNameResolver, RawCypher) ): raise ValueError( @@ -1855,7 +1850,7 @@ def intermediate_transform( self, vars: dict[str, Transformation], distinct: bool = False, - ordering: TOptional[list] = None, + ordering: list | None = None, ) -> "AsyncNodeSet": if not vars: raise ValueError( diff --git a/neomodel/async_/node.py b/neomodel/async_/node.py new file mode 100644 index 00000000..3c59a838 --- /dev/null +++ b/neomodel/async_/node.py @@ -0,0 +1,624 @@ +""" +Node classes and metadata for the async neomodel module. +""" + +from __future__ import annotations + +import warnings +from itertools import combinations +from typing import TYPE_CHECKING, Any, Callable + +from neo4j.graph import Node + +from neomodel.async_.database import adb +from neomodel.async_.property_manager import AsyncPropertyManager +from neomodel.constants import STREAMING_WARNING +from neomodel.exceptions import DoesNotExist, NodeClassAlreadyDefined +from neomodel.hooks import hooks +from neomodel.properties import Property +from neomodel.util import _UnsavedNode, classproperty + +if TYPE_CHECKING: + from neomodel.async_.match import AsyncNodeSet + + +class NodeMeta(type): + DoesNotExist: type[DoesNotExist] + __required_properties__: tuple[str, ...] + __all_properties__: tuple[tuple[str, Any], ...] + __all_aliases__: tuple[tuple[str, Any], ...] + __all_relationships__: tuple[tuple[str, Any], ...] + __label__: str + __optional_labels__: list[str] + + defined_properties: Callable[..., dict[str, Any]] + + def __new__( + mcs: type, name: str, bases: tuple[type, ...], namespace: dict[str, Any] + ) -> Any: + namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) + cls: NodeMeta = type.__new__(mcs, name, bases, namespace) + cls.DoesNotExist._model_class = cls + + if hasattr(cls, "__abstract_node__"): + delattr(cls, "__abstract_node__") + else: + if "deleted" in namespace: + raise ValueError( + "Property name 'deleted' is not allowed as it conflicts with neomodel internals." + ) + elif "id" in namespace: + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif "element_id" in namespace: + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + for key, value in ( + (x, y) for x, y in namespace.items() if isinstance(y, Property) + ): + value.name, value.owner = key, cls + if hasattr(value, "setup") and callable(value.setup): + value.setup() + + # cache various groups of properies + cls.__required_properties__ = tuple( + name + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + if property.required or property.unique_index + ) + cls.__all_properties__ = tuple( + cls.defined_properties(aliases=False, rels=False).items() + ) + cls.__all_aliases__ = tuple( + cls.defined_properties(properties=False, rels=False).items() + ) + cls.__all_relationships__ = tuple( + cls.defined_properties(aliases=False, properties=False).items() + ) + + cls.__label__ = namespace.get("__label__", name) + cls.__optional_labels__ = namespace.get("__optional_labels__", []) + + build_class_registry(cls) + + return cls + + +def build_class_registry(cls: Any) -> None: + base_label_set = frozenset(cls.inherited_labels()) + optional_label_set = set(cls.inherited_optional_labels()) + + # Construct all possible combinations of labels + optional labels + possible_label_combinations = [ + frozenset(set(x).union(base_label_set)) + for i in range(1, len(optional_label_set) + 1) + for x in combinations(optional_label_set, i) + ] + possible_label_combinations.append(base_label_set) + + for label_set in possible_label_combinations: + if not hasattr(cls, "__target_databases__"): + if label_set not in adb._NODE_CLASS_REGISTRY: + adb._NODE_CLASS_REGISTRY[label_set] = cls + else: + raise NodeClassAlreadyDefined( + cls, adb._NODE_CLASS_REGISTRY, adb._DB_SPECIFIC_CLASS_REGISTRY + ) + else: + for database in cls.__target_databases__: + if database not in adb._DB_SPECIFIC_CLASS_REGISTRY: + adb._DB_SPECIFIC_CLASS_REGISTRY[database] = {} + if label_set not in adb._DB_SPECIFIC_CLASS_REGISTRY[database]: + adb._DB_SPECIFIC_CLASS_REGISTRY[database][label_set] = cls + else: + raise NodeClassAlreadyDefined( + cls, adb._NODE_CLASS_REGISTRY, adb._DB_SPECIFIC_CLASS_REGISTRY + ) + + +NodeBase: type = NodeMeta( + "NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True} +) + + +class AsyncStructuredNode(NodeBase): + """ + Base class for all node definitions to inherit from. + + If you want to create your own abstract classes set: + __abstract_node__ = True + """ + + # static properties + + __abstract_node__ = True + + # magic methods + + def __init__(self, *args: Any, **kwargs: Any): + if "deleted" in kwargs: + raise ValueError("deleted property is reserved for neomodel") + + for key, val in self.__all_relationships__: + self.__dict__[key] = val.build_manager(self, key) + + super().__init__(*args, **kwargs) + + def __eq__(self, other: Any) -> bool: + """ + Compare two node objects. + If both nodes were saved to the database, compare them by their element_id. + Otherwise, compare them using object id in memory. + If `other` is not a node, always return False. + """ + if not isinstance(other, (AsyncStructuredNode,)): + return False + if self.was_saved and other.was_saved: + return self.element_id == other.element_id + return id(self) == id(other) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self}>" + + def __str__(self) -> str: + return repr(self.__properties__) + + # dynamic properties + + @classproperty + def nodes(self) -> "AsyncNodeSet": + """ + Returns a NodeSet object representing all nodes of the classes label + :return: NodeSet + :rtype: NodeSet + """ + from neomodel.async_.match import AsyncNodeSet + + return AsyncNodeSet(self) + + @property + def element_id(self) -> Any | None: + if hasattr(self, "element_id_property"): + return self.element_id_property + return None + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self) -> int: + try: + return int(self.element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + @property + def was_saved(self) -> bool: + """ + Shows status of node in the database. False, if node hasn't been saved yet, True otherwise. + """ + return self.element_id is not None + + # methods + + @classmethod + async def _build_merge_query( + cls, + merge_params: tuple[dict[str, Any], ...], + update_existing: bool = False, + lazy: bool = False, + relationship: Any | None = None, + merge_by: dict[str, str | list[str]] | None = None, + ) -> tuple[str, dict[str, Any]]: + """ + Get a tuple of a CYPHER query and a params dict for the specified MERGE query. + + :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". + :type merge_params: list of dict + :param update_existing: True to update properties of existing nodes, default False to keep existing values. + :type update_existing: bool + :param lazy: False by default, specify True to get nodes with id only without the properties. + :type lazy: bool + :param relationship: Optional relationship to create when merging nodes. + :type relationship: Any | None + :param merge_by: Optional dict with 'label' and 'keys' to specify custom merge criteria. + 'label' is optional and should be a string, 'keys' is a list of strings. + If 'label' is not provided, uses the node's inherited labels. + If 'keys' is not provided, uses the node's required properties as merge keys. + :type merge_by: dict[str, str | list[str]] | None + :return: tuple of query and params + :rtype: tuple[str, dict[str, Any]] + """ + query_params: dict[str, Any] = {"merge_params": merge_params} + + # Determine merge key and labels + if merge_by: + # Use custom merge keys + merge_keys = merge_by["keys"] + merge_labels = merge_by.get("label", ":".join(cls.inherited_labels())) + + n_merge_prm = ", ".join(f"{key}: params.create.{key}" for key in merge_keys) + else: + # Use default required properties + merge_labels = ":".join(cls.inherited_labels()) + n_merge_prm = ", ".join( + ( + f"{getattr(cls, p).get_db_property_name(p)}: params.create.{getattr(cls, p).get_db_property_name(p)}" + for p in cls.__required_properties__ + ) + ) + + n_merge = f"n:{merge_labels} {{{n_merge_prm}}}" + if relationship is None: + # create "simple" unwind query + query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " + else: + # validate relationship + if not isinstance(relationship.source, AsyncStructuredNode): + raise ValueError( + f"relationship source [{repr(relationship.source)}] is not a StructuredNode" + ) + relation_type = relationship.definition.get("relation_type") + if not relation_type: + raise ValueError( + "No relation_type is specified on provided relationship" + ) + + from neomodel.async_.match import _rel_helper + + if relationship.source.element_id is None: + raise RuntimeError( + "Could not identify the relationship source, its element id was None." + ) + query_params["source_id"] = await adb.parse_element_id( + relationship.source.element_id + ) + query = f"MATCH (source:{relationship.source.__label__}) WHERE {await adb.get_id_method()}(source) = $source_id\n " + query += "WITH source\n UNWIND $merge_params as params \n " + query += "MERGE " + query += _rel_helper( + lhs="source", + rhs=n_merge, + ident=None, + relation_type=relation_type, + direction=relationship.definition["direction"], + ) + + query += "ON CREATE SET n = params.create\n " + # if update_existing, write properties on match as well + if update_existing is True: + query += "ON MATCH SET n += params.update\n" + + # close query + if lazy: + query += f"RETURN {await adb.get_id_method()}(n)" + else: + query += "RETURN n" + + return query, query_params + + @classmethod + async def create(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: + """ + Call to CREATE with parameters map. A new instance will be created and saved. + + :param props: dict of properties to create the nodes. + :type props: tuple + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type: bool + :rtype: list + """ + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + lazy = kwargs.get("lazy", False) + # create mapped query + query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" + + # close query + if lazy: + query += f" RETURN {await adb.get_id_method()}(n)" + else: + query += " RETURN n" + + results = [] + for item in [ + cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props + ]: + node, _ = await adb.cypher_query(query, {"create_params": item}) + results.extend(node[0]) + + nodes = [cls.inflate(node) for node in results] + + if not lazy and hasattr(cls, "post_create"): + for node in nodes: + node.post_create() + + return nodes + + @classmethod + async def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, + this is an atomic operation. If an instance already exists all optional properties specified will be updated. + + Note that the post_create hook isn't called after create_or_update + + :param props: List of dict arguments to get or create the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :type relationship: Any | None + :param lazy: False by default, specify True to get nodes with id only without the properties. + :type lazy: bool + :param merge_by: Optional dict with 'label' and 'keys' to specify custom merge criteria. + 'label' is optional and should be a string, 'keys' is a list of strings. + If 'label' is not provided, uses the node's inherited labels. + If 'keys' is not provided, uses the node's required properties as merge keys. + :type merge_by: dict[str, str | list[str]] | None + :return: list of nodes + :rtype: list + """ + lazy: bool = bool(kwargs.get("lazy", False)) + relationship = kwargs.get("relationship") + merge_by = kwargs.get("merge_by") + + # build merge query, make sure to update only explicitly specified properties + create_or_update_params = [] + for specified, deflated in [ + (p, cls.deflate(p, skip_empty=True)) for p in props + ]: + create_or_update_params.append( + { + "create": deflated, + "update": {k: v for k, v in deflated.items() if k in specified}, + } + ) + query, params = await cls._build_merge_query( + tuple(create_or_update_params), + update_existing=True, + relationship=relationship, + lazy=lazy, + merge_by=merge_by, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = await adb.cypher_query(query, params) + if lazy: + return [r[0] for r in results[0]] + else: + return [cls.inflate(r[0]) for r in results[0]] + + async def cypher( + self, query: str, params: dict[str, Any] | None = None + ) -> tuple[list | None, tuple[str, ...] | None]: + """ + Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. + + :param query: cypher query string + :type: string + :param params: query parameters + :type: dict + :return: tuple containing a list of query results, and the meta information as a tuple + :rtype: tuple + """ + self._pre_action_check("cypher") + _params = params or {} + if self.element_id is None: + raise ValueError("Can't run cypher operation on unsaved node") + element_id = await adb.parse_element_id(self.element_id) + _params.update({"self": element_id}) + return await adb.cypher_query(query, _params) + + @hooks + async def delete(self) -> bool: + """ + Delete a node and its relationships + + :return: True + """ + self._pre_action_check("delete") + await self.cypher( + f"MATCH (self) WHERE {await adb.get_id_method()}(self)=$self DETACH DELETE self" + ) + delattr(self, "element_id_property") + self.deleted = True + return True + + @classmethod + async def get_or_create(cls: Any, *props: tuple, **kwargs: dict[str, Any]) -> list: + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, + this is an atomic operation. + Parameters must contain all required properties, any non required properties with defaults will be generated. + + Note that the post_create hook isn't called after get_or_create + + :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create + the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :type relationship: Any | None + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type lazy: bool + :param merge_by: Optional dict with 'label' and 'keys' to specify custom merge criteria. + 'label' is optional and should be a string, 'keys' is a list of strings. + If 'label' is not provided, uses the node's inherited labels. + If 'keys' is not provided, uses the node's required properties as merge keys. + :type merge_by: dict[str, str | list[str]] | None + :return: list of nodes + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + merge_by = kwargs.get("merge_by") + + # build merge query + get_or_create_params = [ + {"create": cls.deflate(p, skip_empty=True)} for p in props + ] + query, params = await cls._build_merge_query( + tuple(get_or_create_params), + relationship=relationship, + lazy=lazy, + merge_by=merge_by, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = await adb.cypher_query(query, params) + if lazy: + return [r[0] for r in results[0]] + else: + return [cls.inflate(r[0]) for r in results[0]] + + @classmethod + def inflate(cls: Any, graph_entity: Node) -> Any: # type: ignore[override] + """ + Inflate a raw neo4j_driver node to a neomodel node + :param graph_entity: node + :return: node object + """ + # support lazy loading + if isinstance(graph_entity, str) or isinstance(graph_entity, int): + snode = cls() + snode.element_id_property = graph_entity + else: + snode = super().inflate(graph_entity) + snode.element_id_property = graph_entity.element_id + + return snode + + @classmethod + def inherited_labels(cls: Any) -> list[str]: + """ + Return list of labels from nodes class hierarchy. + + :return: list + """ + return [ + scls.__label__ + for scls in cls.mro() + if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") + ] + + @classmethod + def inherited_optional_labels(cls: Any) -> list[str]: + """ + Return list of optional labels from nodes class hierarchy. + + :return: list + :rtype: list + """ + return [ + label + for scls in cls.mro() + for label in getattr(scls, "__optional_labels__", []) + if not hasattr(scls, "__abstract_node__") + ] + + async def labels(self) -> list[str]: + """ + Returns list of labels tied to the node from neo4j. + + :return: list of labels + :rtype: list + """ + self._pre_action_check("labels") + result = await self.cypher( + f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self " "RETURN labels(n)" + ) + if result is None or result[0] is None: + raise ValueError("Could not get labels, node may not exist") + return result[0][0][0] + + def _pre_action_check(self, action: str) -> None: + if hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on deleted node" + ) + if not hasattr(self, "element_id"): + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on unsaved node" + ) + + async def refresh(self) -> None: + """ + Reload the node from neo4j + """ + self._pre_action_check("refresh") + if hasattr(self, "element_id"): + results = await self.cypher( + f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self RETURN n" + ) + request = results[0] + if not request or not request[0]: + raise self.__class__.DoesNotExist("Can't refresh non existent node") + node = self.inflate(request[0][0]) + for key, val in node.__properties__.items(): + setattr(self, key, val) + else: + raise ValueError("Can't refresh unsaved node") + + @hooks + async def save(self) -> "AsyncStructuredNode": + """ + Save the node to neo4j or raise an exception + + :return: the node instance + """ + + # create or update instance node + if hasattr(self, "element_id_property"): + # update + params = self.deflate(self.__properties__, self) + query = f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self\n" + + if params: + query += "SET " + query += ",\n".join([f"n.{key} = ${key}" for key in params]) + query += "\n" + if self.inherited_labels(): + query += "\n".join( + [f"SET n:`{label}`" for label in self.inherited_labels()] + ) + await self.cypher(query, params) + elif hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.save() attempted on deleted node" + ) + else: # create + result = await self.create(self.__properties__) + created_node = result[0] + self.element_id_property = created_node.element_id + return self diff --git a/neomodel/async_/path.py b/neomodel/async_/path.py index e1d5eb3e..41d1f916 100644 --- a/neomodel/async_/path.py +++ b/neomodel/async_/path.py @@ -2,7 +2,8 @@ from neo4j.graph import Path -from neomodel.async_.core import AsyncStructuredNode, adb +from neomodel.async_.database import adb +from neomodel.async_.node import AsyncStructuredNode from neomodel.async_.relationship import AsyncStructuredRel diff --git a/neomodel/async_/property_manager.py b/neomodel/async_/property_manager.py index deeea5ed..ff72f28b 100644 --- a/neomodel/async_/property_manager.py +++ b/neomodel/async_/property_manager.py @@ -1,7 +1,7 @@ import types from typing import Any -from neo4j.graph import Entity +from neo4j.graph import Node, Relationship from neomodel.exceptions import RequiredProperty from neomodel.properties import AliasProperty, Property @@ -101,9 +101,9 @@ def deflate( return deflated @classmethod - def inflate(cls: Any, graph_entity: Entity) -> Any: + def inflate(cls: Any, graph_entity: Node | Relationship) -> Any: """ - Inflate the properties of a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance + Inflate the properties of a graph entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance of cls. Includes mapping from database property name (see Property.db_property) -> python class attribute name. Ignores any properties that are not defined as python attributes in the class definition. diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index f84be73d..f1b30afc 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -1,8 +1,8 @@ -from typing import Any, Optional +from typing import Any from neo4j.graph import Relationship -from neomodel.async_.core import adb +from neomodel.async_.database import adb from neomodel.async_.property_manager import AsyncPropertyManager from neomodel.hooks import hooks from neomodel.properties import Property @@ -53,23 +53,27 @@ class AsyncStructuredRel(StructuredRelBase): Base class for relationship objects """ - def __init__(self, *args: Any, **kwargs: dict) -> None: - super().__init__(*args, **kwargs) + element_id_property: str + _start_node_element_id_property: str + _end_node_element_id_property: str + + _start_node_class: Any + _end_node_class: Any @property - def element_id(self) -> Optional[Any]: + def element_id(self) -> str | None: if hasattr(self, "element_id_property"): return self.element_id_property return None @property - def _start_node_element_id(self) -> Optional[Any]: + def _start_node_element_id(self) -> str | None: if hasattr(self, "_start_node_element_id_property"): return self._start_node_element_id_property return None @property - def _end_node_element_id(self) -> Optional[Any]: + def _end_node_element_id(self) -> str | None: if hasattr(self, "_end_node_element_id_property"): return self._end_node_element_id_property return None @@ -165,16 +169,16 @@ async def end_node(self) -> Any: return results[0][0][0] @classmethod - def inflate(cls: Any, rel: Relationship) -> "AsyncStructuredRel": + def inflate(cls: Any, graph_entity: Relationship) -> "AsyncStructuredRel": # type: ignore[override] """ Inflate a neo4j_driver relationship object to a neomodel object - :param rel: + :param graph_entity: Relationship :return: StructuredRel """ - srel = super().inflate(rel) - if rel.start_node is not None: - srel._start_node_element_id_property = rel.start_node.element_id - if rel.end_node is not None: - srel._end_node_element_id_property = rel.end_node.element_id - srel.element_id_property = rel.element_id + srel = super().inflate(graph_entity) + if graph_entity.start_node is not None: + srel._start_node_element_id_property = graph_entity.start_node.element_id + if graph_entity.end_node is not None: + srel._end_node_element_id_property = graph_entity.end_node.element_id + srel.element_id_property = graph_entity.element_id return srel diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index db7e787d..0e48d1b8 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -2,10 +2,9 @@ import inspect import sys from importlib import import_module -from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Optional -from neomodel import config -from neomodel.async_.core import adb +from neomodel.async_.database import adb from neomodel.async_.match import ( AsyncNodeSet, AsyncTraversal, @@ -15,9 +14,7 @@ from neomodel.async_.relationship import AsyncStructuredRel from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.util import ( - EITHER, - INCOMING, - OUTGOING, + RelationshipDirection, enumerate_traceback, get_graph_entity_properties, ) @@ -68,9 +65,9 @@ def __init__(self, source: Any, key: str, definition: dict): def __str__(self) -> str: direction = "either" - if self.definition["direction"] == OUTGOING: + if self.definition["direction"] == RelationshipDirection.OUTGOING: direction = "a outgoing" - elif self.definition["direction"] == INCOMING: + elif self.definition["direction"] == RelationshipDirection.INCOMING: direction = "a incoming" return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" @@ -78,9 +75,7 @@ def __str__(self) -> str: def __await__(self) -> Any: return self.all().__await__() # type: ignore[attr-defined] - async def _check_cardinality( - self, node: "AsyncStructuredNode", soft_check: bool = False - ) -> None: + async def check_cardinality(self, node: "AsyncStructuredNode") -> None: """ Check whether a new connection to a node would violate the cardinality of the relationship. @@ -101,8 +96,8 @@ def _check_node(self, obj: type["AsyncStructuredNode"]) -> None: @check_source async def connect( - self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None - ) -> Optional[AsyncStructuredRel]: + self, node: "AsyncStructuredNode", properties: dict[str, Any] | None = None + ) -> AsyncStructuredRel | None: """ Connect a node @@ -112,7 +107,7 @@ async def connect( :return: """ self._check_node(node) - await self._check_cardinality(node) + await self.check_cardinality(node) # Check for cardinality on the remote end. for rel_name, rel_def in node.defined_properties( @@ -129,9 +124,7 @@ async def connect( # If we have found the inverse relationship, we need to check # its cardinality. inverse_rel = getattr(node, rel_name) - await inverse_rel._check_cardinality( - self.source, soft_check=config.SOFT_INVERSE_CARDINALITY_CHECK - ) + await inverse_rel.check_cardinality(self.source) break if not self.definition["model"] and properties: @@ -188,7 +181,7 @@ async def connect( @check_source async def replace( - self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None + self, node: "AsyncStructuredNode", properties: dict[str, Any] | None = None ) -> None: """ Disconnect all existing nodes and connect the supplied node @@ -204,7 +197,7 @@ async def replace( @check_source async def relationship( self, node: "AsyncStructuredNode" - ) -> Optional[AsyncStructuredRel]: + ) -> AsyncStructuredRel | None: """ Retrieve the relationship object for this first relationship between self and node. @@ -258,7 +251,7 @@ async def all_relationships( def _set_start_end_cls( self, rel_instance: AsyncStructuredRel, obj: "AsyncStructuredNode" ) -> AsyncStructuredRel: - if self.definition["direction"] == INCOMING: + if self.definition["direction"] == RelationshipDirection.INCOMING: rel_instance._start_node_class = obj.__class__ rel_instance._end_node_class = self.source_class else: @@ -456,7 +449,7 @@ async def check_nonzero(self) -> bool: async def check_contains(self, obj: Any) -> bool: return await self._new_traversal().check_contains(obj) - async def get_item(self, key: Union[int, slice]) -> Any: + async def get_item(self, key: int | slice) -> Any: return await self._new_traversal().get_item(key) @@ -467,7 +460,7 @@ def __init__( cls_name: str, direction: int, manager: type[AsyncRelationshipManager] = AsyncRelationshipManager, - model: Optional[type[AsyncStructuredRel]] = None, + model: type[AsyncStructuredRel] | None = None, ) -> None: self._validate_class(cls_name, model) @@ -520,7 +513,7 @@ def __init__( adb._NODE_CLASS_REGISTRY[label_set] = model def _validate_class( - self, cls_name: str, model: Optional[type[AsyncStructuredRel]] = None + self, cls_name: str, model: type[AsyncStructuredRel] | None = None ) -> None: if not isinstance(cls_name, (str, object)): raise ValueError("Expected class name or class got " + repr(cls_name)) @@ -586,10 +579,14 @@ def __init__( cls_name: str, relation_type: str, cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore, - model: Optional[type[AsyncStructuredRel]] = None, + model: type[AsyncStructuredRel] | None = None, ) -> None: super().__init__( - relation_type, cls_name, OUTGOING, manager=cardinality, model=model + relation_type, + cls_name, + RelationshipDirection.OUTGOING, + manager=cardinality, + model=model, ) @@ -599,10 +596,14 @@ def __init__( cls_name: str, relation_type: str, cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore, - model: Optional[type[AsyncStructuredRel]] = None, + model: type[AsyncStructuredRel] | None = None, ) -> None: super().__init__( - relation_type, cls_name, INCOMING, manager=cardinality, model=model + relation_type, + cls_name, + RelationshipDirection.INCOMING, + manager=cardinality, + model=model, ) @@ -612,8 +613,12 @@ def __init__( cls_name: str, relation_type: str, cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore, - model: Optional[type[AsyncStructuredRel]] = None, + model: type[AsyncStructuredRel] | None = None, ) -> None: super().__init__( - relation_type, cls_name, EITHER, manager=cardinality, model=model + relation_type, + cls_name, + RelationshipDirection.EITHER, + manager=cardinality, + model=model, ) diff --git a/neomodel/async_/transaction.py b/neomodel/async_/transaction.py new file mode 100644 index 00000000..eb3505da --- /dev/null +++ b/neomodel/async_/transaction.py @@ -0,0 +1,115 @@ +""" +Transaction management for the async neomodel module. +""" + +import warnings +from asyncio import iscoroutinefunction +from functools import wraps +from typing import Any, Callable + +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError + +from neomodel._async_compat.util import AsyncUtil +from neomodel.async_.database import AsyncDatabase +from neomodel.constants import NOT_COROUTINE_ERROR +from neomodel.exceptions import UniqueProperty + + +class AsyncTransactionProxy: + def __init__( + self, + db: AsyncDatabase, + access_mode: str | None = None, + parallel_runtime: bool | None = False, + ): + self.db: AsyncDatabase = db + self.access_mode: str | None = access_mode + self.parallel_runtime: bool | None = parallel_runtime + self.bookmarks: Bookmarks | None = None + self.last_bookmarks: Bookmarks | None = None + + async def __aenter__(self) -> "AsyncTransactionProxy": + if self.parallel_runtime and not await self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False + self.db._parallel_runtime = self.parallel_runtime + await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) + self.bookmarks = None + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.db._parallel_runtime = False + if exc_value: + await self.db.rollback() + + if ( + exc_type is ClientError + and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" + ): + raise UniqueProperty(exc_value.message) + + if not exc_value: + self.last_bookmarks = await self.db.commit() + + def __call__(self, func: Callable) -> Callable: + if AsyncUtil.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Callable: + async with self: + return await func(*args, **kwargs) + + return wrapper + + @property + def with_bookmark(self) -> "BookmarkingAsyncTransactionProxy": + return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) + + +class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): + def __call__(self, func: Callable) -> Callable: + if AsyncUtil.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + async def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, None]: + self.bookmarks = kwargs.pop("bookmarks", None) + + async with self: + result = await func(*args, **kwargs) + self.last_bookmarks = None + + return result, self.last_bookmarks + + return wrapper + + +class ImpersonationHandler: + def __init__(self, db: AsyncDatabase, impersonated_user: str): + self.db = db + self.impersonated_user = impersonated_user + + def __enter__(self) -> "ImpersonationHandler": + self.db.impersonated_user = self.impersonated_user + return self + + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: + self.db.impersonated_user = None + + print("\nException type:", exception_type) + print("\nException value:", exception_value) + print("\nTraceback:", exception_traceback) + + def __call__(self, func: Callable) -> Callable: + def wrapper(*args: Any, **kwargs: Any) -> Callable: + with self: + return func(*args, **kwargs) + + return wrapper diff --git a/neomodel/config.py b/neomodel/config.py index 72fa4374..cb8b86d0 100644 --- a/neomodel/config.py +++ b/neomodel/config.py @@ -1,30 +1,513 @@ +""" +Neomodel configuration module. + +This module provides a modern dataclass-based configuration system with validation +and environment variable support, while maintaining backward compatibility. +""" + +import os +import sys +import warnings +from dataclasses import dataclass, field, fields +from typing import Any, Dict, Optional +from urllib.parse import urlparse + import neo4j +from neo4j import Driver from neomodel._version import __version__ -# Use this to connect with automatically created driver -# The following options are the default ones that will be used as driver config -DATABASE_URL = "bolt://neo4j:foobarbaz@localhost:7687" -FORCE_TIMEZONE = False - -CONNECTION_ACQUISITION_TIMEOUT = 60.0 -CONNECTION_TIMEOUT = 30.0 -ENCRYPTED = False -KEEP_ALIVE = True -MAX_CONNECTION_LIFETIME = 3600 -MAX_CONNECTION_POOL_SIZE = 100 -MAX_TRANSACTION_RETRY_TIME = 30.0 -RESOLVER = None -TRUSTED_CERTIFICATES = neo4j.TrustSystemCAs() -USER_AGENT = f"neomodel/v{__version__}" - -# Use this to connect with your self-managed driver instead -# DRIVER = neo4j.GraphDatabase().driver( -# "bolt://localhost:7687", auth=("neo4j", "foobarbaz") -# ) -DRIVER = None -# Use this to connect to a specific database when using the self-managed driver -DATABASE_NAME = None - -# Use this to enable soft cardinality check -SOFT_INVERSE_CARDINALITY_CHECK = True + +@dataclass +class NeomodelConfig: + """ + Neomodel configuration using dataclasses with validation and environment variable support. + + This class provides a modern, type-safe configuration system that can be loaded + from environment variables and validated at startup. + """ + + # Connection settings + database_url: str = field( + default="bolt://neo4j:foobarbaz@localhost:7687", + metadata={ + "env_var": "NEOMODEL_DATABASE_URL", + "description": "Graph database connection URL", + }, + ) + driver: Driver | None = field( + default=None, + metadata={"env_var": None, "description": "Custom database driver instance"}, + ) + database_name: str | None = field( + default=None, + metadata={ + "env_var": "NEOMODEL_DATABASE_NAME", + "description": "Database name for neomodel-managed driver instance", + }, + ) + + # Driver configuration (for neomodel-managed connections) + connection_acquisition_timeout: float = field( + default=60.0, + metadata={ + "env_var": "NEOMODEL_CONNECTION_ACQUISITION_TIMEOUT", + "description": "Connection acquisition timeout in seconds", + }, + ) + connection_timeout: float = field( + default=30.0, + metadata={ + "env_var": "NEOMODEL_CONNECTION_TIMEOUT", + "description": "Connection timeout in seconds", + }, + ) + encrypted: bool = field( + default=False, + metadata={ + "env_var": "NEOMODEL_ENCRYPTED", + "description": "Enable encrypted connections", + }, + ) + keep_alive: bool = field( + default=True, + metadata={ + "env_var": "NEOMODEL_KEEP_ALIVE", + "description": "Enable keep-alive connections", + }, + ) + max_connection_lifetime: int = field( + default=3600, + metadata={ + "env_var": "NEOMODEL_MAX_CONNECTION_LIFETIME", + "description": "Maximum connection lifetime in seconds", + }, + ) + max_connection_pool_size: int = field( + default=100, + metadata={ + "env_var": "NEOMODEL_MAX_CONNECTION_POOL_SIZE", + "description": "Maximum connection pool size", + }, + ) + max_transaction_retry_time: float = field( + default=30.0, + metadata={ + "env_var": "NEOMODEL_MAX_TRANSACTION_RETRY_TIME", + "description": "Maximum transaction retry time in seconds", + }, + ) + resolver: Any | None = field( + default=None, + metadata={ + "env_var": None, + "description": "Custom resolver for connection routing", + }, + ) + trusted_certificates: Any = field( + default_factory=neo4j.TrustSystemCAs, + metadata={ + "env_var": None, + "description": "Trusted certificates for encrypted connections", + }, + ) + user_agent: str = field( + default=f"neomodel/v{__version__}", + metadata={ + "env_var": "NEOMODEL_USER_AGENT", + "description": "User agent string for connections", + }, + ) + + # Neomodel-specific settings + force_timezone: bool = field( + default=False, + metadata={ + "env_var": "NEOMODEL_FORCE_TIMEZONE", + "description": "Force timezone-aware datetime objects", + }, + ) + soft_cardinality_check: bool = field( + default=False, + metadata={ + "env_var": "NEOMODEL_SOFT_CARDINALITY_CHECK", + "description": "Enable soft cardinality checking (warnings only)", + }, + ) + cypher_debug: bool = field( + default=False, + metadata={ + "env_var": "NEOMODEL_CYPHER_DEBUG", + "description": "Enable Cypher debug logging", + }, + ) + slow_queries: float = field( + default=0.0, + metadata={ + "env_var": "NEOMODEL_SLOW_QUERIES", + "description": "Threshold in seconds for slow query logging (0 = disabled)", + }, + ) + + def __post_init__(self): + """Validate configuration after initialization.""" + self._validate_config() + + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute and validate configuration.""" + super().__setattr__(name, value) + # Only validate if we're not in __init__ or __post_init__ + if hasattr(self, "_initialized"): + # Don't validate here - let the calling code handle validation + pass + else: + # Mark as initialized after first attribute set + if name != "_initialized": + self._initialized = True + + def _validate_config(self) -> None: + """Validate configuration values.""" + # Validate database URL format + if self.database_url: + try: + parsed = urlparse(self.database_url) + if not parsed.scheme or not parsed.netloc: + raise ValueError( + f"Invalid database URL format: {self.database_url}" + ) + except Exception as e: + raise ValueError(f"Invalid database URL: {e}") from e + + # Validate numeric values + if self.connection_acquisition_timeout <= 0: + raise ValueError("connection_acquisition_timeout must be positive") + + if self.connection_timeout <= 0: + raise ValueError("connection_timeout must be positive") + + if self.max_connection_lifetime <= 0: + raise ValueError("max_connection_lifetime must be positive") + + if self.max_connection_pool_size <= 0: + raise ValueError("max_connection_pool_size must be positive") + + if self.max_transaction_retry_time <= 0: + raise ValueError("max_transaction_retry_time must be positive") + + # Validate slow_queries threshold + if self.slow_queries < 0: + raise ValueError("slow_queries must be non-negative") + + @classmethod + def from_env(cls) -> "NeomodelConfig": + """Create configuration from environment variables.""" + config_data: dict[str, Any] = {} + + # Get all fields with their metadata + for field_info in fields(cls): + env_var = field_info.metadata.get("env_var") + if env_var and env_var in os.environ: + value = os.environ[env_var] + field_type = field_info.type + + # Convert string values to appropriate types + if field_type == bool: + config_data[field_info.name] = value.lower() in ( + "true", + "1", + "yes", + "on", + ) + elif field_type == int: + config_data[field_info.name] = int(value) + elif field_type == float: + config_data[field_info.name] = float(value) + else: + config_data[field_info.name] = value + + return cls(**config_data) + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary.""" + result: dict[str, Any] = {} + for field_info in fields(self): + value = getattr(self, field_info.name) + # Skip non-serializable values + if field_info.name not in ("driver", "resolver", "trusted_certificates"): + result[field_info.name] = value + return result + + def update(self, **kwargs) -> None: + """Update configuration values.""" + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + warnings.warn(f"Unknown configuration option: {key}") + + # Re-validate after update + self._validate_config() # pylint: disable=protected-access + + +# Global configuration instance +_config: Optional[NeomodelConfig] = None + + +def get_config() -> NeomodelConfig: + """Get the global configuration instance.""" + global _config # noqa: PLW0603 - usage of 'global' is required here for module-level singleton pattern + if _config is None: + _config = NeomodelConfig.from_env() + return _config + + +def set_config(config: NeomodelConfig) -> None: + """Set the global configuration instance.""" + global _config # noqa: PLW0603 - usage of 'global' is required here for module-level singleton pattern + _config = config + + +def reset_config() -> None: + """Reset the global configuration to default values.""" + global _config # noqa: PLW0603 - usage of 'global' is required here for module-level singleton pattern + _config = None + + +def clear_deprecation_warnings() -> None: + """Clear the set of deprecation warnings that have been shown. + + This is primarily useful for testing purposes to reset the warning state. + """ + global _legacy_attr_warnings + _legacy_attr_warnings.clear() + + +# Backward compatibility: Create module-level attributes that delegate to the config instance +_legacy_attr_warnings: set[str] = set() + + +def _get_attr(name: str) -> Any: + """Get attribute from the global config instance.""" + # Issue deprecation warning for legacy attribute access + if name not in _legacy_attr_warnings: + _legacy_attr_warnings.add(name) + warnings.warn( + f"Accessing config.{name.upper()} is deprecated and will be removed in a future version. Use the modern configuration API instead: " + f"from neomodel import get_config; config = get_config(); config.{name}", + DeprecationWarning, + stacklevel=3, + ) + + config = get_config() + return getattr(config, name) + + +def _set_attr(name: str, value: Any) -> None: + """Set attribute on the global config instance.""" + # Issue deprecation warning for legacy attribute setting + if name not in _legacy_attr_warnings: + _legacy_attr_warnings.add(name) + warnings.warn( + f"Setting config.{name.upper()} is deprecated and will be removed in a future version. Use the modern configuration API instead: " + f"from neomodel import get_config; config = get_config(); config.{name} = value", + DeprecationWarning, + stacklevel=3, + ) + + config = get_config() + original_value = getattr(config, name) + setattr(config, name, value) + try: + config._validate_config() # pylint: disable=protected-access + except ValueError: + # If validation fails, revert the change + setattr(config, name, original_value) + raise + + +# Create module-level properties for backward compatibility +class _ConfigModule: + """Module-level configuration access for backward compatibility.""" + + @property + def DATABASE_URL( + self, + ) -> str: + return _get_attr("database_url") + + @DATABASE_URL.setter + def DATABASE_URL(self, value: str) -> None: + _set_attr("database_url", value) + + @property + def DRIVER( + self, + ) -> Driver | None: + return _get_attr("driver") + + @DRIVER.setter + def DRIVER(self, value: Driver | None) -> None: + _set_attr("driver", value) + + @property + def DATABASE_NAME( + self, + ) -> str | None: + return _get_attr("database_name") + + @DATABASE_NAME.setter + def DATABASE_NAME(self, value: str | None) -> None: + _set_attr("database_name", value) + + @property + def CONNECTION_ACQUISITION_TIMEOUT( + self, + ) -> float: + return _get_attr("connection_acquisition_timeout") + + @CONNECTION_ACQUISITION_TIMEOUT.setter + def CONNECTION_ACQUISITION_TIMEOUT(self, value: float) -> None: + _set_attr("connection_acquisition_timeout", value) + + @property + def CONNECTION_TIMEOUT( + self, + ) -> float: + return _get_attr("connection_timeout") + + @CONNECTION_TIMEOUT.setter + def CONNECTION_TIMEOUT(self, value: float) -> None: + _set_attr("connection_timeout", value) + + @property + def ENCRYPTED( + self, + ) -> bool: + return _get_attr("encrypted") + + @ENCRYPTED.setter + def ENCRYPTED(self, value: bool) -> None: + _set_attr("encrypted", value) + + @property + def KEEP_ALIVE( + self, + ) -> bool: + return _get_attr("keep_alive") + + @KEEP_ALIVE.setter + def KEEP_ALIVE(self, value: bool) -> None: + _set_attr("keep_alive", value) + + @property + def MAX_CONNECTION_LIFETIME( + self, + ) -> int: + return _get_attr("max_connection_lifetime") + + @MAX_CONNECTION_LIFETIME.setter + def MAX_CONNECTION_LIFETIME(self, value: int) -> None: + _set_attr("max_connection_lifetime", value) + + @property + def MAX_CONNECTION_POOL_SIZE( + self, + ) -> int: + return _get_attr("max_connection_pool_size") + + @MAX_CONNECTION_POOL_SIZE.setter + def MAX_CONNECTION_POOL_SIZE(self, value: int) -> None: + _set_attr("max_connection_pool_size", value) + + @property + def MAX_TRANSACTION_RETRY_TIME( + self, + ) -> float: + return _get_attr("max_transaction_retry_time") + + @MAX_TRANSACTION_RETRY_TIME.setter + def MAX_TRANSACTION_RETRY_TIME(self, value: float) -> None: + _set_attr("max_transaction_retry_time", value) + + @property + def RESOLVER( + self, + ) -> Any | None: + return _get_attr("resolver") + + @RESOLVER.setter + def RESOLVER(self, value: Any | None) -> None: + _set_attr("resolver", value) + + @property + def TRUSTED_CERTIFICATES( + self, + ) -> Any: + return _get_attr("trusted_certificates") + + @TRUSTED_CERTIFICATES.setter + def TRUSTED_CERTIFICATES(self, value: Any) -> None: + _set_attr("trusted_certificates", value) + + @property + def USER_AGENT( + self, + ) -> str: + return _get_attr("user_agent") + + @USER_AGENT.setter + def USER_AGENT(self, value: str) -> None: + _set_attr("user_agent", value) + + @property + def FORCE_TIMEZONE( + self, + ) -> bool: + return _get_attr("force_timezone") + + @FORCE_TIMEZONE.setter + def FORCE_TIMEZONE(self, value: bool) -> None: + _set_attr("force_timezone", value) + + @property + def SOFT_CARDINALITY_CHECK( + self, + ) -> bool: + return _get_attr("soft_cardinality_check") + + @SOFT_CARDINALITY_CHECK.setter + def SOFT_CARDINALITY_CHECK(self, value: bool) -> None: + _set_attr("soft_cardinality_check", value) + + @property + def CYPHER_DEBUG( + self, + ) -> bool: + return _get_attr("cypher_debug") + + @CYPHER_DEBUG.setter + def CYPHER_DEBUG(self, value: bool) -> None: + _set_attr("cypher_debug", value) + + @property + def SLOW_QUERIES( + self, + ) -> float: + return _get_attr("slow_queries") + + @SLOW_QUERIES.setter + def SLOW_QUERIES(self, value: float) -> None: + _set_attr("slow_queries", value) + + +# Create the module instance for backward compatibility +_current_module = sys.modules[__name__] +_config_module = _ConfigModule() + +# Replace the module with the config module instance +sys.modules[__name__] = _config_module # type: ignore[assignment] + +# Copy all attributes from the original module to maintain backward compatibility +for attr_name in dir(_current_module): + if not attr_name.startswith("_") and not hasattr(_config_module, attr_name): + setattr(_config_module, attr_name, getattr(_current_module, attr_name)) diff --git a/neomodel/constants.py b/neomodel/constants.py new file mode 100644 index 00000000..956cd068 --- /dev/null +++ b/neomodel/constants.py @@ -0,0 +1,45 @@ +""" +Constants used in various modules of neomodel. +""" + +# Error message constants +RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" +INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" +CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" +STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" +NOT_COROUTINE_ERROR = "The decorated function must be a coroutine" + +# Access mode constants +ACCESS_MODE_WRITE = "WRITE" +ACCESS_MODE_READ = "READ" + +# Database edition constants +ENTERPRISE_EDITION_TAG = "enterprise" + +# Neo4j version constants +VERSION_LEGACY_ID = "4" +VERSION_RELATIONSHIP_CONSTRAINTS_SUPPORT = "5.7" +VERSION_PARALLEL_RUNTIME_SUPPORT = "5.13" +VERSION_VECTOR_INDEXES_SUPPORT = "5.15" +VERSION_FULLTEXT_INDEXES_SUPPORT = "5.16" +VERSION_RELATIONSHIP_VECTOR_INDEXES_SUPPORT = "5.18" + +# ID method constants +LEGACY_ID_METHOD = "id" +ELEMENT_ID_METHOD = "elementId" + +# Cypher query constants +LIST_CONSTRAINTS_COMMAND = "SHOW CONSTRAINTS" +DROP_CONSTRAINT_COMMAND = "DROP CONSTRAINT " +DROP_INDEX_COMMAND = "DROP INDEX " + +# Index type constants +LOOKUP_INDEX_TYPE = "LOOKUP" + +# Info messages constants +NO_TRANSACTION_IN_PROGRESS = "No transaction in progress" +NO_SESSION_OPEN = "No session open" +UNKNOWN_SERVER_VERSION = """ + Unable to perform this operation because the database server version is not known. + This might mean that the database server is offline. +""" diff --git a/neomodel/contrib/async_/semi_structured.py b/neomodel/contrib/async_/semi_structured.py index 810ea10f..fa1dd5b7 100644 --- a/neomodel/contrib/async_/semi_structured.py +++ b/neomodel/contrib/async_/semi_structured.py @@ -1,6 +1,5 @@ -from neomodel.async_.core import AsyncStructuredNode +from neomodel.async_.node import AsyncStructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict -from neomodel.util import get_graph_entity_properties class AsyncSemiStructuredNode(AsyncStructuredNode): diff --git a/neomodel/contrib/sync_/semi_structured.py b/neomodel/contrib/sync_/semi_structured.py index 97c43c39..dcfa5a9a 100644 --- a/neomodel/contrib/sync_/semi_structured.py +++ b/neomodel/contrib/sync_/semi_structured.py @@ -1,6 +1,5 @@ from neomodel.exceptions import DeflateConflict, InflateConflict -from neomodel.sync_.core import StructuredNode -from neomodel.util import get_graph_entity_properties +from neomodel.sync_.node import StructuredNode class SemiStructuredNode(StructuredNode): diff --git a/neomodel/exceptions.py b/neomodel/exceptions.py index 45b24e0e..34131bf6 100644 --- a/neomodel/exceptions.py +++ b/neomodel/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type, Union +from typing import Any, Type class NeomodelException(Exception): @@ -26,7 +26,7 @@ class CardinalityViolation(NeomodelException): For example a relationship type `OneOrMore` returns no nodes. """ - def __init__(self, rel_manager: Any, actual: Union[int, str]): + def __init__(self, rel_manager: Any, actual: int | str): self.rel_manager = str(rel_manager) self.actual = str(actual) @@ -116,7 +116,7 @@ def __str__(self) -> str: return f""" Relationship of type {relationship_type} does not resolve to any of the known objects {self._get_node_class_registry_formatted()} - Note that when using the fetch_relations method, the relationship type must be defined in the model, even if only defined to StructuredRel. + Note that when using the traverse method, the relationship type must be defined in the model, even if only defined to StructuredRel. Otherwise, neomodel will not be able to determine which relationship model to resolve into. """ @@ -185,7 +185,7 @@ def __str__(self) -> str: class DoesNotExist(NeomodelException): - _model_class: Optional[Type] = None + _model_class: Type | None = None """ This class property refers the model class that a subclass of this class belongs to. It is set by :class:`~neomodel.core.NodeMeta`. @@ -217,7 +217,7 @@ def __str__(self) -> str: class InflateError(ValueError, NeomodelException): - def __init__(self, key: str, cls: Any, msg: str, obj: Optional[Any] = None): + def __init__(self, key: str, cls: Any, msg: str, obj: Any | None = None): self.property_name = key self.node_class = cls self.msg = msg diff --git a/neomodel/properties.py b/neomodel/properties.py index 487be3be..253d3d27 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -4,12 +4,12 @@ import uuid from abc import ABCMeta, abstractmethod from datetime import date, datetime -from typing import Any, Callable, Optional +from typing import Any, Callable, Union, overload from zoneinfo import ZoneInfo import neo4j.time -from neomodel import config +from neomodel.config import get_config from neomodel.exceptions import DeflateError, InflateError, NeomodelException TOO_MANY_DEFAULTS = "too many defaults" @@ -22,18 +22,21 @@ def validator(fn: Callable) -> Callable: @functools.wraps(fn) def _validator( # type: ignore - self, value: Any, obj: Optional[Any] = None, rethrow: Optional[bool] = True + self, value: Any, obj: Any | None = None, rethrow: bool | None = True ) -> Any: if rethrow: try: return fn(self, value) except Exception as e: - if fn_name == "inflate": - raise InflateError(self.name, self.owner, str(e), obj) from e - elif fn_name == "deflate": - raise DeflateError(self.name, self.owner, str(e), obj) from e - else: - raise NeomodelException("Unknown Property method " + fn_name) from e + match fn_name: + case "inflate": + raise InflateError(self.name, self.owner, str(e), obj) from e + case "deflate": + raise DeflateError(self.name, self.owner, str(e), obj) from e + case _: + raise NeomodelException( + "Unknown Property method " + fn_name + ) from e else: # For using with ArrayProperty where we don't want an Inflate/Deflate error. return fn(self, value) @@ -48,8 +51,8 @@ class FulltextIndex: def __init__( self, - analyzer: Optional[str] = "standard-no-stop-words", - eventually_consistent: Optional[bool] = False, + analyzer: str | None = "standard-no-stop-words", + eventually_consistent: bool | None = False, ): """ Initializes new fulltext index definition with analyzer and eventually consistent @@ -68,8 +71,8 @@ class VectorIndex: def __init__( self, - dimensions: Optional[int] = 1536, - similarity_function: Optional[str] = "cosine", + dimensions: int | None = 1536, + similarity_function: str | None = "cosine", ): """ Initializes new vector index definition with dimensions and similarity @@ -108,32 +111,32 @@ class Property(metaclass=ABCMeta): """ form_field_class = "CharField" - name: Optional[str] = None - owner: Optional[Any] = None + name: str | None = None + owner: Any | None = None unique_index: bool = False index: bool = False - fulltext_index: Optional[FulltextIndex] = None - vector_index: Optional[VectorIndex] = None + fulltext_index: FulltextIndex | None = None + vector_index: VectorIndex | None = None required: bool = False default: Any = None - db_property: Optional[str] = None - label: Optional[str] = None - help_text: Optional[str] = None + db_property: str | None = None + label: str | None = None + help_text: str | None = None # pylint:disable=unused-argument def __init__( self, - name: Optional[str] = None, - owner: Optional[Any] = None, + name: str | None = None, + owner: Any | None = None, unique_index: bool = False, index: bool = False, - fulltext_index: Optional[FulltextIndex] = None, - vector_index: Optional[VectorIndex] = None, + fulltext_index: FulltextIndex | None = None, + vector_index: VectorIndex | None = None, required: bool = False, - default: Optional[Any] = None, - db_property: Optional[str] = None, - label: Optional[str] = None, - help_text: Optional[str] = None, + default: Any | None = None, + db_property: str | None = None, + label: str | None = None, + help_text: str | None = None, **kwargs: dict[str, Any], ): if default is not None and required: @@ -145,6 +148,8 @@ def __init__( "The arguments `unique_index` and `index` are mutually exclusive." ) + self.name = name + self.owner = owner self.required = required self.unique_index = unique_index self.index = index @@ -155,6 +160,9 @@ def __init__( self.db_property = db_property self.label = label self.help_text = help_text + # Set any extra kwargs as attributes on the property + for key, value in kwargs.items(): + setattr(self, key, value) def default_value(self) -> Any: """ @@ -223,7 +231,7 @@ class RegexProperty(NormalizedProperty): expression: str - def __init__(self, expression: Optional[str] = None, **kwargs: Any): + def __init__(self, expression: str | None = None, **kwargs: Any): """ Initializes new property with an expression. @@ -262,8 +270,8 @@ class StringProperty(NormalizedProperty): def __init__( self, - choices: Optional[Any] = None, - max_length: Optional[int] = None, + choices: Any | None = None, + max_length: int | None = None, **kwargs: Any, ): if max_length is not None: @@ -330,7 +338,7 @@ class ArrayProperty(Property): Stores a list of items """ - def __init__(self, base_property: Optional[Property] = None, **kwargs: Any): + def __init__(self, base_property: Property | None = None, **kwargs: Any): """ Store a list of values, optionally of a specific type. @@ -515,7 +523,7 @@ def deflate(self, value: datetime) -> float: if value.tzinfo: value = value.astimezone(ZoneInfo("UTC")) epoch_date = datetime(1970, 1, 1, tzinfo=ZoneInfo("UTC")) - elif config.FORCE_TIMEZONE: + elif get_config().force_timezone: raise ValueError(f"Error deflating {value}: No timezone provided.") else: # No timezone specified on datetime object.. assuming UTC @@ -595,7 +603,17 @@ def __init__(self, to: str): def aliased_to(self) -> str: return self.target - def __get__(self, obj: Any, _type: Optional[Any] = None) -> Property: + @overload + def __get__(self, obj: None, _type: Any | None = None) -> "AliasProperty": + ... + + @overload + def __get__(self, obj: Any, _type: Any | None = None) -> Property: + ... + + def __get__( + self, obj: Any | None, _type: Any | None = None + ) -> Union[Property, "AliasProperty"]: return getattr(obj, self.aliased_to()) if obj else self def __set__(self, obj: Any, value: Property) -> None: diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index cb3a7884..c777398e 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -7,11 +7,11 @@ :: usage: neomodel_inspect_database [-h] [--db bolt://neo4j:neo4j@localhost:7687] [--write-to ...] - + Connects to a Neo4j database and inspects existing nodes and relationships. Infers the schema of the database and generates Python class definitions. - If a connection URL is not specified, the tool will look up the environment + If a connection URL is not specified, the tool will look up the environment variable NEO4J_BOLT_URL. If that environment variable is not set, the tool will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 @@ -19,7 +19,7 @@ If no file is specified, the tool will print the class definitions to stdout. Note : this script only has a synchronous mode. - + options: -h, --help show this help message and exit --db bolt://neo4j:neo4j@localhost:7687 @@ -36,7 +36,7 @@ from os import environ from typing import Any -from neomodel.sync_.core import db +from neomodel.sync_.database import db IMPORTS = [] diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index f4159434..38890c56 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -8,19 +8,19 @@ :: usage: neomodel_install_labels [-h] [--db bolt://neo4j:neo4j@localhost:7687] [ ...] - + Setup indexes and constraints on labels in Neo4j for your neomodel schema. - - If a connection URL is not specified, the tool will look up the environment + + If a connection URL is not specified, the tool will look up the environment variable NEO4J_BOLT_URL. If that environment variable is not set, the tool will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 Note : this script only has a synchronous mode. - + positional arguments: python modules or files with neomodel schema declarations. - + options: -h, --help show this help message and exit --db bolt://neo4j:neo4j@localhost:7687 @@ -32,7 +32,7 @@ from os import environ from neomodel.scripts.utils import load_python_module_or_file -from neomodel.sync_.core import db +from neomodel.sync_.database import db def main(): diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index d879932f..e9840de0 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -8,15 +8,15 @@ :: usage: neomodel_remove_labels [-h] [--db bolt://neo4j:neo4j@localhost:7687] - + Drop all indexes and constraints on labels from schema in Neo4j database. - - If a connection URL is not specified, the tool will look up the environment + + If a connection URL is not specified, the tool will look up the environment variable NEO4J_BOLT_URL. If that environment variable is not set, the tool will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 Note : this script only has a synchronous mode. - + options: -h, --help show this help message and exit --db bolt://neo4j:neo4j@localhost:7687 @@ -28,7 +28,7 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter from os import environ -from neomodel.sync_.core import db +from neomodel.sync_.database import db def main(): diff --git a/neomodel/semantic_filters.py b/neomodel/semantic_filters.py index a1421111..f0756e74 100644 --- a/neomodel/semantic_filters.py +++ b/neomodel/semantic_filters.py @@ -1,6 +1,3 @@ -from typing import List - - class VectorFilter(object): """ Represents a CALL db.index.vector.query* neo functions call within the OGM @@ -15,10 +12,29 @@ class VectorFilter(object): """ def __init__( - self, topk: int, vector_attribute_name: str, candidate_vector: List[float] + self, topk: int, vector_attribute_name: str, candidate_vector: list[float] ): self.topk = topk self.vector_attribute_name = vector_attribute_name self.index_name = None self.node_set_label = None self.vector = candidate_vector + +class FulltextFilter(object): + """ + Represents a CALL db.index.fulltext.query* neo function call within the OGM. + :param query_strng: The string you are finding the nearest + :type query_string: str + :param freetext_attribute_name: The property name for the free text indexed property. + :type fulltext_attribute_name: str + :param topk: Amount to nodes to return + :type topk: int + + """ + + def __init__(self, query_string: str, fulltext_attribute_name: str, topk: int): + self.query_string = query_string + self.fulltext_attribute_name = fulltext_attribute_name + self.index_name = None + self.node_set_label = None + self.topk = topk diff --git a/neomodel/sync_/cardinality.py b/neomodel/sync_/cardinality.py index 849e0205..71a121f4 100644 --- a/neomodel/sync_/cardinality.py +++ b/neomodel/sync_/cardinality.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any, Optional +from neomodel.config import get_config from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, @@ -15,17 +16,16 @@ class ZeroOrOne(RelationshipManager): description = "zero or one relationship" - def _check_cardinality( - self, node: "StructuredNode", soft_check: bool = False - ) -> None: + def check_cardinality(self, node: "StructuredNode") -> None: if self.__len__(): - if soft_check: + detailed_description = str(self) + if get_config().soft_cardinality_check: print( - f"Cardinality violation detected : Node already has one relationship of type {self.definition['relation_type']}, should not connect more. Soft check is enabled so the relationship will be created. Note that strict check will be enabled by default in version 6.0" + f"Cardinality violation detected : Node already has {detailed_description}, should not connect more. Soft check is enabled so the relationship will be created." ) else: raise AttemptedCardinalityViolation( - f"Node already has one relationship of type {self.definition['relation_type']}. Use reconnect() to replace the existing relationship." + f"Node already has {detailed_description}. Use reconnect() to replace the existing relationship." ) def single(self) -> Optional["StructuredNode"]: @@ -46,7 +46,7 @@ def all(self) -> list["StructuredNode"]: return [node] if node else [] def connect( - self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None + self, node: "StructuredNode", properties: dict[str, Any] | None = None ) -> "StructuredRel": """ Connect to a node. @@ -97,6 +97,11 @@ def disconnect(self, node: "StructuredNode") -> None: raise AttemptedCardinalityViolation("One or more expected") return super().disconnect(node) + def disconnect_all(self) -> None: + raise AttemptedCardinalityViolation( + "Cardinality one or more, cannot disconnect_all use reconnect." + ) + class One(RelationshipManager): """ @@ -105,17 +110,16 @@ class One(RelationshipManager): description = "one relationship" - def _check_cardinality( - self, node: "StructuredNode", soft_check: bool = False - ) -> None: + def check_cardinality(self, node: "StructuredNode") -> None: if self.__len__(): - if soft_check: + detailed_description = str(self) + if get_config().soft_cardinality_check: print( - f"Cardinality violation detected : Node already has one relationship of type {self.definition['relation_type']}, should not connect more. Soft check is enabled so the relationship will be created. Note that strict check will be enabled by default in version 6.0" + f"Cardinality violation detected : Node already has {detailed_description}, should not connect more. Soft check is enabled so the relationship will be created." ) else: raise AttemptedCardinalityViolation( - f"Node already has one relationship of type {self.definition['relation_type']}. Use reconnect() to replace the existing relationship." + f"Node already has {detailed_description}. Use reconnect() to replace the existing relationship." ) def single(self) -> "StructuredNode": @@ -150,7 +154,7 @@ def disconnect_all(self) -> None: ) def connect( - self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None + self, node: "StructuredNode", properties: dict[str, Any] | None = None ) -> "StructuredRel": """ Connect a node diff --git a/neomodel/sync_/core.py b/neomodel/sync_/database.py similarity index 56% rename from neomodel/sync_/core.py rename to neomodel/sync_/database.py index 5f8a8182..48ba76cf 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/database.py @@ -1,13 +1,13 @@ +""" +Database connection and management for the neomodel module. +""" + import logging import os import sys import time -import warnings -from asyncio import iscoroutinefunction -from functools import wraps -from itertools import combinations -from threading import local -from typing import Any, Callable, Optional, TextIO, Union +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any, Callable, TextIO from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -23,72 +23,49 @@ from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired from neo4j.graph import Node, Path, Relationship -from neomodel import config -from neomodel._async_compat.util import Util +from neomodel.config import get_config +from neomodel.constants import ( + ACCESS_MODE_READ, + ACCESS_MODE_WRITE, + CONSTRAINT_ALREADY_EXISTS, + DROP_CONSTRAINT_COMMAND, + DROP_INDEX_COMMAND, + ELEMENT_ID_METHOD, + ENTERPRISE_EDITION_TAG, + INDEX_ALREADY_EXISTS, + LEGACY_ID_METHOD, + LIST_CONSTRAINTS_COMMAND, + LOOKUP_INDEX_TYPE, + NO_SESSION_OPEN, + NO_TRANSACTION_IN_PROGRESS, + RULE_ALREADY_EXISTS, + UNKNOWN_SERVER_VERSION, + VERSION_FULLTEXT_INDEXES_SUPPORT, + VERSION_LEGACY_ID, + VERSION_PARALLEL_RUNTIME_SUPPORT, + VERSION_RELATIONSHIP_CONSTRAINTS_SUPPORT, + VERSION_RELATIONSHIP_VECTOR_INDEXES_SUPPORT, + VERSION_VECTOR_INDEXES_SUPPORT, +) from neomodel.exceptions import ( ConstraintValidationFailed, - DoesNotExist, FeatureNotSupported, - NodeClassAlreadyDefined, NodeClassNotDefined, RelationshipClassNotDefined, UniqueProperty, ) -from neomodel.hooks import hooks from neomodel.properties import FulltextIndex, Property, VectorIndex -from neomodel.sync_.property_manager import PropertyManager -from neomodel.util import ( - _UnsavedNode, - classproperty, - deprecated, - version_tag_to_integer, -) +from neomodel.util import version_tag_to_integer -logger = logging.getLogger(__name__) +# The imports inside this block are only for type checking tools (like mypy or IDEs) to help with code hints and error checking. +# These imports are ignored when the code actually runs, so they don't affect runtime performance or cause circular import problems. +if TYPE_CHECKING: + from neomodel.sync_.node import StructuredNode # type: ignore + from neomodel.sync_.transaction import ImpersonationHandler, TransactionProxy -RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" -INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" -CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" -STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" -NOT_COROUTINE_ERROR = "The decorated function must be a coroutine" - -# Access mode constants -ACCESS_MODE_WRITE = "WRITE" -ACCESS_MODE_READ = "READ" - -# Database edition constants -ENTERPRISE_EDITION_TAG = "enterprise" - -# Neo4j version constants -VERSION_LEGACY_ID = "4" -VERSION_RELATIONSHIP_CONSTRAINTS_SUPPORT = "5.7" -VERSION_PARALLEL_RUNTIME_SUPPORT = "5.13" -VERSION_VECTOR_INDEXES_SUPPORT = "5.15" -VERSION_FULLTEXT_INDEXES_SUPPORT = "5.16" -VERSION_RELATIONSHIP_VECTOR_INDEXES_SUPPORT = "5.18" - -# ID method constants -LEGACY_ID_METHOD = "id" -ELEMENT_ID_METHOD = "elementId" - -# Cypher query constants -LIST_CONSTRAINTS_COMMAND = "SHOW CONSTRAINTS" -DROP_CONSTRAINT_COMMAND = "DROP CONSTRAINT " -DROP_INDEX_COMMAND = "DROP INDEX " - -# Index type constants -LOOKUP_INDEX_TYPE = "LOOKUP" - -# Info messages constants -NO_TRANSACTION_IN_PROGRESS = "No transaction in progress" -NO_SESSION_OPEN = "No session open" -UNKNOWN_SERVER_VERSION = """ - Unable to perform this operation because the database server version is not known. - This might mean that the database server is offline. -""" +logger = logging.getLogger(__name__) -# make sure the connection url has been set prior to executing the wrapped function def ensure_connection(func: Callable) -> Callable: """Decorator that ensures a connection is established before executing the decorated function. @@ -97,7 +74,6 @@ def ensure_connection(func: Callable) -> Callable: Returns: callable: The decorated function. - """ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable: @@ -108,38 +84,188 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable: _db = self if not _db.driver: - if hasattr(config, "DATABASE_URL") and config.DATABASE_URL: - _db.set_connection(url=config.DATABASE_URL) - elif hasattr(config, "DRIVER") and config.DRIVER: - _db.set_connection(driver=config.DRIVER) + config = get_config() + if hasattr(config, "database_url") and config.database_url: + _db.set_connection(url=config.database_url) + elif hasattr(config, "driver") and config.driver: + _db.set_connection(driver=config.driver) return func(self, *args, **kwargs) return wrapper -class Database(local): +class Database: """ A singleton object via which all operations from neomodel to the Neo4j backend are handled with. + + This class enforces singleton behavior - only one instance can exist at a time. + The singleton instance is accessible via the module-level 'db' variable. """ + # Shared global registries _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {} + # Singleton instance tracking + _instance: "Database | None" = None + _initialized: bool = False + + def __new__(cls) -> "Database": + """ + Enforce singleton pattern - only one instance can exist. + + Returns: + Database: The singleton instance + + Raises: + RuntimeError: If attempting to create a second instance + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + def __init__(self) -> None: - self._active_transaction: Optional[Transaction] = None - self.url: Optional[str] = None - self.driver: Optional[Driver] = None - self._session: Optional[Session] = None - self._pid: Optional[int] = None - self._database_name: Optional[str] = DEFAULT_DATABASE - self._database_version: Optional[str] = None - self._database_edition: Optional[str] = None - self.impersonated_user: Optional[str] = None - self._parallel_runtime: Optional[bool] = False + # Prevent re-initialization of the singleton instance + if Database._initialized: + return + # Private to instances and contexts + self.__active_transaction: ContextVar[Transaction | None] = ContextVar( + "_active_transaction", default=None + ) + self.__url: ContextVar[str | None] = ContextVar("url", default=None) + self.__driver: ContextVar[Driver | None] = ContextVar("driver", default=None) + self.__session: ContextVar[Session | None] = ContextVar( + "_session", default=None + ) + self.__pid: ContextVar[int | None] = ContextVar("_pid", default=None) + self.__database_name: ContextVar[str | None] = ContextVar( + "_database_name", default=DEFAULT_DATABASE + ) + self.__database_version: ContextVar[str | None] = ContextVar( + "_database_version", default=None + ) + self.__database_edition: ContextVar[str | None] = ContextVar( + "_database_edition", default=None + ) + self.__impersonated_user: ContextVar[str | None] = ContextVar( + "impersonated_user", default=None + ) + self.__parallel_runtime: ContextVar[bool | None] = ContextVar( + "_parallel_runtime", default=False + ) + + # Mark the singleton as initialized + Database._initialized = True + + @classmethod + def get_instance(cls) -> "Database": + """ + Get the singleton instance of Database. + + Returns: + Database: The singleton instance + """ + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset_instance(cls) -> None: + """ + Reset the singleton instance. This should only be used for testing purposes. + + Warning: This will close any existing connections and reset all state. + """ + if cls._instance is not None: + # Close any existing connections + cls._instance.close_connection() + + cls._instance = None + cls._initialized = False + + @property + def _active_transaction(self) -> Transaction | None: + return self.__active_transaction.get() + + @_active_transaction.setter + def _active_transaction(self, value: Transaction | None) -> None: + self.__active_transaction.set(value) + + @property + def url(self) -> str | None: + return self.__url.get() + + @url.setter + def url(self, value: str | None) -> None: + self.__url.set(value) + + @property + def driver(self) -> Driver | None: + return self.__driver.get() + + @driver.setter + def driver(self, value: Driver | None) -> None: + self.__driver.set(value) + + @property + def _session(self) -> Session | None: + return self.__session.get() + + @_session.setter + def _session(self, value: Session | None) -> None: + self.__session.set(value) + + @property + def _pid(self) -> int | None: + return self.__pid.get() + + @_pid.setter + def _pid(self, value: int | None) -> None: + self.__pid.set(value) + + @property + def _database_name(self) -> str | None: + return self.__database_name.get() + + @_database_name.setter + def _database_name(self, value: str | None) -> None: + self.__database_name.set(value) + + @property + def _database_version(self) -> str | None: + return self.__database_version.get() + + @_database_version.setter + def _database_version(self, value: str | None) -> None: + self.__database_version.set(value) + + @property + def _database_edition(self) -> str | None: + return self.__database_edition.get() + + @_database_edition.setter + def _database_edition(self, value: str | None) -> None: + self.__database_edition.set(value) + + @property + def impersonated_user(self) -> str | None: + return self.__impersonated_user.get() + + @impersonated_user.setter + def impersonated_user(self, value: str | None) -> None: + self.__impersonated_user.set(value) + + @property + def _parallel_runtime(self) -> bool | None: + return self.__parallel_runtime.get() + + @_parallel_runtime.setter + def _parallel_runtime(self, value: bool | None) -> None: + self.__parallel_runtime.set(value) def set_connection( - self, url: Optional[str] = None, driver: Optional[Driver] = None + self, url: str | None = None, driver: Driver | None = None ) -> None: """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -153,8 +279,9 @@ def set_connection( """ if driver: self.driver = driver - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME + config = get_config() + if hasattr(config, "database_name") and config.database_name: + self._database_name = config.database_name elif url: self._parse_driver_from_url(url=url) @@ -207,21 +334,22 @@ def _parse_driver_from_url(self, url: str) -> None: f"Expecting url format: bolt://user:password@localhost:7687 got {url}" ) + config = get_config() options = { "auth": basic_auth(username, password), - "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, - "connection_timeout": config.CONNECTION_TIMEOUT, - "keep_alive": config.KEEP_ALIVE, - "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, - "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, - "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, - "resolver": config.RESOLVER, - "user_agent": config.USER_AGENT, + "connection_acquisition_timeout": config.connection_acquisition_timeout, + "connection_timeout": config.connection_timeout, + "keep_alive": config.keep_alive, + "max_connection_lifetime": config.max_connection_lifetime, + "max_connection_pool_size": config.max_connection_pool_size, + "max_transaction_retry_time": config.max_transaction_retry_time, + "resolver": config.resolver, + "user_agent": config.user_agent, } if "+s" not in parsed_url.scheme: - options["encrypted"] = config.ENCRYPTED - options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + options["encrypted"] = config.encrypted + options["trusted_certificates"] = config.trusted_certificates # Ignore the type error because the workaround would be duplicating code self.driver = GraphDatabase.driver( @@ -230,8 +358,8 @@ def _parse_driver_from_url(self, url: str) -> None: self.url = url # The database name can be provided through the url or the config if database_name == "": - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME + if hasattr(config, "database_name") and config.database_name: + self._database_name = config.database_name else: self._database_name = database_name @@ -248,14 +376,14 @@ def close_connection(self) -> None: self.driver = None @property - def database_version(self) -> Optional[str]: + def database_version(self) -> str | None: if self._database_version is None: self._update_database_version() return self._database_version @property - def database_edition(self) -> Optional[str]: + def database_edition(self) -> str | None: if self._database_edition is None: self._update_database_version() @@ -266,18 +394,26 @@ def transaction(self) -> "TransactionProxy": """ Returns the current transaction object """ + from neomodel.sync_.transaction import TransactionProxy # type: ignore + return TransactionProxy(self) @property def write_transaction(self) -> "TransactionProxy": + from neomodel.sync_.transaction import TransactionProxy # type: ignore + return TransactionProxy(self, access_mode=ACCESS_MODE_WRITE) @property def read_transaction(self) -> "TransactionProxy": + from neomodel.sync_.transaction import TransactionProxy # type: ignore + return TransactionProxy(self, access_mode=ACCESS_MODE_READ) @property def parallel_read_transaction(self) -> "TransactionProxy": + from neomodel.sync_.transaction import TransactionProxy # type: ignore + return TransactionProxy( self, access_mode=ACCESS_MODE_READ, parallel_runtime=True ) @@ -291,6 +427,8 @@ def impersonate(self, user: str) -> "ImpersonationHandler": Returns: ImpersonationHandler: Context manager to set/unset the user to impersonate """ + from neomodel.sync_.transaction import ImpersonationHandler # type: ignore + db_edition = self.database_edition if db_edition != ENTERPRISE_EDITION_TAG: raise FeatureNotSupported( @@ -449,12 +587,18 @@ def _object_resolution(self, object_to_resolve: Any) -> Any: ) if isinstance(object_to_resolve, Path): - from neomodel.sync_.path import NeomodelPath + from neomodel.sync_.path import NeomodelPath # type: ignore return NeomodelPath(object_to_resolve) if isinstance(object_to_resolve, list): - return self._result_resolution([object_to_resolve]) + return [self._object_resolution(item) for item in object_to_resolve] + + if isinstance(object_to_resolve, dict): + return { + key: self._object_resolution(value) + for key, value in object_to_resolve.items() + } return object_to_resolve @@ -490,11 +634,11 @@ def _result_resolution(self, result_list: list) -> list: def cypher_query( self, query: str, - params: Optional[dict[str, Any]] = None, + params: dict[str, Any] | None = None, handle_unique: bool = True, retry_on_session_expire: bool = False, resolve_objects: bool = False, - ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: + ) -> tuple[list | None, tuple[str, ...] | None]: """ Runs a query on the database and returns a list of results and their headers. @@ -546,13 +690,13 @@ def cypher_query( def _run_cypher_query( self, - session: Union[Session, Transaction], + session: Session | Transaction, query: str, params: dict[str, Any], handle_unique: bool, retry_on_session_expire: bool, resolve_objects: bool, - ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: + ) -> tuple[list | None, tuple[str, ...] | None]: try: # Retrieve the data start = time.time() @@ -613,7 +757,7 @@ def get_id_method(self) -> str: else: return ELEMENT_ID_METHOD - def parse_element_id(self, element_id: Optional[str]) -> Union[str, int]: + def parse_element_id(self, element_id: str | None) -> str | int: if element_id is None: raise ValueError( "Unable to parse element id, are you sure this element has been saved ?" @@ -708,12 +852,12 @@ def clear_neo4j_database( """ ) if clear_constraints: - drop_constraints() + self.drop_constraints() if clear_indexes: - drop_indexes() + self.drop_indexes() def drop_constraints( - self, quiet: bool = True, stdout: Optional[TextIO] = None + self, quiet: bool = True, stdout: TextIO | None = None ) -> None: """ Discover and drop all constraints. @@ -740,7 +884,7 @@ def drop_constraints( if not quiet: stdout.write("\n") - def drop_indexes(self, quiet: bool = True, stdout: Optional[TextIO] = None) -> None: + def drop_indexes(self, quiet: bool = True, stdout: TextIO | None = None) -> None: """ Discover and drop all indexes, except the automatically created token lookup indexes. @@ -760,7 +904,7 @@ def drop_indexes(self, quiet: bool = True, stdout: Optional[TextIO] = None) -> N if not quiet: stdout.write("\n") - def remove_all_labels(self, stdout: Optional[TextIO] = None) -> None: + def remove_all_labels(self, stdout: TextIO | None = None) -> None: """ Calls functions for dropping constraints and indexes. @@ -777,7 +921,7 @@ def remove_all_labels(self, stdout: Optional[TextIO] = None) -> None: stdout.write("Dropping indexes...\n") self.drop_indexes(quiet=False, stdout=stdout) - def install_all_labels(self, stdout: Optional[TextIO] = None) -> None: + def install_all_labels(self, stdout: TextIO | None = None) -> None: """ Discover all subclasses of StructuredNode in your application and execute install_labels on each. Note: code must be loaded (imported) in order for a class to be discovered. @@ -798,9 +942,11 @@ def subsub(cls: Any) -> list: # recursively return all subclasses stdout.write("Setting up indexes and constraints...\n\n") i = 0 + from .node import StructuredNode + for cls in subsub(StructuredNode): stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") - install_labels(cls, quiet=False, stdout=stdout) + self.install_labels(cls, quiet=False, stdout=stdout) i += 1 if i: @@ -809,7 +955,7 @@ def subsub(cls: Any) -> list: # recursively return all subclasses stdout.write(f"Finished {i} classes.\n") def install_labels( - self, cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None + self, cls: Any, quiet: bool = True, stdout: TextIO | None = None ) -> None: """ Setup labels with indexes and constraints for a given class @@ -1187,738 +1333,4 @@ def _install_relationship( # Create a singleton instance of the database object -db = Database() - - -# Deprecated methods -def change_neo4j_password(db: Database, user: str, new_password: str) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, db for async). - Please use db.change_neo4j_password(user, new_password) instead. - This direct call will be removed in an upcoming version. - """ - ) - db.change_neo4j_password(user, new_password) - - -def clear_neo4j_database( - db: Database, clear_constraints: bool = False, clear_indexes: bool = False -) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, db for async). - Please use db.clear_neo4j_database(clear_constraints, clear_indexes) instead. - This direct call will be removed in an upcoming version. - """ - ) - db.clear_neo4j_database(clear_constraints, clear_indexes) - - -def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, db for async). - Please use db.drop_constraints(quiet, stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - db.drop_constraints(quiet, stdout) - - -def drop_indexes(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, db for async). - Please use db.drop_indexes(quiet, stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - db.drop_indexes(quiet, stdout) - - -def remove_all_labels(stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, db for async). - Please use db.remove_all_labels(stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - db.remove_all_labels(stdout) - - -def install_labels( - cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None -) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, db for async). - Please use db.install_labels(cls, quiet, stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - db.install_labels(cls, quiet, stdout) - - -def install_all_labels(stdout: Optional[TextIO] = None) -> None: - deprecated( - """ - This method has been moved to the Database singleton (db for sync, db for async). - Please use db.install_all_labels(stdout) instead. - This direct call will be removed in an upcoming version. - """ - ) - db.install_all_labels(stdout) - - -class TransactionProxy: - bookmarks: Optional[Bookmarks] = None - - def __init__( - self, - db: Database, - access_mode: Optional[str] = None, - parallel_runtime: Optional[bool] = False, - ): - self.db: Database = db - self.access_mode: Optional[str] = access_mode - self.parallel_runtime: Optional[bool] = parallel_runtime - - @ensure_connection - def __enter__(self) -> "TransactionProxy": - if self.parallel_runtime and not self.db.parallel_runtime_available(): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.parallel_runtime = False - self.db._parallel_runtime = self.parallel_runtime - self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) - self.bookmarks = None - return self - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - self.db._parallel_runtime = False - if exc_value: - self.db.rollback() - - if ( - exc_type is ClientError - and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" - ): - raise UniqueProperty(exc_value.message) - - if not exc_value: - self.last_bookmark = self.db.commit() - - def __call__(self, func: Callable) -> Callable: - if Util.is_async_code and not iscoroutinefunction(func): - raise TypeError(NOT_COROUTINE_ERROR) - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Callable: - with self: - return func(*args, **kwargs) - - return wrapper - - @property - def with_bookmark(self) -> "BookmarkingAsyncTransactionProxy": - return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) - - -class BookmarkingAsyncTransactionProxy(TransactionProxy): - def __call__(self, func: Callable) -> Callable: - if Util.is_async_code and not iscoroutinefunction(func): - raise TypeError(NOT_COROUTINE_ERROR) - - def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, None]: - self.bookmarks = kwargs.pop("bookmarks", None) - - with self: - result = func(*args, **kwargs) - self.last_bookmark = None - - return result, self.last_bookmark - - return wrapper - - -class ImpersonationHandler: - def __init__(self, db: Database, impersonated_user: str): - self.db = db - self.impersonated_user = impersonated_user - - def __enter__(self) -> "ImpersonationHandler": - self.db.impersonated_user = self.impersonated_user - return self - - def __exit__( - self, exception_type: Any, exception_value: Any, exception_traceback: Any - ) -> None: - self.db.impersonated_user = None - - print("\nException type:", exception_type) - print("\nException value:", exception_value) - print("\nTraceback:", exception_traceback) - - def __call__(self, func: Callable) -> Callable: - def wrapper(*args: Any, **kwargs: Any) -> Callable: - with self: - return func(*args, **kwargs) - - return wrapper - - -class NodeMeta(type): - DoesNotExist: type[DoesNotExist] - __required_properties__: tuple[str, ...] - __all_properties__: tuple[tuple[str, Any], ...] - __all_aliases__: tuple[tuple[str, Any], ...] - __all_relationships__: tuple[tuple[str, Any], ...] - __label__: str - __optional_labels__: list[str] - - defined_properties: Callable[..., dict[str, Any]] - - def __new__( - mcs: type, name: str, bases: tuple[type, ...], namespace: dict[str, Any] - ) -> Any: - namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) - cls: NodeMeta = type.__new__(mcs, name, bases, namespace) - cls.DoesNotExist._model_class = cls - - if hasattr(cls, "__abstract_node__"): - delattr(cls, "__abstract_node__") - else: - if "deleted" in namespace: - raise ValueError( - "Property name 'deleted' is not allowed as it conflicts with neomodel internals." - ) - elif "id" in namespace: - raise ValueError( - """ - Property name 'id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as id is also a Neo4j internal. - """ - ) - elif "element_id" in namespace: - raise ValueError( - """ - Property name 'element_id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. - """ - ) - for key, value in ( - (x, y) for x, y in namespace.items() if isinstance(y, Property) - ): - value.name, value.owner = key, cls - if hasattr(value, "setup") and callable(value.setup): - value.setup() - - # cache various groups of properies - cls.__required_properties__ = tuple( - name - for name, property in cls.defined_properties( - aliases=False, rels=False - ).items() - if property.required or property.unique_index - ) - cls.__all_properties__ = tuple( - cls.defined_properties(aliases=False, rels=False).items() - ) - cls.__all_aliases__ = tuple( - cls.defined_properties(properties=False, rels=False).items() - ) - cls.__all_relationships__ = tuple( - cls.defined_properties(aliases=False, properties=False).items() - ) - - cls.__label__ = namespace.get("__label__", name) - cls.__optional_labels__ = namespace.get("__optional_labels__", []) - - build_class_registry(cls) - - return cls - - -def build_class_registry(cls: Any) -> None: - base_label_set = frozenset(cls.inherited_labels()) - optional_label_set = set(cls.inherited_optional_labels()) - - # Construct all possible combinations of labels + optional labels - possible_label_combinations = [ - frozenset(set(x).union(base_label_set)) - for i in range(1, len(optional_label_set) + 1) - for x in combinations(optional_label_set, i) - ] - possible_label_combinations.append(base_label_set) - - for label_set in possible_label_combinations: - if not hasattr(cls, "__target_databases__"): - if label_set not in db._NODE_CLASS_REGISTRY: - db._NODE_CLASS_REGISTRY[label_set] = cls - else: - raise NodeClassAlreadyDefined( - cls, db._NODE_CLASS_REGISTRY, db._DB_SPECIFIC_CLASS_REGISTRY - ) - else: - for database in cls.__target_databases__: - if database not in db._DB_SPECIFIC_CLASS_REGISTRY: - db._DB_SPECIFIC_CLASS_REGISTRY[database] = {} - if label_set not in db._DB_SPECIFIC_CLASS_REGISTRY[database]: - db._DB_SPECIFIC_CLASS_REGISTRY[database][label_set] = cls - else: - raise NodeClassAlreadyDefined( - cls, db._NODE_CLASS_REGISTRY, db._DB_SPECIFIC_CLASS_REGISTRY - ) - - -NodeBase: type = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) - - -class StructuredNode(NodeBase): - """ - Base class for all node definitions to inherit from. - - If you want to create your own abstract classes set: - __abstract_node__ = True - """ - - # static properties - - __abstract_node__ = True - - # magic methods - - def __init__(self, *args: Any, **kwargs: Any): - if "deleted" in kwargs: - raise ValueError("deleted property is reserved for neomodel") - - for key, val in self.__all_relationships__: - self.__dict__[key] = val.build_manager(self, key) - - super().__init__(*args, **kwargs) - - def __eq__(self, other: Any) -> bool: - """ - Compare two node objects. - If both nodes were saved to the database, compare them by their element_id. - Otherwise, compare them using object id in memory. - If `other` is not a node, always return False. - """ - if not isinstance(other, (StructuredNode,)): - return False - if self.was_saved and other.was_saved: - return self.element_id == other.element_id - return id(self) == id(other) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}: {self}>" - - def __str__(self) -> str: - return repr(self.__properties__) - - # dynamic properties - - @classproperty - def nodes(self) -> Any: - """ - Returns a NodeSet object representing all nodes of the classes label - :return: NodeSet - :rtype: NodeSet - """ - from neomodel.sync_.match import NodeSet - - return NodeSet(self) - - @property - def element_id(self) -> Optional[Any]: - if hasattr(self, "element_id_property"): - return self.element_id_property - return None - - # Version 4.4 support - id is deprecated in version 5.x - @property - def id(self) -> int: - try: - return int(self.element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) - - @property - def was_saved(self) -> bool: - """ - Shows status of node in the database. False, if node hasn't been saved yet, True otherwise. - """ - return self.element_id is not None - - # methods - - @classmethod - def _build_merge_query( - cls, - merge_params: tuple[dict[str, Any], ...], - update_existing: bool = False, - lazy: bool = False, - relationship: Optional[Any] = None, - ) -> tuple[str, dict[str, Any]]: - """ - Get a tuple of a CYPHER query and a params dict for the specified MERGE query. - - :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". - :type merge_params: list of dict - :param update_existing: True to update properties of existing nodes, default False to keep existing values. - :type update_existing: bool - :rtype: tuple - """ - query_params: dict[str, Any] = {"merge_params": merge_params} - n_merge_labels = ":".join(cls.inherited_labels()) - n_merge_prm = ", ".join( - ( - f"{getattr(cls, p).get_db_property_name(p)}: params.create.{getattr(cls, p).get_db_property_name(p)}" - for p in cls.__required_properties__ - ) - ) - n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" - if relationship is None: - # create "simple" unwind query - query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " - else: - # validate relationship - if not isinstance(relationship.source, StructuredNode): - raise ValueError( - f"relationship source [{repr(relationship.source)}] is not a StructuredNode" - ) - relation_type = relationship.definition.get("relation_type") - if not relation_type: - raise ValueError( - "No relation_type is specified on provided relationship" - ) - - from neomodel.sync_.match import _rel_helper - - if relationship.source.element_id is None: - raise RuntimeError( - "Could not identify the relationship source, its element id was None." - ) - query_params["source_id"] = db.parse_element_id( - relationship.source.element_id - ) - query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " - query += "WITH source\n UNWIND $merge_params as params \n " - query += "MERGE " - query += _rel_helper( - lhs="source", - rhs=n_merge, - ident=None, - relation_type=relation_type, - direction=relationship.definition["direction"], - ) - - query += "ON CREATE SET n = params.create\n " - # if update_existing, write properties on match as well - if update_existing is True: - query += "ON MATCH SET n += params.update\n" - - # close query - if lazy: - query += f"RETURN {db.get_id_method()}(n)" - else: - query += "RETURN n" - - return query, query_params - - @classmethod - def create(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: - """ - Call to CREATE with parameters map. A new instance will be created and saved. - - :param props: dict of properties to create the nodes. - :type props: tuple - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :type: bool - :rtype: list - """ - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - lazy = kwargs.get("lazy", False) - # create mapped query - query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" - - # close query - if lazy: - query += f" RETURN {db.get_id_method()}(n)" - else: - query += " RETURN n" - - results = [] - for item in [ - cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props - ]: - node, _ = db.cypher_query(query, {"create_params": item}) - results.extend(node[0]) - - nodes = [cls.inflate(node) for node in results] - - if not lazy and hasattr(cls, "post_create"): - for node in nodes: - node.post_create() - - return nodes - - @classmethod - def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, - this is an atomic operation. If an instance already exists all optional properties specified will be updated. - - Note that the post_create hook isn't called after create_or_update - - :param props: List of dict arguments to get or create the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy: bool = bool(kwargs.get("lazy", False)) - relationship = kwargs.get("relationship") - - # build merge query, make sure to update only explicitly specified properties - create_or_update_params = [] - for specified, deflated in [ - (p, cls.deflate(p, skip_empty=True)) for p in props - ]: - create_or_update_params.append( - { - "create": deflated, - "update": dict( - (k, v) for k, v in deflated.items() if k in specified - ), - } - ) - query, params = cls._build_merge_query( - tuple(create_or_update_params), - update_existing=True, - relationship=relationship, - lazy=lazy, - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = db.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - def cypher( - self, query: str, params: Optional[dict[str, Any]] = None - ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: - """ - Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. - - :param query: cypher query string - :type: string - :param params: query parameters - :type: dict - :return: tuple containing a list of query results, and the meta information as a tuple - :rtype: tuple - """ - self._pre_action_check("cypher") - _params = params or {} - if self.element_id is None: - raise ValueError("Can't run cypher operation on unsaved node") - element_id = db.parse_element_id(self.element_id) - _params.update({"self": element_id}) - return db.cypher_query(query, _params) - - @hooks - def delete(self) -> bool: - """ - Delete a node and its relationships - - :return: True - """ - self._pre_action_check("delete") - self.cypher( - f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self" - ) - delattr(self, "element_id_property") - self.deleted = True - return True - - @classmethod - def get_or_create(cls: Any, *props: tuple, **kwargs: dict[str, Any]) -> list: - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, - this is an atomic operation. - Parameters must contain all required properties, any non required properties with defaults will be generated. - - Note that the post_create hook isn't called after get_or_create - - :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create - the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy = kwargs.get("lazy", False) - relationship = kwargs.get("relationship") - - # build merge query - get_or_create_params = [ - {"create": cls.deflate(p, skip_empty=True)} for p in props - ] - query, params = cls._build_merge_query( - tuple(get_or_create_params), relationship=relationship, lazy=lazy - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = db.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - @classmethod - def inflate(cls: Any, node: Any) -> Any: - """ - Inflate a raw neo4j_driver node to a neomodel node - :param node: - :return: node object - """ - # support lazy loading - if isinstance(node, str) or isinstance(node, int): - snode = cls() - snode.element_id_property = node - else: - snode = super().inflate(node) - snode.element_id_property = node.element_id - - return snode - - @classmethod - def inherited_labels(cls: Any) -> list[str]: - """ - Return list of labels from nodes class hierarchy. - - :return: list - """ - return [ - scls.__label__ - for scls in cls.mro() - if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") - ] - - @classmethod - def inherited_optional_labels(cls: Any) -> list[str]: - """ - Return list of optional labels from nodes class hierarchy. - - :return: list - :rtype: list - """ - return [ - label - for scls in cls.mro() - for label in getattr(scls, "__optional_labels__", []) - if not hasattr(scls, "__abstract_node__") - ] - - def labels(self) -> list[str]: - """ - Returns list of labels tied to the node from neo4j. - - :return: list of labels - :rtype: list - """ - self._pre_action_check("labels") - result = self.cypher( - f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" - ) - if result is None or result[0] is None: - raise ValueError("Could not get labels, node may not exist") - return result[0][0][0] - - def _pre_action_check(self, action: str) -> None: - if hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on deleted node" - ) - if not hasattr(self, "element_id"): - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on unsaved node" - ) - - def refresh(self) -> None: - """ - Reload the node from neo4j - """ - self._pre_action_check("refresh") - if hasattr(self, "element_id"): - results = self.cypher( - f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n" - ) - request = results[0] - if not request or not request[0]: - raise self.__class__.DoesNotExist("Can't refresh non existent node") - node = self.inflate(request[0][0]) - for key, val in node.__properties__.items(): - setattr(self, key, val) - else: - raise ValueError("Can't refresh unsaved node") - - @hooks - def save(self) -> "StructuredNode": - """ - Save the node to neo4j or raise an exception - - :return: the node instance - """ - - # create or update instance node - if hasattr(self, "element_id_property"): - # update - params = self.deflate(self.__properties__, self) - query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n" - - if params: - query += "SET " - query += ",\n".join([f"n.{key} = ${key}" for key in params]) - query += "\n" - if self.inherited_labels(): - query += "\n".join( - [f"SET n:`{label}`" for label in self.inherited_labels()] - ) - self.cypher(query, params) - elif hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.save() attempted on deleted node" - ) - else: # create - result = self.create(self.__properties__) - created_node = result[0] - self.element_id_property = created_node.element_id - return self +db = Database.get_instance() diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 2635ebcc..4aace76d 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,21 +1,21 @@ import inspect import re import string -import warnings from dataclasses import dataclass from typing import Any, Iterator from typing import Optional as TOptional -from typing import Tuple, Union +from typing import Union from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty, ArrayProperty, Property -from neomodel.semantic_filters import VectorFilter +from neomodel.semantic_filters import FulltextFilter, VectorFilter from neomodel.sync_ import relationship_manager -from neomodel.sync_.core import StructuredNode, db +from neomodel.sync_.database import db +from neomodel.sync_.node import StructuredNode from neomodel.sync_.relationship import StructuredRel from neomodel.typing import Subquery, Transformation -from neomodel.util import INCOMING, OUTGOING +from neomodel.util import RelationshipDirection CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)") @@ -23,10 +23,10 @@ def _rel_helper( lhs: str, rhs: str, - ident: TOptional[str] = None, - relation_type: TOptional[str] = None, - direction: TOptional[int] = None, - relation_properties: TOptional[dict] = None, + ident: str | None = None, + relation_type: str | None = None, + direction: int | None = None, + relation_properties: dict | None = None, **kwargs: dict[str, Any], # NOSONAR ) -> str: """ @@ -57,20 +57,19 @@ def _rel_helper( rel_props = f" {{{rel_props_str}}}" rel_def = "" - # relation_type is unspecified - if relation_type is None: - rel_def = "" - # all("*" wildcard) relation_type - elif relation_type == "*": - rel_def = "[*]" - else: - # explicit relation_type - rel_def = f"[{ident if ident else ''}:`{relation_type}`{rel_props}]" + match relation_type: + case None: # relation_type is unspecified + rel_def = "" + case "*": # all("*" wildcard) relation_type + rel_def = "[*]" + case _: # explicit relation_type + rel_def = f"[{ident if ident else ''}:`{relation_type}`{rel_props}]" stmt = "" - if direction == OUTGOING: + + if direction == RelationshipDirection.OUTGOING: stmt = f"-{rel_def}->" - elif direction == INCOMING: + elif direction == RelationshipDirection.INCOMING: stmt = f"<-{rel_def}-" else: stmt = f"-{rel_def}-" @@ -88,9 +87,9 @@ def _rel_merge_helper( lhs: str, rhs: str, ident: str = "neomodelident", - relation_type: TOptional[str] = None, - direction: TOptional[int] = None, - relation_properties: TOptional[dict] = None, + relation_type: str | None = None, + direction: int | None = None, + relation_properties: dict | None = None, **kwargs: dict[str, Any], # NOSONAR ) -> str: """ @@ -113,9 +112,9 @@ def _rel_merge_helper( :returns: string """ - if direction == OUTGOING: + if direction == RelationshipDirection.OUTGOING: stmt = "-{0}->" - elif direction == INCOMING: + elif direction == RelationshipDirection.INCOMING: stmt = "<-{0}-" else: stmt = "-{0}-" @@ -143,15 +142,14 @@ def _rel_merge_helper( rel_none_props = ( f" ON CREATE SET {rel_prop_val_str} ON MATCH SET {rel_prop_val_str}" ) - # relation_type is unspecified - if relation_type is None: - stmt = stmt.format("") - # all("*" wildcard) relation_type - elif relation_type == "*": - stmt = stmt.format("[*]") - else: - # explicit relation_type - stmt = stmt.format(f"[{ident}:`{relation_type}`{rel_props}]") + + match relation_type: + case None: # relation_type is unspecified + stmt = stmt.format("") + case "*": # all("*" wildcard) relation_type + stmt = stmt.format("[*]") + case _: # explicit relation_type + stmt = stmt.format(f"[{ident}:`{relation_type}`{rel_props}]") return f"({lhs}){stmt}({rhs}){rel_none_props}" @@ -226,7 +224,7 @@ def install_traversals(cls: type[StructuredNode], node_set: "NodeSet") -> None: def _handle_special_operators( property_obj: Property, key: str, value: str, operator: str, prop: str -) -> Tuple[str, str, str]: +) -> tuple[str, str, str]: if operator == _SPECIAL_OPERATOR_IN: if not isinstance(value, (list, tuple)): raise ValueError( @@ -263,7 +261,7 @@ def _deflate_value( value: str, operator: str, prop: str, -) -> Tuple[str, str, str]: +) -> tuple[str, str, str]: if isinstance(property_obj, AliasProperty): prop = property_obj.aliased_to() deflated_value = getattr(cls, prop).deflate(value) @@ -278,7 +276,7 @@ def _deflate_value( def _initialize_filter_args_variables( cls: type[StructuredNode], key: str -) -> Tuple[type[StructuredNode], None, None, str, bool, str]: +) -> tuple[type[StructuredNode], None, None, str, bool, str]: current_class = cls current_rel_model = None leaf_prop = None @@ -298,7 +296,7 @@ def _initialize_filter_args_variables( def _process_filter_key( cls: type[StructuredNode], key: str -) -> Tuple[Property, str, str]: +) -> tuple[Property, str, str]: ( current_class, current_rel_model, @@ -394,33 +392,35 @@ class QueryAST: match: list[str] optional_match: list[str] where: list[str] - with_clause: TOptional[str] - return_clause: TOptional[str] - order_by: TOptional[list[str]] - skip: TOptional[int] - limit: TOptional[int] - result_class: TOptional[type] - lookup: TOptional[str] - additional_return: TOptional[list[str]] - is_count: TOptional[bool] - vector_index_query: TOptional[type] + with_clause: str | None + return_clause: str | None + order_by: list[str] | None + skip: int | None + limit: int | None + result_class: type | None + lookup: str | None + additional_return: list[str] | None + is_count: bool | None + vector_index_query: VectorFilter | None + fulltext_index_query: FulltextFilter | None def __init__( self, - match: TOptional[list[str]] = None, - optional_match: TOptional[list[str]] = None, - where: TOptional[list[str]] = None, - optional_where: TOptional[list[str]] = None, - with_clause: TOptional[str] = None, - return_clause: TOptional[str] = None, - order_by: TOptional[list[str]] = None, - skip: TOptional[int] = None, - limit: TOptional[int] = None, - result_class: TOptional[type] = None, - lookup: TOptional[str] = None, - additional_return: TOptional[list[str]] = None, - is_count: TOptional[bool] = False, - vector_index_query: TOptional[type] = None, + match: list[str] | None = None, + optional_match: list[str] | None = None, + where: list[str] | None = None, + optional_where: list[str] | None = None, + with_clause: str | None = None, + return_clause: str | None = None, + order_by: list[str] | None = None, + skip: int | None = None, + limit: int | None = None, + result_class: type | None = None, + lookup: str | None = None, + additional_return: list[str] | None = None, + is_count: bool | None = False, + vector_index_query: VectorFilter | None = None, + fulltext_index_query: FulltextFilter | None = None, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -438,13 +438,14 @@ def __init__( ) self.is_count = is_count self.vector_index_query = vector_index_query + self.fulltext_index_query = fulltext_index_query self.subgraph: dict = {} self.mixed_filters: bool = False class QueryBuilder: def __init__( - self, node_set: "BaseSet", subquery_namespace: TOptional[str] = None + self, node_set: "BaseSet", subquery_namespace: str | None = None ) -> None: self.node_set = node_set self._ast = QueryAST() @@ -452,7 +453,7 @@ def __init__( self._place_holder_registry: dict = {} self._relation_identifier_count: int = 0 self._node_identifier_count: int = 0 - self._subquery_namespace: TOptional[str] = subquery_namespace + self._subquery_namespace: str | None = subquery_namespace def build_ast(self) -> "QueryBuilder": if isinstance(self.node_set, NodeSet) and hasattr( @@ -466,7 +467,14 @@ def build_ast(self) -> "QueryBuilder": and hasattr(self.node_set, "vector_query") and self.node_set.vector_query ): - self.build_vector_query(self.node_set.vector_query, self.node_set.source) + self.build_vector_query() + + if ( + isinstance(self.node_set, NodeSet) + and hasattr(self.node_set, "fulltext_query") + and self.node_set.fulltext_query + ): + self.build_fulltext_query() self.build_source(self.node_set) @@ -549,30 +557,59 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None: order_by.append(f"{result[0]}.{prop}") self._ast.order_by = order_by - def build_vector_query(self, vectorfilter: "VectorFilter", source: "NodeSet"): + def build_vector_query(self): """ Query a vector indexed property on the node. """ + vector_filter = self.node_set.vector_query + source_class = self.node_set.source_class try: - attribute = getattr(source, vectorfilter.vector_attribute_name) + attribute = getattr( + self.node_set.source, vector_filter.vector_attribute_name + ) except AttributeError as e: raise AttributeError( - f"Attribute '{vectorfilter.vector_attribute_name}' not found on '{type(source).__name__}'." + f"Attribute '{vector_filter.vector_attribute_name}' not found on '{source_class.__name__}'." ) from e if not attribute.vector_index: raise AttributeError( - f"Attribute {vectorfilter.vector_attribute_name} is not declared with a vector index." + f"Attribute {vector_filter.vector_attribute_name} is not declared with a vector index." ) - vectorfilter.index_name = ( - f"vector_index_{source.__label__}_{vectorfilter.vector_attribute_name}" - ) - vectorfilter.node_set_label = source.__label__.lower() + vector_filter.index_name = f"vector_index_{source_class.__label__}_{vector_filter.vector_attribute_name}" + vector_filter.node_set_label = source_class.__label__.lower() - self._ast.vector_index_query = vectorfilter - self._ast.return_clause = f"{vectorfilter.node_set_label}, score" - self._ast.result_class = source.__class__ + self._ast.vector_index_query = vector_filter + self._ast.return_clause = f"{vector_filter.node_set_label}, score" + self._ast.result_class = source_class.__class__ + + def build_fulltext_query(self): + """ + Query a free text indexed property on the node. + """ + full_text_filter = self.node_set.fulltext_query + source_class = self.node_set.source_class + try: + attribute = getattr( + self.node_set.source, full_text_filter.fulltext_attribute_name + ) + except AttributeError as e: + raise AttributeError( + f"Atribute '{full_text_filter.fulltext_attribute_name}' not found on '{source_class.__name__}'." + ) from e + + if not attribute.fulltext_index: + raise AttributeError( + f"Attribute {full_text_filter.fulltext_attribute_name} is not declared with a full text index." + ) + + full_text_filter.index_name = f"fulltext_index_{source_class.__label__}_{full_text_filter.fulltext_attribute_name}" + full_text_filter.node_set_label = source_class.__label__.lower() + + self._ast.fulltext_index_query = full_text_filter + self._ast.return_clause = f"{full_text_filter.node_set_label}, score" + self._ast.result_class = source_class.__class__ def build_traversal(self, traversal: "Traversal") -> str: """ @@ -612,7 +649,7 @@ def _additional_return(self, name: str) -> None: def build_traversal_from_path( self, relation: "Path", source_class: Any - ) -> Tuple[str, Any]: + ) -> tuple[str, Any]: path: str = relation.value stmt: str = "" source_class_iterator = source_class @@ -770,7 +807,7 @@ def _register_place_holder(self, key: str) -> str: def _parse_path( self, source_class: type[StructuredNode], prop: str - ) -> Tuple[str, str, Any, bool]: + ) -> tuple[str, str, Any, bool]: is_rel_filter = "|" in prop if is_rel_filter: path, prop = prop.rsplit("|", 1) @@ -837,7 +874,7 @@ def _build_filter_statements( target.append((statement, is_optional_relation)) def _parse_q_filters( - self, ident: str, q: Union[QBase, Any], source_class: type[StructuredNode] + self, ident: str, q: QBase | Any, source_class: type[StructuredNode] ) -> tuple[str, str]: target: list[tuple[str, bool]] = [] @@ -881,7 +918,7 @@ def build_where_stmt( ident: str, filters: list, source_class: type[StructuredNode], - q_filters: Union[QBase, Any, None] = None, + q_filters: QBase | Any | None = None, ) -> None: """ Construct a where statement from some filters. @@ -923,7 +960,7 @@ def build_where_stmt( def lookup_query_variable( self, path: str, return_relation: bool = False - ) -> TOptional[Tuple[str, Any, bool]]: + ) -> tuple[str, Any, bool] | None: """Retrieve the variable name generated internally for the given traversal path.""" subgraph = self._ast.subgraph if not subgraph: @@ -935,7 +972,7 @@ def lookup_query_variable( return None # Check if relation is coming from an optional MATCH - # (declared using fetch|traverse_relations) + # (declared using traverse) is_optional_relation = False for relation in self.node_set.relations_to_fetch: if relation.value == path: @@ -975,6 +1012,16 @@ def build_query(self) -> str: # This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering query += f""" WITH {self._ast.vector_index_query.node_set_label}, score""" + if self._ast.fulltext_index_query: + query += f"""CALL () {{ + CALL db.index.fulltext.queryNodes("{self._ast.fulltext_index_query.index_name}", "{self._ast.fulltext_index_query.query_string}") + YIELD node AS {self._ast.fulltext_index_query.node_set_label}, score + RETURN {self._ast.fulltext_index_query.node_set_label}, score LIMIT {self._ast.fulltext_index_query.topk} + }} + """ + # This ensures that we bring the context of the new nodeSet and score along with us for metadata filtering + query += f""" WITH {self._ast.fulltext_index_query.node_set_label}, score""" + # Instead of using only one MATCH statement for every relation # to follow, we use one MATCH per relation (to avoid cartesian # product issues...). @@ -1123,7 +1170,7 @@ def _count(self) -> int: results, _ = db.cypher_query(query, self._query_params) return int(results[0][0]) - def _contains(self, node_element_id: TOptional[Union[str, int]]) -> bool: + def _contains(self, node_element_id: str | int | None) -> bool: # inject id = into ast if not self._ast.return_clause and self._ast.additional_return: self._ast.return_clause = self._ast.additional_return[0] @@ -1220,7 +1267,7 @@ def __nonzero__(self) -> bool: """ return self.__bool__() - def __contains__(self, obj: Union[StructuredNode, Any]) -> bool: + def __contains__(self, obj: StructuredNode | Any) -> bool: if isinstance(obj, StructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: ast = self.query_cls(self).build_ast() @@ -1230,7 +1277,7 @@ def __contains__(self, obj: Union[StructuredNode, Any]) -> bool: raise ValueError("Expecting StructuredNode instance") - def __getitem__(self, key: Union[int, slice]) -> TOptional["BaseSet"]: + def __getitem__(self, key: int | slice) -> TOptional["BaseSet"]: if isinstance(key, slice): if key.stop and key.start: self.limit = key.stop - key.start @@ -1250,8 +1297,6 @@ def __getitem__(self, key: Union[int, slice]) -> TOptional["BaseSet"]: _first_item = [node for node in ast._execute()][0] return _first_item - return None - @dataclass class Optional: # type: ignore[no-redef] @@ -1269,7 +1314,7 @@ class Path: include_nodes_in_return: bool = True include_rels_in_return: bool = True relation_filtering: bool = False - alias: TOptional[str] = None + alias: str | None = None @dataclass @@ -1277,7 +1322,7 @@ class RelationNameResolver: """Helper to refer to a relation variable name. Since variable names are generated automatically within MATCH statements (for - anything injected using fetch_relations or traverse_relations), we need a way to + anything injected using traverse), we need a way to retrieve them. """ @@ -1298,7 +1343,7 @@ class NodeNameResolver: """Helper to refer to a node variable name. Since variable names are generated automatically within MATCH statements (for - anything injected using fetch_relations or traverse_relations), we need a way to + anything injected using traverse), we need a way to retrieve them. """ @@ -1444,13 +1489,14 @@ def __init__(self, source: Any) -> None: self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] self._unique_variables: list[str] = [] - self.vector_query: Optional[str] = None + self.vector_query: VectorFilter | None = None + self.fulltext_query: FulltextFilter | None = None def __await__(self) -> Any: return self.all().__await__() # type: ignore[attr-defined] def _get( - self, limit: TOptional[int] = None, lazy: bool = False, **kwargs: dict[str, Any] + self, limit: int | None = None, lazy: bool = False, **kwargs: dict[str, Any] ) -> list: self.filter(**kwargs) if limit: @@ -1548,15 +1594,16 @@ def filter(self, *args: Any, **kwargs: Any) -> "BaseSet": """ if args or kwargs: # Need to grab and remove the VectorFilter from both args and kwargs - new_args = ( - [] - ) # As args are a tuple, they're immutable. But we need to remove the vectorfilter from the arguments so they don't go into Q. + # As args are a tuple, they're immutable. But we need to remove the vectorfilter from the arguments so they don't go into Q. + new_args = [] for arg in args: if isinstance(arg, VectorFilter) and (not self.vector_query): self.vector_query = arg - new_args.append(arg) - new_args = tuple(new_args) + if isinstance(arg, FulltextFilter) and (not self.fulltext_query): + self.fulltext_query = arg + + new_args.append(arg) if ( kwargs.get("vector_filter") @@ -1565,7 +1612,14 @@ def filter(self, *args: Any, **kwargs: Any) -> "BaseSet": ): self.vector_query = kwargs.pop("vector_filter") - self.q_filters = Q(self.q_filters & Q(*new_args, **kwargs)) + if ( + kwargs.get("fulltext_filter") + and isinstance(kwargs["fulltext_filter"], FulltextFilter) + and not self.fulltext_query + ): + self.fulltext_query = kwargs.pop("fulltext_filter") + + self.q_filters = Q(self.q_filters & Q(*tuple(new_args), **kwargs)) return self @@ -1620,7 +1674,7 @@ def order_by(self, *props: Any) -> "BaseSet": return self def _register_relation_to_fetch( - self, relation_def: Any, alias: TOptional[str] = None + self, relation_def: Any, alias: str | None = None ) -> "Path": if isinstance(relation_def, Path): item = relation_def @@ -1632,9 +1686,9 @@ def _register_relation_to_fetch( item.alias = alias return item - def unique_variables(self, *paths: tuple[str, ...]) -> "NodeSet": + def unique_variables(self, *paths: str) -> "NodeSet": """Generate unique variable names for the given paths.""" - self._unique_variables = paths + self._unique_variables = list(paths) return self def traverse(self, *paths: tuple[str, ...], **aliased_paths: dict) -> "NodeSet": @@ -1649,60 +1703,12 @@ def traverse(self, *paths: tuple[str, ...], **aliased_paths: dict) -> "NodeSet": self.relations_to_fetch = relations return self - def fetch_relations(self, *relation_names: tuple[str, ...]) -> "NodeSet": - """Specify a set of relations to traverse and return.""" - warnings.warn( - "fetch_relations() will be deprecated in version 6, use traverse() instead.", - DeprecationWarning, - ) - relations = [] - for relation_name in relation_names: - if isinstance(relation_name, Optional): - relation_name = Path(value=relation_name.relation, optional=True) - relations.append(self._register_relation_to_fetch(relation_name)) - self.relations_to_fetch = relations - return self - - def traverse_relations( - self, *relation_names: tuple[str, ...], **aliased_relation_names: dict - ) -> "NodeSet": - """Specify a set of relations to traverse only.""" - - warnings.warn( - "traverse_relations() will be deprecated in version 6, use traverse() instead.", - DeprecationWarning, - ) - - def convert_to_path(input: Union[str, Optional]) -> Path: - if isinstance(input, Optional): - path = Path(value=input.relation, optional=True) - else: - path = Path(value=input) - path.include_nodes_in_return = False - path.include_rels_in_return = False - return path - - relations = [] - for relation_name in relation_names: - relations.append( - self._register_relation_to_fetch(convert_to_path(relation_name)) - ) - for alias, relation_def in aliased_relation_names.items(): - relations.append( - self._register_relation_to_fetch( - convert_to_path(relation_def), alias=alias - ) - ) - - self.relations_to_fetch = relations - return self - def annotate(self, *vars: tuple, **aliased_vars: tuple) -> "NodeSet": """Annotate node set results with extra variables.""" def register_extra_var( - vardef: Union[AggregatingFunction, ScalarFunction, Any], - varname: Union[str, None] = None, + vardef: AggregatingFunction | ScalarFunction | Any, + varname: str | None = None, ) -> None: if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results.append( @@ -1761,20 +1767,12 @@ def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ - if ( - self.relations_to_fetch - and not self.relations_to_fetch[0].include_nodes_in_return - and not self.relations_to_fetch[0].include_rels_in_return - ): - raise NotImplementedError( - "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." - ) results: list = [] qbuilder = self.query_cls(self) qbuilder.build_ast() if not qbuilder._ast.subgraph: raise RuntimeError( - "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + "Nothing to resolve. Make sure to include relations in the result using traverse() or filter()." ) other_nodes = {} root_node = None @@ -1783,9 +1781,6 @@ def resolve_subgraph(self) -> list: if node.__class__ is self.source and "_" not in name: root_node = node continue - if isinstance(node, list) and isinstance(node[0], list): - other_nodes[name] = node[0] - continue other_nodes[name] = node results.append( self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) @@ -1796,7 +1791,7 @@ def subquery( self, nodeset: "NodeSet", return_set: list[str], - initial_context: TOptional[list[str]] = None, + initial_context: list[str] | None = None, ) -> "NodeSet": """Add a subquery to this node set. @@ -1827,7 +1822,7 @@ def subquery( raise RuntimeError(f"Variable '{var}' is not returned by subquery.") if initial_context: for var in initial_context: - if type(var) is not str and not isinstance( + if not isinstance(var, str) and not isinstance( var, (NodeNameResolver, RelationNameResolver, RawCypher) ): raise ValueError( @@ -1847,7 +1842,7 @@ def intermediate_transform( self, vars: dict[str, Transformation], distinct: bool = False, - ordering: TOptional[list] = None, + ordering: list | None = None, ) -> "NodeSet": if not vars: raise ValueError( diff --git a/neomodel/sync_/node.py b/neomodel/sync_/node.py new file mode 100644 index 00000000..4f6b382a --- /dev/null +++ b/neomodel/sync_/node.py @@ -0,0 +1,622 @@ +""" +Node classes and metadata for the neomodel module. +""" + +from __future__ import annotations + +import warnings +from itertools import combinations +from typing import TYPE_CHECKING, Any, Callable + +from neo4j.graph import Node + +from neomodel.constants import STREAMING_WARNING +from neomodel.exceptions import DoesNotExist, NodeClassAlreadyDefined +from neomodel.hooks import hooks +from neomodel.properties import Property +from neomodel.sync_.database import db +from neomodel.sync_.property_manager import PropertyManager +from neomodel.util import _UnsavedNode, classproperty + +if TYPE_CHECKING: + from neomodel.sync_.match import NodeSet + + +class NodeMeta(type): + DoesNotExist: type[DoesNotExist] + __required_properties__: tuple[str, ...] + __all_properties__: tuple[tuple[str, Any], ...] + __all_aliases__: tuple[tuple[str, Any], ...] + __all_relationships__: tuple[tuple[str, Any], ...] + __label__: str + __optional_labels__: list[str] + + defined_properties: Callable[..., dict[str, Any]] + + def __new__( + mcs: type, name: str, bases: tuple[type, ...], namespace: dict[str, Any] + ) -> Any: + namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) + cls: NodeMeta = type.__new__(mcs, name, bases, namespace) + cls.DoesNotExist._model_class = cls + + if hasattr(cls, "__abstract_node__"): + delattr(cls, "__abstract_node__") + else: + if "deleted" in namespace: + raise ValueError( + "Property name 'deleted' is not allowed as it conflicts with neomodel internals." + ) + elif "id" in namespace: + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif "element_id" in namespace: + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + for key, value in ( + (x, y) for x, y in namespace.items() if isinstance(y, Property) + ): + value.name, value.owner = key, cls + if hasattr(value, "setup") and callable(value.setup): + value.setup() + + # cache various groups of properies + cls.__required_properties__ = tuple( + name + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + if property.required or property.unique_index + ) + cls.__all_properties__ = tuple( + cls.defined_properties(aliases=False, rels=False).items() + ) + cls.__all_aliases__ = tuple( + cls.defined_properties(properties=False, rels=False).items() + ) + cls.__all_relationships__ = tuple( + cls.defined_properties(aliases=False, properties=False).items() + ) + + cls.__label__ = namespace.get("__label__", name) + cls.__optional_labels__ = namespace.get("__optional_labels__", []) + + build_class_registry(cls) + + return cls + + +def build_class_registry(cls: Any) -> None: + base_label_set = frozenset(cls.inherited_labels()) + optional_label_set = set(cls.inherited_optional_labels()) + + # Construct all possible combinations of labels + optional labels + possible_label_combinations = [ + frozenset(set(x).union(base_label_set)) + for i in range(1, len(optional_label_set) + 1) + for x in combinations(optional_label_set, i) + ] + possible_label_combinations.append(base_label_set) + + for label_set in possible_label_combinations: + if not hasattr(cls, "__target_databases__"): + if label_set not in db._NODE_CLASS_REGISTRY: + db._NODE_CLASS_REGISTRY[label_set] = cls + else: + raise NodeClassAlreadyDefined( + cls, db._NODE_CLASS_REGISTRY, db._DB_SPECIFIC_CLASS_REGISTRY + ) + else: + for database in cls.__target_databases__: + if database not in db._DB_SPECIFIC_CLASS_REGISTRY: + db._DB_SPECIFIC_CLASS_REGISTRY[database] = {} + if label_set not in db._DB_SPECIFIC_CLASS_REGISTRY[database]: + db._DB_SPECIFIC_CLASS_REGISTRY[database][label_set] = cls + else: + raise NodeClassAlreadyDefined( + cls, db._NODE_CLASS_REGISTRY, db._DB_SPECIFIC_CLASS_REGISTRY + ) + + +NodeBase: type = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) + + +class StructuredNode(NodeBase): + """ + Base class for all node definitions to inherit from. + + If you want to create your own abstract classes set: + __abstract_node__ = True + """ + + # static properties + + __abstract_node__ = True + + # magic methods + + def __init__(self, *args: Any, **kwargs: Any): + if "deleted" in kwargs: + raise ValueError("deleted property is reserved for neomodel") + + for key, val in self.__all_relationships__: + self.__dict__[key] = val.build_manager(self, key) + + super().__init__(*args, **kwargs) + + def __eq__(self, other: Any) -> bool: + """ + Compare two node objects. + If both nodes were saved to the database, compare them by their element_id. + Otherwise, compare them using object id in memory. + If `other` is not a node, always return False. + """ + if not isinstance(other, (StructuredNode,)): + return False + if self.was_saved and other.was_saved: + return self.element_id == other.element_id + return id(self) == id(other) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self}>" + + def __str__(self) -> str: + return repr(self.__properties__) + + # dynamic properties + + @classproperty + def nodes(self) -> "NodeSet": + """ + Returns a NodeSet object representing all nodes of the classes label + :return: NodeSet + :rtype: NodeSet + """ + from neomodel.sync_.match import NodeSet + + return NodeSet(self) + + @property + def element_id(self) -> Any | None: + if hasattr(self, "element_id_property"): + return self.element_id_property + return None + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self) -> int: + try: + return int(self.element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + @property + def was_saved(self) -> bool: + """ + Shows status of node in the database. False, if node hasn't been saved yet, True otherwise. + """ + return self.element_id is not None + + # methods + + @classmethod + def _build_merge_query( + cls, + merge_params: tuple[dict[str, Any], ...], + update_existing: bool = False, + lazy: bool = False, + relationship: Any | None = None, + merge_by: dict[str, str | list[str]] | None = None, + ) -> tuple[str, dict[str, Any]]: + """ + Get a tuple of a CYPHER query and a params dict for the specified MERGE query. + + :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". + :type merge_params: list of dict + :param update_existing: True to update properties of existing nodes, default False to keep existing values. + :type update_existing: bool + :param lazy: False by default, specify True to get nodes with id only without the properties. + :type lazy: bool + :param relationship: Optional relationship to create when merging nodes. + :type relationship: Any | None + :param merge_by: Optional dict with 'label' and 'keys' to specify custom merge criteria. + 'label' is optional and should be a string, 'keys' is a list of strings. + If 'label' is not provided, uses the node's inherited labels. + If 'keys' is not provided, uses the node's required properties as merge keys. + :type merge_by: dict[str, str | list[str]] | None + :return: tuple of query and params + :rtype: tuple[str, dict[str, Any]] + """ + query_params: dict[str, Any] = {"merge_params": merge_params} + + # Determine merge key and labels + if merge_by: + # Use custom merge keys + merge_keys = merge_by["keys"] + merge_labels = merge_by.get("label", ":".join(cls.inherited_labels())) + + n_merge_prm = ", ".join(f"{key}: params.create.{key}" for key in merge_keys) + else: + # Use default required properties + merge_labels = ":".join(cls.inherited_labels()) + n_merge_prm = ", ".join( + ( + f"{getattr(cls, p).get_db_property_name(p)}: params.create.{getattr(cls, p).get_db_property_name(p)}" + for p in cls.__required_properties__ + ) + ) + + n_merge = f"n:{merge_labels} {{{n_merge_prm}}}" + if relationship is None: + # create "simple" unwind query + query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " + else: + # validate relationship + if not isinstance(relationship.source, StructuredNode): + raise ValueError( + f"relationship source [{repr(relationship.source)}] is not a StructuredNode" + ) + relation_type = relationship.definition.get("relation_type") + if not relation_type: + raise ValueError( + "No relation_type is specified on provided relationship" + ) + + from neomodel.sync_.match import _rel_helper + + if relationship.source.element_id is None: + raise RuntimeError( + "Could not identify the relationship source, its element id was None." + ) + query_params["source_id"] = db.parse_element_id( + relationship.source.element_id + ) + query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " + query += "WITH source\n UNWIND $merge_params as params \n " + query += "MERGE " + query += _rel_helper( + lhs="source", + rhs=n_merge, + ident=None, + relation_type=relation_type, + direction=relationship.definition["direction"], + ) + + query += "ON CREATE SET n = params.create\n " + # if update_existing, write properties on match as well + if update_existing is True: + query += "ON MATCH SET n += params.update\n" + + # close query + if lazy: + query += f"RETURN {db.get_id_method()}(n)" + else: + query += "RETURN n" + + return query, query_params + + @classmethod + def create(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: + """ + Call to CREATE with parameters map. A new instance will be created and saved. + + :param props: dict of properties to create the nodes. + :type props: tuple + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type: bool + :rtype: list + """ + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + lazy = kwargs.get("lazy", False) + # create mapped query + query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" + + # close query + if lazy: + query += f" RETURN {db.get_id_method()}(n)" + else: + query += " RETURN n" + + results = [] + for item in [ + cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props + ]: + node, _ = db.cypher_query(query, {"create_params": item}) + results.extend(node[0]) + + nodes = [cls.inflate(node) for node in results] + + if not lazy and hasattr(cls, "post_create"): + for node in nodes: + node.post_create() + + return nodes + + @classmethod + def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, + this is an atomic operation. If an instance already exists all optional properties specified will be updated. + + Note that the post_create hook isn't called after create_or_update + + :param props: List of dict arguments to get or create the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :type relationship: Any | None + :param lazy: False by default, specify True to get nodes with id only without the properties. + :type lazy: bool + :param merge_by: Optional dict with 'label' and 'keys' to specify custom merge criteria. + 'label' is optional and should be a string, 'keys' is a list of strings. + If 'label' is not provided, uses the node's inherited labels. + If 'keys' is not provided, uses the node's required properties as merge keys. + :type merge_by: dict[str, str | list[str]] | None + :return: list of nodes + :rtype: list + """ + lazy: bool = bool(kwargs.get("lazy", False)) + relationship = kwargs.get("relationship") + merge_by = kwargs.get("merge_by") + + # build merge query, make sure to update only explicitly specified properties + create_or_update_params = [] + for specified, deflated in [ + (p, cls.deflate(p, skip_empty=True)) for p in props + ]: + create_or_update_params.append( + { + "create": deflated, + "update": {k: v for k, v in deflated.items() if k in specified}, + } + ) + query, params = cls._build_merge_query( + tuple(create_or_update_params), + update_existing=True, + relationship=relationship, + lazy=lazy, + merge_by=merge_by, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = db.cypher_query(query, params) + if lazy: + return [r[0] for r in results[0]] + else: + return [cls.inflate(r[0]) for r in results[0]] + + def cypher( + self, query: str, params: dict[str, Any] | None = None + ) -> tuple[list | None, tuple[str, ...] | None]: + """ + Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. + + :param query: cypher query string + :type: string + :param params: query parameters + :type: dict + :return: tuple containing a list of query results, and the meta information as a tuple + :rtype: tuple + """ + self._pre_action_check("cypher") + _params = params or {} + if self.element_id is None: + raise ValueError("Can't run cypher operation on unsaved node") + element_id = db.parse_element_id(self.element_id) + _params.update({"self": element_id}) + return db.cypher_query(query, _params) + + @hooks + def delete(self) -> bool: + """ + Delete a node and its relationships + + :return: True + """ + self._pre_action_check("delete") + self.cypher( + f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self" + ) + delattr(self, "element_id_property") + self.deleted = True + return True + + @classmethod + def get_or_create(cls: Any, *props: tuple, **kwargs: dict[str, Any]) -> list: + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, + this is an atomic operation. + Parameters must contain all required properties, any non required properties with defaults will be generated. + + Note that the post_create hook isn't called after get_or_create + + :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create + the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :type relationship: Any | None + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type lazy: bool + :param merge_by: Optional dict with 'label' and 'keys' to specify custom merge criteria. + 'label' is optional and should be a string, 'keys' is a list of strings. + If 'label' is not provided, uses the node's inherited labels. + If 'keys' is not provided, uses the node's required properties as merge keys. + :type merge_by: dict[str, str | list[str]] | None + :return: list of nodes + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + merge_by = kwargs.get("merge_by") + + # build merge query + get_or_create_params = [ + {"create": cls.deflate(p, skip_empty=True)} for p in props + ] + query, params = cls._build_merge_query( + tuple(get_or_create_params), + relationship=relationship, + lazy=lazy, + merge_by=merge_by, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = db.cypher_query(query, params) + if lazy: + return [r[0] for r in results[0]] + else: + return [cls.inflate(r[0]) for r in results[0]] + + @classmethod + def inflate(cls: Any, graph_entity: Node) -> Any: # type: ignore[override] + """ + Inflate a raw neo4j_driver node to a neomodel node + :param graph_entity: node + :return: node object + """ + # support lazy loading + if isinstance(graph_entity, str) or isinstance(graph_entity, int): + snode = cls() + snode.element_id_property = graph_entity + else: + snode = super().inflate(graph_entity) + snode.element_id_property = graph_entity.element_id + + return snode + + @classmethod + def inherited_labels(cls: Any) -> list[str]: + """ + Return list of labels from nodes class hierarchy. + + :return: list + """ + return [ + scls.__label__ + for scls in cls.mro() + if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") + ] + + @classmethod + def inherited_optional_labels(cls: Any) -> list[str]: + """ + Return list of optional labels from nodes class hierarchy. + + :return: list + :rtype: list + """ + return [ + label + for scls in cls.mro() + for label in getattr(scls, "__optional_labels__", []) + if not hasattr(scls, "__abstract_node__") + ] + + def labels(self) -> list[str]: + """ + Returns list of labels tied to the node from neo4j. + + :return: list of labels + :rtype: list + """ + self._pre_action_check("labels") + result = self.cypher( + f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" + ) + if result is None or result[0] is None: + raise ValueError("Could not get labels, node may not exist") + return result[0][0][0] + + def _pre_action_check(self, action: str) -> None: + if hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on deleted node" + ) + if not hasattr(self, "element_id"): + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on unsaved node" + ) + + def refresh(self) -> None: + """ + Reload the node from neo4j + """ + self._pre_action_check("refresh") + if hasattr(self, "element_id"): + results = self.cypher( + f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n" + ) + request = results[0] + if not request or not request[0]: + raise self.__class__.DoesNotExist("Can't refresh non existent node") + node = self.inflate(request[0][0]) + for key, val in node.__properties__.items(): + setattr(self, key, val) + else: + raise ValueError("Can't refresh unsaved node") + + @hooks + def save(self) -> "StructuredNode": + """ + Save the node to neo4j or raise an exception + + :return: the node instance + """ + + # create or update instance node + if hasattr(self, "element_id_property"): + # update + params = self.deflate(self.__properties__, self) + query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n" + + if params: + query += "SET " + query += ",\n".join([f"n.{key} = ${key}" for key in params]) + query += "\n" + if self.inherited_labels(): + query += "\n".join( + [f"SET n:`{label}`" for label in self.inherited_labels()] + ) + self.cypher(query, params) + elif hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.save() attempted on deleted node" + ) + else: # create + result = self.create(self.__properties__) + created_node = result[0] + self.element_id_property = created_node.element_id + return self diff --git a/neomodel/sync_/path.py b/neomodel/sync_/path.py index b5c45931..e4357d73 100644 --- a/neomodel/sync_/path.py +++ b/neomodel/sync_/path.py @@ -2,7 +2,8 @@ from neo4j.graph import Path -from neomodel.sync_.core import StructuredNode, db +from neomodel.sync_.database import db +from neomodel.sync_.node import StructuredNode from neomodel.sync_.relationship import StructuredRel diff --git a/neomodel/sync_/property_manager.py b/neomodel/sync_/property_manager.py index 160475cb..6982130a 100644 --- a/neomodel/sync_/property_manager.py +++ b/neomodel/sync_/property_manager.py @@ -1,7 +1,7 @@ import types from typing import Any -from neo4j.graph import Entity +from neo4j.graph import Node, Relationship from neomodel.exceptions import RequiredProperty from neomodel.properties import AliasProperty, Property @@ -101,9 +101,9 @@ def deflate( return deflated @classmethod - def inflate(cls: Any, graph_entity: Entity) -> Any: + def inflate(cls: Any, graph_entity: Node | Relationship) -> Any: """ - Inflate the properties of a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance + Inflate the properties of a graph entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance of cls. Includes mapping from database property name (see Property.db_property) -> python class attribute name. Ignores any properties that are not defined as python attributes in the class definition. diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index 638a78a0..f65ffe39 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,10 +1,10 @@ -from typing import Any, Optional +from typing import Any from neo4j.graph import Relationship from neomodel.hooks import hooks from neomodel.properties import Property -from neomodel.sync_.core import db +from neomodel.sync_.database import db from neomodel.sync_.property_manager import PropertyManager ELEMENT_ID_MIGRATION_NOTICE = "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." @@ -53,23 +53,27 @@ class StructuredRel(StructuredRelBase): Base class for relationship objects """ - def __init__(self, *args: Any, **kwargs: dict) -> None: - super().__init__(*args, **kwargs) + element_id_property: str + _start_node_element_id_property: str + _end_node_element_id_property: str + + _start_node_class: Any + _end_node_class: Any @property - def element_id(self) -> Optional[Any]: + def element_id(self) -> str | None: if hasattr(self, "element_id_property"): return self.element_id_property return None @property - def _start_node_element_id(self) -> Optional[Any]: + def _start_node_element_id(self) -> str | None: if hasattr(self, "_start_node_element_id_property"): return self._start_node_element_id_property return None @property - def _end_node_element_id(self) -> Optional[Any]: + def _end_node_element_id(self) -> str | None: if hasattr(self, "_end_node_element_id_property"): return self._end_node_element_id_property return None @@ -157,16 +161,16 @@ def end_node(self) -> Any: return results[0][0][0] @classmethod - def inflate(cls: Any, rel: Relationship) -> "StructuredRel": + def inflate(cls: Any, graph_entity: Relationship) -> "StructuredRel": # type: ignore[override] """ Inflate a neo4j_driver relationship object to a neomodel object - :param rel: + :param graph_entity: Relationship :return: StructuredRel """ - srel = super().inflate(rel) - if rel.start_node is not None: - srel._start_node_element_id_property = rel.start_node.element_id - if rel.end_node is not None: - srel._end_node_element_id_property = rel.end_node.element_id - srel.element_id_property = rel.element_id + srel = super().inflate(graph_entity) + if graph_entity.start_node is not None: + srel._start_node_element_id_property = graph_entity.start_node.element_id + if graph_entity.end_node is not None: + srel._end_node_element_id_property = graph_entity.end_node.element_id + srel.element_id_property = graph_entity.element_id return srel diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 18df8e35..a780eff4 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -2,11 +2,10 @@ import inspect import sys from importlib import import_module -from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional -from neomodel import config from neomodel.exceptions import NotConnected, RelationshipClassRedefined -from neomodel.sync_.core import db +from neomodel.sync_.database import db from neomodel.sync_.match import ( NodeSet, Traversal, @@ -15,9 +14,7 @@ ) from neomodel.sync_.relationship import StructuredRel from neomodel.util import ( - EITHER, - INCOMING, - OUTGOING, + RelationshipDirection, enumerate_traceback, get_graph_entity_properties, ) @@ -68,9 +65,9 @@ def __init__(self, source: Any, key: str, definition: dict): def __str__(self) -> str: direction = "either" - if self.definition["direction"] == OUTGOING: + if self.definition["direction"] == RelationshipDirection.OUTGOING: direction = "a outgoing" - elif self.definition["direction"] == INCOMING: + elif self.definition["direction"] == RelationshipDirection.INCOMING: direction = "a incoming" return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" @@ -78,9 +75,7 @@ def __str__(self) -> str: def __await__(self) -> Any: return self.all().__await__() # type: ignore[attr-defined] - def _check_cardinality( - self, node: "StructuredNode", soft_check: bool = False - ) -> None: + def check_cardinality(self, node: "StructuredNode") -> None: """ Check whether a new connection to a node would violate the cardinality of the relationship. @@ -101,8 +96,8 @@ def _check_node(self, obj: type["StructuredNode"]) -> None: @check_source def connect( - self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None - ) -> Optional[StructuredRel]: + self, node: "StructuredNode", properties: dict[str, Any] | None = None + ) -> StructuredRel | None: """ Connect a node @@ -112,7 +107,7 @@ def connect( :return: """ self._check_node(node) - self._check_cardinality(node) + self.check_cardinality(node) # Check for cardinality on the remote end. for rel_name, rel_def in node.defined_properties( @@ -129,9 +124,7 @@ def connect( # If we have found the inverse relationship, we need to check # its cardinality. inverse_rel = getattr(node, rel_name) - inverse_rel._check_cardinality( - self.source, soft_check=config.SOFT_INVERSE_CARDINALITY_CHECK - ) + inverse_rel.check_cardinality(self.source) break if not self.definition["model"] and properties: @@ -188,7 +181,7 @@ def connect( @check_source def replace( - self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None + self, node: "StructuredNode", properties: dict[str, Any] | None = None ) -> None: """ Disconnect all existing nodes and connect the supplied node @@ -202,7 +195,7 @@ def replace( self.connect(node, properties) @check_source - def relationship(self, node: "StructuredNode") -> Optional[StructuredRel]: + def relationship(self, node: "StructuredNode") -> StructuredRel | None: """ Retrieve the relationship object for this first relationship between self and node. @@ -250,7 +243,7 @@ def all_relationships(self, node: "StructuredNode") -> list[StructuredRel]: def _set_start_end_cls( self, rel_instance: StructuredRel, obj: "StructuredNode" ) -> StructuredRel: - if self.definition["direction"] == INCOMING: + if self.definition["direction"] == RelationshipDirection.INCOMING: rel_instance._start_node_class = obj.__class__ rel_instance._end_node_class = self.source_class else: @@ -438,7 +431,7 @@ def __nonzero__(self) -> bool: def __contains__(self, obj: Any) -> bool: return self._new_traversal().__contains__(obj) - def __getitem__(self, key: Union[int, slice]) -> Any: + def __getitem__(self, key: int | slice) -> Any: return self._new_traversal().__getitem__(key) @@ -449,7 +442,7 @@ def __init__( cls_name: str, direction: int, manager: type[RelationshipManager] = RelationshipManager, - model: Optional[type[StructuredRel]] = None, + model: type[StructuredRel] | None = None, ) -> None: self._validate_class(cls_name, model) @@ -502,7 +495,7 @@ def __init__( db._NODE_CLASS_REGISTRY[label_set] = model def _validate_class( - self, cls_name: str, model: Optional[type[StructuredRel]] = None + self, cls_name: str, model: type[StructuredRel] | None = None ) -> None: if not isinstance(cls_name, (str, object)): raise ValueError("Expected class name or class got " + repr(cls_name)) @@ -566,10 +559,14 @@ def __init__( cls_name: str, relation_type: str, cardinality: type[RelationshipManager] = ZeroOrMore, - model: Optional[type[StructuredRel]] = None, + model: type[StructuredRel] | None = None, ) -> None: super().__init__( - relation_type, cls_name, OUTGOING, manager=cardinality, model=model + relation_type, + cls_name, + RelationshipDirection.OUTGOING, + manager=cardinality, + model=model, ) @@ -579,10 +576,14 @@ def __init__( cls_name: str, relation_type: str, cardinality: type[RelationshipManager] = ZeroOrMore, - model: Optional[type[StructuredRel]] = None, + model: type[StructuredRel] | None = None, ) -> None: super().__init__( - relation_type, cls_name, INCOMING, manager=cardinality, model=model + relation_type, + cls_name, + RelationshipDirection.INCOMING, + manager=cardinality, + model=model, ) @@ -592,8 +593,12 @@ def __init__( cls_name: str, relation_type: str, cardinality: type[RelationshipManager] = ZeroOrMore, - model: Optional[type[StructuredRel]] = None, + model: type[StructuredRel] | None = None, ) -> None: super().__init__( - relation_type, cls_name, EITHER, manager=cardinality, model=model + relation_type, + cls_name, + RelationshipDirection.EITHER, + manager=cardinality, + model=model, ) diff --git a/neomodel/sync_/transaction.py b/neomodel/sync_/transaction.py new file mode 100644 index 00000000..006295db --- /dev/null +++ b/neomodel/sync_/transaction.py @@ -0,0 +1,115 @@ +""" +Transaction management for the neomodel module. +""" + +import warnings +from asyncio import iscoroutinefunction +from functools import wraps +from typing import Any, Callable + +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError + +from neomodel._async_compat.util import Util +from neomodel.constants import NOT_COROUTINE_ERROR +from neomodel.exceptions import UniqueProperty +from neomodel.sync_.database import Database + + +class TransactionProxy: + def __init__( + self, + db: Database, + access_mode: str | None = None, + parallel_runtime: bool | None = False, + ): + self.db: Database = db + self.access_mode: str | None = access_mode + self.parallel_runtime: bool | None = parallel_runtime + self.bookmarks: Bookmarks | None = None + self.last_bookmarks: Bookmarks | None = None + + def __enter__(self) -> "TransactionProxy": + if self.parallel_runtime and not self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False + self.db._parallel_runtime = self.parallel_runtime + self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) + self.bookmarks = None + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.db._parallel_runtime = False + if exc_value: + self.db.rollback() + + if ( + exc_type is ClientError + and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" + ): + raise UniqueProperty(exc_value.message) + + if not exc_value: + self.last_bookmarks = self.db.commit() + + def __call__(self, func: Callable) -> Callable: + if Util.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Callable: + with self: + return func(*args, **kwargs) + + return wrapper + + @property + def with_bookmark(self) -> "BookmarkingAsyncTransactionProxy": + return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) + + +class BookmarkingAsyncTransactionProxy(TransactionProxy): + def __call__(self, func: Callable) -> Callable: + if Util.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, None]: + self.bookmarks = kwargs.pop("bookmarks", None) + + with self: + result = func(*args, **kwargs) + self.last_bookmarks = None + + return result, self.last_bookmarks + + return wrapper + + +class ImpersonationHandler: + def __init__(self, db: Database, impersonated_user: str): + self.db = db + self.impersonated_user = impersonated_user + + def __enter__(self) -> "ImpersonationHandler": + self.db.impersonated_user = self.impersonated_user + return self + + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: + self.db.impersonated_user = None + + print("\nException type:", exception_type) + print("\nException value:", exception_value) + print("\nTraceback:", exception_traceback) + + def __call__(self, func: Callable) -> Callable: + def wrapper(*args: Any, **kwargs: Any) -> Callable: + with self: + return func(*args, **kwargs) + + return wrapper diff --git a/neomodel/typing.py b/neomodel/typing.py index a23f88eb..7e26512e 100644 --- a/neomodel/typing.py +++ b/neomodel/typing.py @@ -1,13 +1,13 @@ """Custom types used for annotations.""" -from typing import Any, Optional, TypedDict +from typing import Any, TypedDict Transformation = TypedDict( "Transformation", { "source": Any, - "source_prop": Optional[str], - "include_in_return": Optional[bool], + "source_prop": str | None, + "include_in_return": bool | None, }, ) @@ -18,6 +18,6 @@ "query": str, "query_params": dict, "return_set": list[str], - "initial_context": Optional[list[Any]], + "initial_context": list[Any] | None, }, ) diff --git a/neomodel/util.py b/neomodel/util.py index 36fffdd2..509bdd97 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -1,10 +1,17 @@ import warnings +from enum import IntEnum from types import FrameType -from typing import Any, Callable, Optional +from typing import Any, Callable from neo4j.graph import Entity -OUTGOING, INCOMING, EITHER = 1, -1, 0 + +class RelationshipDirection(IntEnum): + """Enum representing the direction of relationships in Neo4j.""" + + OUTGOING = 1 + INCOMING = -1 + EITHER = 0 def deprecated(message: str) -> Callable: @@ -27,7 +34,7 @@ class cpf: def __init__(self, getter: Callable) -> None: self.getter = getter - def __get__(self, obj: Any, type: Optional[Any] = None) -> Any: + def __get__(self, obj: Any, type: Any | None = None) -> Any: return self.getter(type) return cpf(f) @@ -49,7 +56,7 @@ def get_graph_entity_properties(entity: Entity) -> dict: return entity._properties -def enumerate_traceback(initial_frame: Optional[FrameType] = None) -> Any: +def enumerate_traceback(initial_frame: FrameType | None = None) -> Any: depth, frame = 0, initial_frame while frame is not None: yield depth, frame diff --git a/pyproject.toml b/pyproject.toml index 69fb43cc..91b16a3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ dependencies = [ "neo4j~=5.28.2", ] -requires-python = ">=3.9" +requires-python = ">=3.10" dynamic = ["version"] [project.urls] diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 8cbf952b..f516e072 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -5,7 +5,7 @@ mark_async_session_auto_fixture, ) -from neomodel import adb, config +from neomodel import adb, get_config @mark_async_session_auto_fixture @@ -19,7 +19,7 @@ async def setup_neo4j_session(request): warnings.simplefilter("default") - config.DATABASE_URL = os.environ.get( + get_config().database_url = os.environ.get( "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" ) @@ -51,5 +51,8 @@ async def setup_neo4j_session(request): @mark_async_function_auto_fixture async def setUp(): + await adb.set_connection(url=get_config().database_url) await adb.cypher_query("MATCH (n) DETACH DELETE n") yield + + await adb.close_connection() diff --git a/test/async_/test_batch.py b/test/async_/test_batch.py index 653dce0d..d2aa3e6d 100644 --- a/test/async_/test_batch.py +++ b/test/async_/test_batch.py @@ -9,13 +9,11 @@ IntegerProperty, StringProperty, UniqueIdProperty, - config, + adb, ) from neomodel._async_compat.util import AsyncUtil from neomodel.exceptions import DeflateError, UniqueProperty -config.AUTO_INSTALL_LABELS = True - class UniqueUser(AsyncStructuredNode): uid = UniqueIdProperty() @@ -136,3 +134,454 @@ async def test_get_or_create_with_rel(): # not the same gizmo assert bobs_gizmo[0] != tims_gizmo[0] + + +class NodeWithDefaultProp(AsyncStructuredNode): + name = StringProperty(required=True) + age = IntegerProperty(default=30) + other_prop = StringProperty() + + +@mark_async_test +async def test_get_or_create_with_ignored_properties(): + node = await NodeWithDefaultProp.get_or_create({"name": "Tania", "age": 20}) + assert node[0].name == "Tania" + assert node[0].age == 20 + + node = await NodeWithDefaultProp.get_or_create({"name": "Tania"}) + assert node[0].name == "Tania" + assert node[0].age == 20 # Tania was fetched and not created + + node = await NodeWithDefaultProp.get_or_create({"name": "Tania", "age": 30}) + assert node[0].name == "Tania" + assert node[0].age == 20 # Tania was fetched and not created + + +@mark_async_test +async def test_create_or_update_with_ignored_properties(): + node = await NodeWithDefaultProp.create_or_update({"name": "Tania", "age": 20}) + assert node[0].name == "Tania" + assert node[0].age == 20 + + node = await NodeWithDefaultProp.create_or_update( + {"name": "Tania", "other_prop": "other"} + ) + assert node[0].name == "Tania" + assert ( + node[0].age == 20 + ) # Tania is still 20 even though default says she should be 30 + assert ( + node[0].other_prop == "other" + ) # She does have a brand new other_prop, lucky her ! + + node = await NodeWithDefaultProp.create_or_update( + {"name": "Tania", "age": 30, "other_prop": "other2"} + ) + assert node[0].name == "Tania" + assert node[0].age == 30 # Tania is now 30, congrats Tania ! + assert ( + node[0].other_prop == "other2" + ) # Plus she has a new other_prop - as a birthday gift ? + + +@mark_async_test +async def test_lazy_mode(): + """Test lazy mode functionality.""" + + node1 = (await NodeWithDefaultProp.create({"name": "Tania", "age": 20}))[0] + node = await NodeWithDefaultProp.get_or_create( + {"name": "Tania", "age": 20}, lazy=True + ) + if await adb.version_is_higher_than("5.0.0"): + assert node[0] == node1.element_id + else: + assert node[0] == node1.id + + node = await NodeWithDefaultProp.create_or_update( + {"name": "Tania", "age": 25}, lazy=True + ) + if await adb.version_is_higher_than("5.0.0"): + assert node[0] == node1.element_id + else: + assert node[0] == node1.id + + +class MergeKeyTestNode(AsyncStructuredNode): + """Test node for merge key functionality tests.""" + + name = StringProperty(required=True) + email = StringProperty(required=True, unique_index=True) + age = IntegerProperty() + department = StringProperty() + + +class MergeKeyChildTestNode(MergeKeyTestNode): + """Test node for merge key functionality tests with an extra label.""" + + +@mark_async_test +async def test_default_merge_behavior(): + """Test default merge behavior using required properties.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + {"name": "John", "email": "john@example.com", "age": 30} + ) + )[0] + + # Update with same name and email (should update existing) + nodes = await MergeKeyTestNode.create_or_update( + {"name": "John", "email": "john@example.com", "age": 31} + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id + assert nodes[0].age == 31 + + # Verify the node was updated correctly + assert nodes[0].name == "John" + assert nodes[0].email == "john@example.com" + assert nodes[0].age == 31 + + +@mark_async_test +async def test_custom_merge_key_email(): + """Test custom merge key using email only.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + {"name": "Jane", "email": "jane@example.com", "age": 25} + ) + )[0] + + # Update with custom merge key (email only) + nodes = await MergeKeyTestNode.create_or_update( + { + "name": "Jane Doe", # Different name + "email": "jane@example.com", # Same email + "age": 26, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node should be the same + assert nodes[0].name == "Jane Doe" # Name should be updated + assert nodes[0].age == 26 # Age should be updated + + +@mark_async_test +async def test_merge_key_unspecified_label(): + """Test merge key with unspecified label.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + {"name": "Jane", "email": "jane@example.com", "age": 25} + ) + )[0] + + # Update with custom merge key (email only) + nodes = await MergeKeyTestNode.create_or_update( + { + "name": "Jane Doe", # Different name + "email": "jane@example.com", # Same email + "age": 26, + }, + merge_by={"keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node should be the same + assert nodes[0].name == "Jane Doe" # Name should be updated + assert nodes[0].age == 26 # Age should be updated + + +@mark_async_test +async def test_get_or_create_with_merge_key(): + """Test get_or_create with custom merge key.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + {"name": "Alice", "email": "alice@example.com", "age": 28} + ) + )[0] + + # Use get_or_create with custom merge key + nodes = await MergeKeyTestNode.get_or_create( + { + "name": "Alice Smith", # Different name + "email": "alice@example.com", # Same email + "age": 29, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node was fetched and not created + assert nodes[0].name == "Alice" # Name should be the same + assert nodes[0].age == 28 + + +@mark_async_test +async def test_merge_key_create_new_node(): + """Test that merge key creates new node when no match is found.""" + + # Create node with merge key that won't match anything + nodes = await MergeKeyTestNode.create_or_update( + {"name": "New User", "email": "new@example.com", "age": 30}, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].name == "New User" + assert nodes[0].email == "new@example.com" + assert nodes[0].age == 30 + + # Verify the node was created correctly + assert nodes[0].name == "New User" + assert nodes[0].email == "new@example.com" + assert nodes[0].age == 30 + + +@mark_async_test +async def test_get_or_create_with_different_label(): + """Test merge key with different label specification.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + {"name": "Eve", "email": "eve@example.com", "age": 27} + ) + )[0] + node2 = ( + await MergeKeyChildTestNode.create( + {"name": "Eve Child", "email": "evechild@example.com", "age": 27} + ) + )[0] + + # Use merge key with explicit label - child one + nodes = await MergeKeyTestNode.get_or_create( + { + "name": "Eve Child", # Different name + "email": "evechild@example.com", # Same email + "age": 27, + }, + merge_by={"label": "MergeKeyChildTestNode", "keys": ["age"]}, + ) + + assert len(nodes) == 1 # Only the node with child label + assert nodes[0].element_id == node2.element_id # Node was fetched and not created + assert nodes[0].name == "Eve Child" + + # Use merge key with explicit label - parent one + nodes = await MergeKeyTestNode.get_or_create( + { + "name": "Eve Child", # Different name + "email": "evechild@example.com", # Same email + "age": 27, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["age"]}, + ) + + assert len(nodes) == 2 # Both nodes were fetched + element_ids = [node.element_id for node in nodes] + assert node1.element_id in element_ids + assert node2.element_id in element_ids + + +@mark_async_test +async def test_multiple_merge_operations(): + """Test multiple merge operations with different keys.""" + + # Create initial nodes + node1 = ( + await MergeKeyTestNode.create( + {"name": "Frank", "email": "frank@example.com", "age": 45} + ) + )[0] + node2 = ( + await MergeKeyTestNode.create( + {"name": "Grace", "email": "grace@example.com", "age": 38} + ) + )[0] + + # Update Frank by email + nodes1 = await MergeKeyTestNode.create_or_update( + {"name": "Franklin", "email": "frank@example.com", "age": 46}, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + # Update Grace by name + nodes2 = await MergeKeyTestNode.create_or_update( + {"name": "Grace", "email": "grace.new@example.com", "age": 39}, + merge_by={"label": "MergeKeyTestNode", "keys": ["name"]}, + ) + + assert len(nodes1) == 1 + assert len(nodes2) == 1 + assert nodes1[0].element_id == node1.element_id + assert nodes2[0].element_id == node2.element_id + assert nodes1[0].name == "Franklin" + assert nodes2[0].email == "grace.new@example.com" + + # Verify both nodes were updated correctly + assert nodes1[0].name == "Franklin" + assert nodes2[0].email == "grace.new@example.com" + + +@mark_async_test +async def test_merge_key_lazy_mode(): + """Test merge key functionality with lazy mode.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + {"name": "Diana", "email": "diana@example.com", "age": 32} + ) + )[0] + + # Test with lazy mode + nodes = await MergeKeyTestNode.create_or_update( + {"name": "Diana Prince", "email": "diana@example.com", "age": 33}, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + lazy=True, + ) + + assert len(nodes) == 1 + # In lazy mode, we should get the element_id back + if await adb.version_is_higher_than("5.0.0"): + assert nodes[0] == node1.element_id + else: + assert nodes[0] == node1.id + + +@mark_async_test +async def test_merge_key_with_multiple_properties(): + """Test merge key with a property that has multiple values.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + { + "name": "Multi", + "email": "multi@example.com", + "age": 25, + "department": "Engineering", + } + ) + )[0] + + # Update with different department but same email + nodes = await MergeKeyTestNode.create_or_update( + { + "name": "Multi Updated", + "email": "multi@example.com", + "age": 26, + "department": "Management", + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node should be the same + assert nodes[0].name == "Multi Updated" + assert nodes[0].department == "Management" + assert nodes[0].age == 26 + + +@mark_async_test +async def test_merge_key_with_get_or_create_multiple_keys(): + """Test merge key for get_or_create with multiple keys.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + { + "name": "Charlie", + "email": "charlie@example.com", + "age": 35, + } + ) + )[0] + + # Use get_or_create with multiple keys + nodes = await MergeKeyTestNode.get_or_create( + { + "name": "Charlie", + "email": "charlie@example.com", + "age": 36, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id + assert nodes[0].age == 35 + + # Try to get_or_create with different keys (should create new) + nodes2 = await MergeKeyTestNode.get_or_create( + { + "name": "Charlie", + "email": "charlie.doe@example.com", # Different email + "age": 37, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes2) == 1 + assert nodes2[0].element_id != node1.element_id + assert nodes2[0].name == "Charlie" + assert nodes2[0].email == "charlie.doe@example.com" + + +@mark_async_test +async def test_merge_key_with_create_or_update_multiple_keys(): + """Test merge key for create_or_update with multiple keys.""" + + # Create initial node + node1 = ( + await MergeKeyTestNode.create( + { + "name": "John", + "email": "john@example.com", + "age": 30, + "department": "Engineering", + } + ) + )[0] + + # Update with same name and email (both keys match) + nodes = await MergeKeyTestNode.create_or_update( + { + "name": "John", + "email": "john@example.com", + "age": 31, + "department": "Management", + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id + assert nodes[0].age == 31 + assert nodes[0].department == "Management" + + # Test with only one key matching (should create new node) + nodes2 = await MergeKeyTestNode.create_or_update( + { + "name": "John", + "email": "john.doe@example.com", # Different email + "age": 32, + "department": "Sales", + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes2) == 1 + assert nodes2[0].element_id != node1.element_id # Should be a new node + assert nodes2[0].name == "John" + assert nodes2[0].email == "john.doe@example.com" diff --git a/test/async_/test_cardinality.py b/test/async_/test_cardinality.py index c79e6d99..4b0729db 100644 --- a/test/async_/test_cardinality.py +++ b/test/async_/test_cardinality.py @@ -17,7 +17,7 @@ IntegerProperty, StringProperty, adb, - config, + get_config, ) @@ -67,6 +67,11 @@ class Company(AsyncStructuredNode): class Employee(AsyncStructuredNode): name = StringProperty(required=True) employer = AsyncRelationshipFrom("Company", "EMPLOYS", cardinality=AsyncZeroOrOne) + offices = AsyncRelationshipFrom("Office", "HOSTS", cardinality=AsyncOneOrMore) + + +class Office(AsyncStructuredNode): + name = StringProperty(required=True) class Manager(AsyncStructuredNode): @@ -120,9 +125,15 @@ async def test_cardinality_zero_or_one(): assert single_driver.version == 1 j = await ScrewDriver(version=2).save() - with raises(AttemptedCardinalityViolation): + with raises(AttemptedCardinalityViolation) as exc_info: await m.driver.connect(j) + error_message = str(exc_info.value) + assert ( + f"Node already has zero or one relationship in a outgoing direction of type HAS_SCREWDRIVER on node ({m.element_id}) of class 'Monkey'. Use reconnect() to replace the existing relationship." + == error_message + ) + await m.driver.reconnect(h, j) single_driver = await m.driver.single() assert single_driver.version == 2 @@ -161,9 +172,12 @@ async def test_cardinality_one_or_more(): cars = await m.car.all() assert len(cars) == 1 - with raises(AttemptedCardinalityViolation): + with raises(AttemptedCardinalityViolation) as exc_info: await m.car.disconnect(c) + error_message = str(exc_info.value) + assert "One or more expected" == error_message + d = await Car(version=3).save() await m.car.connect(d) cars = await m.car.all() @@ -173,6 +187,11 @@ async def test_cardinality_one_or_more(): cars = await m.car.all() assert len(cars) == 1 + with raises(AttemptedCardinalityViolation): + await m.car.disconnect_all() + + assert await m.car.single() is not None + @mark_async_test async def test_cardinality_one(): @@ -192,9 +211,15 @@ async def test_cardinality_one(): assert single_toothbrush.name == "Jim" x = await ToothBrush(name="Jim").save() - with raises(AttemptedCardinalityViolation): + with raises(AttemptedCardinalityViolation) as exc_info: await m.toothbrush.connect(x) + error_message = str(exc_info.value) + assert ( + f"Node already has one relationship in a outgoing direction of type HAS_TOOTHBRUSH on node ({m.element_id}) of class 'Monkey'. Use reconnect() to replace the existing relationship." + == error_message + ) + with raises(AttemptedCardinalityViolation): await m.toothbrush.disconnect(b) @@ -230,7 +255,8 @@ async def test_relationship_from_one_cardinality_enforced(): were not being enforced. """ # Setup - config.SOFT_INVERSE_CARDINALITY_CHECK = False + config = get_config() + config.soft_cardinality_check = False owner1 = await Owner(name="Alice").save() owner2 = await Owner(name="Bob").save() pet = await Pet(name="Fluffy").save() @@ -248,14 +274,15 @@ async def test_relationship_from_one_cardinality_enforced(): stream = io.StringIO() with patch("sys.stdout", new=stream): - config.SOFT_INVERSE_CARDINALITY_CHECK = True + config.soft_cardinality_check = True await owner2.pets.connect(pet) assert pet in await owner2.pets.all() console_output = stream.getvalue() assert "Cardinality violation detected" in console_output assert "Soft check is enabled so the relationship will be created" in console_output - assert "strict check will be enabled by default in version 6.0" in console_output + + config.soft_cardinality_check = False @mark_async_test @@ -264,7 +291,8 @@ async def test_relationship_from_zero_or_one_cardinality_enforced(): Test that RelationshipFrom with cardinality=ZeroOrOne prevents multiple connections. """ # Setup - config.SOFT_INVERSE_CARDINALITY_CHECK = False + config = get_config() + config.soft_cardinality_check = False company1 = await Company(name="TechCorp").save() company2 = await Company(name="StartupInc").save() employee = await Employee(name="John").save() @@ -282,14 +310,38 @@ async def test_relationship_from_zero_or_one_cardinality_enforced(): stream = io.StringIO() with patch("sys.stdout", new=stream): - config.SOFT_INVERSE_CARDINALITY_CHECK = True + config.soft_cardinality_check = True await company2.employees.connect(employee) assert employee in await company2.employees.all() console_output = stream.getvalue() assert "Cardinality violation detected" in console_output assert "Soft check is enabled so the relationship will be created" in console_output - assert "strict check will be enabled by default in version 6.0" in console_output + + config.soft_cardinality_check = False + + +@mark_async_test +async def test_relationship_from_one_or_more_cardinality_enforced(): + """ + Test that RelationshipFrom with cardinality=OneOrMore prevents disconnecting all nodes. + """ + # Setup + config = get_config() + config.soft_cardinality_check = False + office = await Office(name="Headquarters").save() + employee = await Employee(name="John").save() + await employee.offices.connect(office) + + with raises(AttemptedCardinalityViolation): + await employee.offices.disconnect(office) + + with raises(AttemptedCardinalityViolation): + await employee.offices.disconnect_all() + + assert await employee.offices.single() is not None + + config.soft_cardinality_check = False @mark_async_test @@ -298,7 +350,8 @@ async def test_bidirectional_cardinality_validation(): Test that cardinality is validated on both ends when both sides have constraints. """ # Setup - config.SOFT_INVERSE_CARDINALITY_CHECK = False + config = get_config() + config.soft_cardinality_check = False manager1 = await Manager(name="Sarah").save() manager2 = await Manager(name="David").save() assistant = await Assistant(name="Alex").save() @@ -316,11 +369,12 @@ async def test_bidirectional_cardinality_validation(): stream = io.StringIO() with patch("sys.stdout", new=stream): - config.SOFT_INVERSE_CARDINALITY_CHECK = True + config.soft_cardinality_check = True await manager2.assistant.connect(assistant) assert assistant in await manager2.assistant.all() console_output = stream.getvalue() assert "Cardinality violation detected" in console_output assert "Soft check is enabled so the relationship will be created" in console_output - assert "strict check will be enabled by default in version 6.0" in console_output + + config.soft_cardinality_check = False diff --git a/test/async_/test_connection.py b/test/async_/test_connection.py index fa9e14b0..da0ed43d 100644 --- a/test/async_/test_connection.py +++ b/test/async_/test_connection.py @@ -1,26 +1,38 @@ import os -from test._async_compat import mark_async_test +from test._async_compat import ( + mark_async_function_auto_fixture, + mark_async_session_auto_fixture, + mark_async_test, +) from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME import pytest from neo4j import AsyncDriver, AsyncGraphDatabase from neo4j.debug import watch -from neomodel import AsyncStructuredNode, StringProperty, adb, config +from neomodel import AsyncStructuredNode, StringProperty, adb, get_config -@mark_async_test -@pytest.fixture(autouse=True) -async def setup_teardown(): +@mark_async_function_auto_fixture +async def setup_teardown(request): yield # Teardown actions after tests have run # Reconnect to initial URL for potential subsequent tests - await adb.close_connection() - await adb.set_connection(url=config.DATABASE_URL) + # Skip reconnection for Aura tests except bolt+ssc parameter + should_reconnect = True + if ( + "test_connect_to_aura" in request.node.name + and "bolt+ssc" not in request.node.name + ): + should_reconnect = False + + if should_reconnect: + await adb.close_connection() + await adb.set_connection(url=get_config().database_url) -@pytest.fixture(autouse=True, scope="session") -def neo4j_logging(): +@mark_async_session_auto_fixture +async def neo4j_logging(): with watch("neo4j"): yield @@ -69,12 +81,13 @@ async def test_config_driver_works(): NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) ) - config.DRIVER = driver + config = get_config() + config.driver = driver assert await Pastry(name="Grignette").save() # Clear config # No need to close connection - pytest teardown will do it - config.DRIVER = None + config.driver = None @mark_async_test @@ -85,17 +98,18 @@ async def test_connect_to_non_default_database(): await adb.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") await adb.close_connection() + config = get_config() # Set database name in url - for url init only - await adb.set_connection(url=f"{config.DATABASE_URL}/{database_name}") + await adb.set_connection(url=f"{config.database_url}/{database_name}") assert await get_current_database_name() == "pastries" await adb.close_connection() # Set database name in config - for both url and driver init - config.DATABASE_NAME = database_name + config.database_name = database_name # url init - await adb.set_connection(url=config.DATABASE_URL) + await adb.set_connection(url=config.database_url) assert await get_current_database_name() == "pastries" await adb.close_connection() @@ -110,7 +124,7 @@ async def test_connect_to_non_default_database(): # Clear config # No need to close connection - pytest teardown will do it - config.DATABASE_NAME = None + config.database_name = None @mark_async_test @@ -156,9 +170,9 @@ async def test_connect_to_aura(protocol): async def _set_connection(protocol): - AURA_TEST_DB_USER = os.environ["AURA_TEST_DB_USER"] - AURA_TEST_DB_PASSWORD = os.environ["AURA_TEST_DB_PASSWORD"] - AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"] + aura_test_db_user = os.environ["AURA_TEST_DB_USER"] + aura_test_db_password = os.environ["AURA_TEST_DB_PASSWORD"] + aura_test_db_hostname = os.environ["AURA_TEST_DB_HOSTNAME"] - database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}" + database_url = f"{protocol}://{aura_test_db_user}:{aura_test_db_password}@{aura_test_db_hostname}" await adb.set_connection(url=database_url) diff --git a/test/async_/test_core_additional.py b/test/async_/test_core_additional.py new file mode 100644 index 00000000..1b1ba4f4 --- /dev/null +++ b/test/async_/test_core_additional.py @@ -0,0 +1,248 @@ +""" +Additional tests for neomodel.async_.core module to improve coverage. +""" + +from test._async_compat import mark_async_test +from unittest.mock import AsyncMock, patch + +import pytest +from neo4j.exceptions import ClientError + +from neomodel.async_.database import AsyncDatabase, ensure_connection +from neomodel.async_.transaction import AsyncTransactionProxy + + +@mark_async_test +async def test_ensure_connection_decorator_no_driver(): + """Test ensure_connection decorator when driver is None.""" + + class MockDB: + def __init__(self): + self.driver = None + + async def set_connection(self, **kwargs): + # Dummy implementation for testing + pass + + @ensure_connection + async def test_method(self): + return "success" + + test_db = MockDB() + with patch.object( + test_db, "set_connection", new_callable=AsyncMock + ) as mock_set_connection: + result = await test_db.test_method() + assert result == "success" + mock_set_connection.assert_called_once_with( + url="bolt://neo4j:foobarbaz@localhost:7687" + ) + + +@mark_async_test +async def test_ensure_connection_decorator_with_driver(): + """Test ensure_connection decorator when driver is set.""" + + class MockDB: + def __init__(self): + self.driver = "existing_driver" + + @ensure_connection + async def test_method(self): + return "success" + + test_db = MockDB() + result = await test_db.test_method() + assert result == "success" + + +@mark_async_test +async def test_clear_neo4j_database(): + """Test clear_neo4j_database method.""" + test_db = AsyncDatabase() + + with patch.object(test_db, "cypher_query", new_callable=AsyncMock) as mock_cypher: + with patch.object( + test_db, "drop_constraints", new_callable=AsyncMock + ) as mock_drop_constraints: + with patch.object( + test_db, "drop_indexes", new_callable=AsyncMock + ) as mock_drop_indexes: + await test_db.clear_neo4j_database( + clear_constraints=True, clear_indexes=True + ) + + mock_cypher.assert_called_once() + mock_drop_constraints.assert_called_once() + mock_drop_indexes.assert_called_once() + + +@mark_async_test +async def test_drop_constraints(): + """Test drop_constraints method.""" + test_db = AsyncDatabase() + + mock_results = [ + {"name": "constraint1", "labelsOrTypes": ["Label1"], "properties": ["prop1"]}, + {"name": "constraint2", "labelsOrTypes": ["Label2"], "properties": ["prop2"]}, + ] + + with patch.object(test_db, "cypher_query", new_callable=AsyncMock) as mock_cypher: + mock_cypher.return_value = ( + mock_results, + ["name", "labelsOrTypes", "properties"], + ) + + await test_db.drop_constraints(quiet=False) + + # Should call LIST_CONSTRAINTS_COMMAND and DROP_CONSTRAINT_COMMAND for each constraint + assert mock_cypher.call_count == 3 # 1 for list + 2 for drop + + +@mark_async_test +async def test_drop_indexes(): + """Test drop_indexes method.""" + test_db = AsyncDatabase() + + mock_indexes = [ + {"name": "index1", "labelsOrTypes": ["Label1"], "properties": ["prop1"]}, + {"name": "index2", "labelsOrTypes": ["Label2"], "properties": ["prop2"]}, + ] + + with patch.object( + test_db, "list_indexes", new_callable=AsyncMock + ) as mock_list_indexes: + mock_list_indexes.return_value = mock_indexes + + with patch.object( + test_db, "cypher_query", new_callable=AsyncMock + ) as mock_cypher: + await test_db.drop_indexes(quiet=False) + + # Should call DROP_INDEX_COMMAND for each index + assert mock_cypher.call_count == 2 + + +@mark_async_test +async def test_remove_all_labels(): + """Test remove_all_labels method.""" + test_db = AsyncDatabase() + + with patch.object( + test_db, "drop_constraints", new_callable=AsyncMock + ) as mock_drop_constraints: + with patch.object( + test_db, "drop_indexes", new_callable=AsyncMock + ) as mock_drop_indexes: + with patch("sys.stdout") as mock_stdout: + await test_db.remove_all_labels() + + mock_drop_constraints.assert_called_once_with( + quiet=False, stdout=mock_stdout + ) + mock_drop_indexes.assert_called_once_with( + quiet=False, stdout=mock_stdout + ) + + +@mark_async_test +async def test_install_all_labels(): + """Test install_all_labels method.""" + test_db = AsyncDatabase() + + class MockNode: + def __init__(self, name): + self.__name__ = name + + @classmethod + async def install_labels(cls, quiet=True, stdout=None): + pass + + with patch("neomodel.async_.node.AsyncStructuredNode", MockNode): + with patch("sys.stdout"): + await test_db.install_all_labels() + + # Should call install_labels on each node class + assert True # Test passes if no exception is raised + + +@mark_async_test +async def test_proxy_aenter_parallel_runtime_warning(): + """Test AsyncTransactionProxy __aenter__ with parallel runtime warning.""" + test_db = AsyncDatabase() + proxy = AsyncTransactionProxy(test_db, parallel_runtime=True) + + with patch.object( + test_db, "parallel_runtime_available", new_callable=AsyncMock + ) as mock_available: + mock_available.return_value = False + + with patch("warnings.warn") as mock_warn: + with patch.object(test_db, "begin", new_callable=AsyncMock) as mock_begin: + await proxy.__aenter__() + + mock_warn.assert_called_once() + mock_begin.assert_called_once() + + +@mark_async_test +async def test_proxy_aexit_with_exception(): + """Test AsyncTransactionProxy __aexit__ with exception.""" + test_db = AsyncDatabase() + proxy = AsyncTransactionProxy(test_db) + + with patch.object(test_db, "rollback", new_callable=AsyncMock) as mock_rollback: + with patch.object(test_db, "commit", new_callable=AsyncMock) as mock_commit: + # Test with exception + await proxy.__aexit__(ValueError, ValueError("test"), None) + mock_rollback.assert_called_once() + mock_commit.assert_not_called() + + +@mark_async_test +async def test_proxy_aexit_success(): + """Test AsyncTransactionProxy __aexit__ with success.""" + test_db = AsyncDatabase() + proxy = AsyncTransactionProxy(test_db) + + with patch.object(test_db, "rollback", new_callable=AsyncMock) as mock_rollback: + with patch.object(test_db, "commit", new_callable=AsyncMock) as mock_commit: + mock_commit.return_value = "bookmarks" + + await proxy.__aexit__(None, None, None) + mock_rollback.assert_not_called() + mock_commit.assert_called_once() + assert proxy.last_bookmarks == "bookmarks" + + +@mark_async_test +async def test_proxy_call_decorator(): + """Test AsyncTransactionProxy __call__ decorator.""" + test_db = AsyncDatabase() + proxy = AsyncTransactionProxy(test_db) + + async def test_func(): + return "success" + + decorated = proxy(test_func) + assert callable(decorated) + + # Test that the decorated function works + with patch.object(proxy, "__aenter__", new_callable=AsyncMock) as mock_enter: + with patch.object(proxy, "__aexit__", new_callable=AsyncMock): + mock_enter.return_value = proxy + result = await decorated() + assert result == "success" + + +@mark_async_test +async def test_cypher_query_client_error_generic(): + """Test cypher_query with generic ClientError.""" + test_db = AsyncDatabase() + + with patch.object(test_db, "_run_cypher_query", new_callable=AsyncMock) as mock_run: + client_error = ClientError("Neo.ClientError.Generic", "message") + mock_run.side_effect = client_error + + with pytest.raises(ClientError): + await test_db.cypher_query("MATCH (n) RETURN n") diff --git a/test/async_/test_database_management.py b/test/async_/test_database_management.py index 5159642a..592613b4 100644 --- a/test/async_/test_database_management.py +++ b/test/async_/test_database_management.py @@ -1,5 +1,7 @@ +import asyncio from test._async_compat import mark_async_test +import neo4j import pytest from neo4j.exceptions import AuthError @@ -11,6 +13,8 @@ StringProperty, adb, ) +from neomodel._async_compat.util import AsyncUtil +from neomodel.async_.database import AsyncDatabase class City(AsyncStructuredNode): @@ -79,3 +83,76 @@ async def test_change_password(): await adb.close_connection() await adb.set_connection(url=prev_url) + + +@mark_async_test +async def test_adb_singleton_behavior(): + """Test that AsyncDatabase enforces singleton behavior.""" + + # Get the module-level instance + adb1 = AsyncDatabase.get_instance() + + # Try to create another instance directly + adb2 = AsyncDatabase() + + # Try to create another instance via get_instance + adb3 = AsyncDatabase.get_instance() + + # All instances should be the same object + assert adb1 is adb2, "Direct instantiation should return the same instance" + assert adb1 is adb3, "get_instance should return the same instance" + assert adb2 is adb3, "All instances should be the same object" + + # Test that the module-level 'adb' is also the same instance + assert adb is adb1, "Module-level 'adb' should be the same instance" + + +@mark_async_test +async def test_async_database_properties(): + # A fresh instance of AsyncDatabase is not yet connected + await AsyncDatabase.reset_instance() + reset_singleton = AsyncDatabase.get_instance() + assert reset_singleton._active_transaction is None + assert reset_singleton.url is None + assert reset_singleton.driver is None + assert reset_singleton._session is None + assert reset_singleton._pid is None + assert reset_singleton._database_name is neo4j.DEFAULT_DATABASE + assert reset_singleton._database_version is None + assert reset_singleton._database_edition is None + assert reset_singleton.impersonated_user is None + assert reset_singleton._parallel_runtime is False + + +@mark_async_test +async def test_parallel_transactions(): + if not AsyncUtil.is_async_code: + pytest.skip("Async only test") + + transactions = set() + sessions = set() + + async def query(i: int): + await asyncio.sleep(0.05) + + assert adb._active_transaction is None + assert adb._session is None + + async with adb.transaction: + # ensure transaction and session are unique for async context + transaction_id = id(adb._active_transaction) + assert transaction_id not in transactions + transactions.add(transaction_id) + + session_id = id(adb._session) + assert session_id not in sessions + sessions.add(session_id) + + result, _ = await adb.cypher_query( + "CALL apoc.util.sleep($delay_ms) RETURN $task_id as task_id, $delay_ms as slept", + {"delay_ms": i * 505, "task_id": i}, + ) + + return result[0][0], result[0][1], transaction_id, session_id + + _ = await asyncio.gather(*(query(i) for i in range(1, 5))) diff --git a/test/async_/test_fulltextfilter.py b/test/async_/test_fulltextfilter.py new file mode 100644 index 00000000..a992ee4e --- /dev/null +++ b/test/async_/test_fulltextfilter.py @@ -0,0 +1,326 @@ +from datetime import datetime +from test._async_compat import mark_async_test + +import pytest + +from neomodel import ( + AsyncRelationshipFrom, + AsyncStructuredNode, + AsyncStructuredRel, + DateTimeProperty, + FloatProperty, + StringProperty, + FulltextIndex, + adb, +) +from neomodel.semantic_filters import FulltextFilter + + +@mark_async_test +async def test_base_fulltextfilter(): + """ + Tests that the fulltextquery is run, node and score are returned. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNode(AsyncStructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + other = StringProperty() + + await adb.install_labels(fulltextNode) + + node1 = await fulltextNode(other="thing", description="Another thing").save() + + node2 = await fulltextNode( + other="other thing", description="Another other thing" + ).save() + + fulltextNodeSearch = fulltextNode.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ) + ) + + result = await fulltextNodeSearch.all() + assert all(isinstance(x[0], fulltextNode) for x in result) + assert all(isinstance(x[1], float) for x in result) + + +@mark_async_test +async def test_fulltextfilter_topk_works(): + """ + Tests that the topk filter works. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNodetopk(AsyncStructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + await adb.install_labels(fulltextNodetopk) + + node1 = await fulltextNodetopk(description="this description").save() + node2 = await fulltextNodetopk(description="that description").save() + node3 = await fulltextNodetopk(description="my description").save() + + fulltextNodeSearch = fulltextNodetopk.nodes.filter( + fulltext_filter=FulltextFilter( + topk=2, fulltext_attribute_name="description", query_string="description" + ) + ) + + result = await fulltextNodeSearch.all() + assert len(result) == 2 + +@mark_async_test +async def test_fulltextfilter_with_node_propertyfilter(): + """ + Tests that the fulltext query is run, and "thing" node is only node returned. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNodeBis(AsyncStructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + other = StringProperty() + + await adb.install_labels(fulltextNodeBis) + + node1 = await fulltextNodeBis(other="thing", description="Another thing").save() + + node2 = await fulltextNodeBis( + other="other thing", description="Another other thing" + ).save() + + fulltextFilterforthing = fulltextNodeBis.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ), + other="thing", + ) + + result = await fulltextFilterforthing.all() + + assert len(result) == 1 + assert all(isinstance(x[0], fulltextNodeBis) for x in result) + assert result[0][0].other == "thing" + assert all(isinstance(x[1], float) for x in result) + + +@mark_async_test +async def test_dont_duplicate_fulltext_filter_node(): + """ + Tests the situation that another node has the same filter value. + Testing that we are only performing the fulltextfilter and metadata filter on the right nodes. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNodeTer(AsyncStructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + name = StringProperty() + + class otherfulltextNodeTer(AsyncStructuredNode): + other_description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + other_name = StringProperty() + + await adb.install_labels(fulltextNodeTer) + await adb.install_labels(otherfulltextNodeTer) + + node1 = await fulltextNodeTer(name="John", description="thing one").save() + node2 = await fulltextNodeTer(name="Fred", description="thing two").save() + node3 = await otherfulltextNodeTer(name="John", description="thing three").save() + node4 = await otherfulltextNodeTer(name="Fred", description="thing four").save() + + john_fulltext_search = fulltextNodeTer.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ), + name="John", + ) + + result = await john_fulltext_search.all() + + assert len(result) == 1 + assert isinstance(result[0][0], fulltextNodeTer) + assert result[0][0].name == "John" + assert isinstance(result[0][1], float) + + +@mark_async_test +async def test_django_filter_w_fulltext_filter(): + """ + Tests that django filters still work with the fulltext filter. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextDjangoNode(AsyncStructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + name = StringProperty() + number = FloatProperty() + + await adb.install_labels(fulltextDjangoNode) + + nodeone = await fulltextDjangoNode( + name="John", description="thing one", number=float(10) + ).save() + + nodetwo = await fulltextDjangoNode( + name="Fred", description="thing two", number=float(3) + ).save() + + fulltext_index_with_django_filter = fulltextDjangoNode.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ), + number__gt=5, + ) + + result = await fulltext_index_with_django_filter.all() + assert len(result) == 1 + assert isinstance(result[0][0], fulltextDjangoNode) + assert result[0][0].number > 5 + + +@mark_async_test +async def test_fulltextfilter_with_relationshipfilter(): + """ + Tests that by filtering on fulltext similarity and then peforming a relationshipfilter works. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class SupplierFT(AsyncStructuredNode): + name = StringProperty() + + class SuppliesFTRel(AsyncStructuredRel): + since = DateTimeProperty(default=datetime.now) + + class ProductFT(AsyncStructuredNode): + name = StringProperty() + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + suppliers = AsyncRelationshipFrom(SupplierFT, "SUPPLIES", model=SuppliesFTRel) + + await adb.install_labels(SupplierFT) + await adb.install_labels(SuppliesFTRel) + await adb.install_labels(ProductFT) + + supplier1 = await SupplierFT(name="Supplier 1").save() + supplier2 = await SupplierFT(name="Supplier 2").save() + product1 = await ProductFT( + name="Product A", + description="High quality product", + ).save() + product2 = await ProductFT( + name="Product B", + description="Very High quality product", + ).save() + await product1.suppliers.connect(supplier1) + await product1.suppliers.connect(supplier2) + await product2.suppliers.connect(supplier1) + + filtered_product = ProductFT.nodes.filter( + fulltext_filter=FulltextFilter( + topk=1, fulltext_attribute_name="description", query_string="product" + ), + suppliers__name="Supplier 1", + ) + + result = await filtered_product.all() + + assert len(result) == 1 + assert isinstance(result[0][0], ProductFT) + assert isinstance(result[0][1], SupplierFT) + assert isinstance(result[0][2], SuppliesFTRel) + + +@mark_async_test +async def test_fulltextfiler_nonexistent_attribute(): + """ + Tests that AttributeError is raised when fulltext_attribute_name doesn't exist on the source. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + + class TestNodeWithFT(AsyncStructuredNode): + name = StringProperty() + fulltext = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + + await adb.install_labels(TestNodeWithFT) + + with pytest.raises( + AttributeError, match="Atribute 'nonexistent_fulltext' not found" + ): + nodeset = TestNodeWithFT.nodes.filter( + fulltext_filter=FulltextFilter( + topk=1, + fulltext_attribute_name="nonexistent_fulltext", + query_string="something", + ) + ) + await nodeset.all() + + +@mark_async_test +async def test_fulltextfiler_no_fulltext_index(): + """ + Tests that AttributeError is raised when fulltext_attribute_name doesn't exist on the source. + """ + + if not await adb.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class TestNodeWithoutFT(AsyncStructuredNode): + name = StringProperty() + fulltext = StringProperty() # No fulltext_index + + await adb.install_labels(TestNodeWithoutFT) + + with pytest.raises(AttributeError, match="is not declared with a full text index"): + nodeset = TestNodeWithoutFT.nodes.filter( + fulltext_filter=FulltextFilter( + topk=1, fulltext_attribute_name="fulltext", query_string="something" + ) + ) + await nodeset.all() diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index ddbd6808..3ee29f18 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -5,10 +5,11 @@ More information about the same issue at: https://github.com/aanastasiou/neomodelInheritanceTest -The following example uses a recursive relationship for economy, but the -idea remains the same: "Instantiate the correct type of node at the end of +The following example uses a recursive relationship for economy, but the +idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ + import random from test._async_compat import mark_async_test @@ -123,56 +124,6 @@ async def test_automatic_result_resolution(): assert type((await A.friends_with)[0]) is TechnicalPerson -@mark_async_test -async def test_recursive_automatic_result_resolution(): - """ - Node objects are instantiated to native Python objects, both at the top - level of returned results and in the case where they are returned within - lists. - """ - - # Create a few entities - A = ( - await TechnicalPerson.get_or_create( - {"name": "Grumpier", "expertise": "Grumpiness"} - ) - )[0] - B = ( - await TechnicalPerson.get_or_create( - {"name": "Happier", "expertise": "Grumpiness"} - ) - )[0] - C = ( - await TechnicalPerson.get_or_create( - {"name": "Sleepier", "expertise": "Pillows"} - ) - )[0] - D = ( - await TechnicalPerson.get_or_create( - {"name": "Sneezier", "expertise": "Pillows"} - ) - )[0] - - # Retrieve mixed results, both at the top level and nested - L, _ = await adb.cypher_query( - "MATCH (a:TechnicalPerson) " - "WHERE a.expertise='Grumpiness' " - "WITH collect(a) as Alpha " - "MATCH (b:TechnicalPerson) " - "WHERE b.expertise='Pillows' " - "WITH Alpha, collect(b) as Beta " - "RETURN [Alpha, [Beta, [Beta, ['Banana', " - "Alpha]]]]", - resolve_objects=True, - ) - - # Assert that a Node returned deep in a nested list structure is of the - # correct type - assert type(L[0][0][0][1][0][0][0][0]) is TechnicalPerson - # Assert that primitive data types remain primitive data types - assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - - @mark_async_test async def test_validation_with_inheritance_from_db(): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 07e5d446..cc27e194 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1,11 +1,11 @@ import re from datetime import datetime from test._async_compat import mark_async_test +from unittest.mock import AsyncMock, MagicMock from pytest import raises, skip, warns from neomodel import ( - INCOMING, ArrayProperty, AsyncRelationshipFrom, AsyncRelationshipTo, @@ -34,6 +34,7 @@ Size, ) from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined +from neomodel.util import RelationshipDirection class SupplierRel(AsyncStructuredRel): @@ -426,7 +427,7 @@ def test_traversal_definition_keys_are_valid(): "a_name", { "node_class": Supplier, - "direction": INCOMING, + "direction": RelationshipDirection.INCOMING, "relationship_type": "KNOWS", "model": None, }, @@ -437,7 +438,7 @@ def test_traversal_definition_keys_are_valid(): "a_name", { "node_class": Supplier, - "direction": INCOMING, + "direction": RelationshipDirection.INCOMING, "relation_type": "KNOWS", "model": None, }, @@ -548,7 +549,7 @@ async def test_q_filters(): robusta = await Species(name="Robusta").save() await c4.species.connect(robusta) latte_or_robusta_coffee = ( - await Coffee.nodes.fetch_relations(Optional("species")) + await Coffee.nodes.traverse(Path(value="species", optional=True)) .filter(Q(name="Latte") | Q(species__name="Robusta")) .all() ) @@ -557,7 +558,7 @@ async def test_q_filters(): arabica = await Species(name="Arabica").save() await c1.species.connect(arabica) robusta_coffee = ( - await Coffee.nodes.fetch_relations(Optional("species")) + await Coffee.nodes.traverse(Path(value="species", optional=True)) .filter(species__name="Robusta") .all() ) @@ -683,78 +684,17 @@ async def test_relation_prop_ordering(): await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) await nescafe.species.connect(arabica) - results = ( - await Supplier.nodes.fetch_relations("coffees").order_by("-coffees|since").all() - ) + results = await Supplier.nodes.traverse("coffees").order_by("-coffees|since").all() assert len(results) == 2 assert results[0][0] == supplier1 assert results[1][0] == supplier2 - results = ( - await Supplier.nodes.fetch_relations("coffees").order_by("coffees|since").all() - ) + results = await Supplier.nodes.traverse("coffees").order_by("coffees|since").all() assert len(results) == 2 assert results[0][0] == supplier2 assert results[1][0] == supplier1 -@mark_async_test -async def test_fetch_relations(): - arabica = await Species(name="Arabica").save() - robusta = await Species(name="Robusta").save() - nescafe = await Coffee(name="Nescafe", price=99).save() - nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() - - tesco = await Supplier(name="Tesco", delivery_cost=3).save() - await nescafe.suppliers.connect(tesco) - await nescafe_gold.suppliers.connect(tesco) - await nescafe.species.connect(arabica) - - result = ( - await Supplier.nodes.filter(name="Tesco") - .fetch_relations("coffees__species") - .all() - ) - assert len(result[0]) == 5 - assert arabica in result[0] - assert robusta not in result[0] - assert tesco in result[0] - assert nescafe in result[0] - assert nescafe_gold not in result[0] - - result = ( - await Species.nodes.filter(name="Robusta") - .fetch_relations(Optional("coffees__suppliers")) - .all() - ) - assert len(result) == 1 - - if AsyncUtil.is_async_code: - count = ( - await Supplier.nodes.filter(name="Tesco") - .fetch_relations("coffees__species") - .get_len() - ) - assert count == 1 - - assert ( - await Supplier.nodes.fetch_relations("coffees__species") - .filter(name="Tesco") - .check_contains(tesco) - ) - else: - count = len( - Supplier.nodes.filter(name="Tesco") - .fetch_relations("coffees__species") - .all() - ) - assert count == 1 - - assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( - name="Tesco" - ) - - @mark_async_test async def test_traverse(): arabica = await Species(name="Arabica").save() @@ -820,9 +760,7 @@ async def test_traverse_and_order_by(): await nescafe.species.connect(arabica) await nescafe_gold.species.connect(robusta) - results = ( - await Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all() - ) + results = await Species.nodes.traverse("coffees").order_by("-coffees__price").all() assert len(results) == 2 assert len(results[0]) == 3 # 2 nodes and 1 relation assert results[0][0] == robusta @@ -844,36 +782,66 @@ async def test_annotate_and_collect(): await nescafe_gold.species.connect(arabica) result = ( - await Supplier.nodes.traverse_relations(species="coffees__species") + await Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(Collect("species")) .all() ) assert len(result) == 1 - assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + assert len(result[0][1]) == 3 # 3 species must be there (with 2 duplicates) result = ( - await Supplier.nodes.traverse_relations(species="coffees__species") + await Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( - await Supplier.nodes.traverse_relations(species="coffees__species") + await Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(Size(Collect("species", distinct=True))) .all() ) assert result[0][1] == 2 # 2 species result = ( - await Supplier.nodes.traverse_relations(species="coffees__species") + await Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(all_species=Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( - await Supplier.nodes.traverse_relations("coffees__species") + await Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate( all_species=Collect(NodeNameResolver("coffees__species"), distinct=True), all_species_rels=Collect( @@ -882,8 +850,8 @@ async def test_annotate_and_collect(): ) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there - assert len(result[0][2][0]) == 3 # 3 species relations must be there + assert len(result[0][1]) == 2 # 2 species must be there + assert len(result[0][2]) == 3 # 3 species relations must be there @mark_async_test @@ -902,22 +870,12 @@ async def test_resolve_subgraph(): with raises( RuntimeError, match=re.escape( - "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + "Nothing to resolve. Make sure to include relations in the result using traverse() or filter()." ), ): result = await Supplier.nodes.resolve_subgraph() - with raises( - NotImplementedError, - match=re.escape( - "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." - ), - ): - result = await Supplier.nodes.traverse_relations( - "coffees__species" - ).resolve_subgraph() - - result = await Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() + result = await Supplier.nodes.traverse("coffees__species").resolve_subgraph() assert len(result) == 2 assert hasattr(result[0], "_relations") @@ -944,8 +902,8 @@ async def test_resolve_subgraph_optional(): await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) - result = await Supplier.nodes.fetch_relations( - Optional("coffees__species") + result = await Supplier.nodes.traverse( + Path(value="coffees__species", optional=True) ).resolve_subgraph() assert len(result) == 1 @@ -969,7 +927,7 @@ async def test_subquery(): await nescafe.species.connect(arabica) subquery = await Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers") + Coffee.nodes.traverse(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] ) @@ -987,16 +945,14 @@ async def test_subquery(): match=re.escape("Variable 'unknown' is not returned by subquery."), ): result = await Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + Coffee.nodes.traverse(suppliers="suppliers").annotate( supps=Collect("suppliers") ), ["unknown"], ) result_string_context = await subquery.subquery( - Coffee.nodes.traverse_relations(supps2="suppliers").annotate( - supps2=Collect("supps") - ), + Coffee.nodes.traverse(supps2="suppliers").annotate(supps2=Collect("supps")), ["supps2"], ["supps"], ) @@ -1010,7 +966,7 @@ async def test_subquery(): with raises(ValueError, match=r"Wrong variable specified in initial context"): result = await Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + Coffee.nodes.traverse(suppliers="suppliers").annotate( supps=Collect("suppliers") ), ["supps"], @@ -1058,7 +1014,7 @@ async def test_intermediate_transform(): await nescafe.species.connect(arabica) result = ( - await Coffee.nodes.fetch_relations("suppliers") + await Coffee.nodes.traverse("suppliers") .intermediate_transform( { "coffee": {"source": "coffee", "include_in_return": True}, @@ -1086,7 +1042,7 @@ async def test_intermediate_transform(): r"Wrong source type specified for variable 'test', should be a string or an instance of NodeNameResolver or RelationNameResolver" ), ): - Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( + Coffee.nodes.traverse(suppliers="suppliers").intermediate_transform( { "test": {"source": Collect("suppliers")}, } @@ -1097,9 +1053,7 @@ async def test_intermediate_transform(): r"You must provide one variable at least when calling intermediate_transform()" ), ): - Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( - {} - ) + Coffee.nodes.traverse(suppliers="suppliers").intermediate_transform({}) @mark_async_test @@ -1152,12 +1106,12 @@ async def test_mix_functions(): full_nodeset = ( await Student.nodes.filter(name__istartswith="m", lives_in__name="Eiffel Tower") .order_by("name") - .fetch_relations( + .traverse( "parents", - Optional("children__preferred_course"), + Path(value="children__preferred_course", optional=True), ) .subquery( - Student.nodes.fetch_relations("courses") + Student.nodes.traverse("courses") .intermediate_transform( {"rel": {"source": RelationNameResolver("courses")}}, ordering=[ @@ -1205,9 +1159,9 @@ async def test_issue_795(): with raises( RelationshipClassNotDefined, - match=r"[\s\S]*Note that when using the fetch_relations method, the relationship type must be defined in the model.*", + match=r"[\s\S]*Note that when using the traverse method, the relationship type must be defined in the model.*", ): - _ = await PersonX.nodes.fetch_relations("country").all() + _ = await PersonX.nodes.traverse("country").all() @mark_async_test @@ -1245,7 +1199,7 @@ async def test_unique_variables(): await gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) await gold3000.species.connect(arabica) - nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter( + nodeset = Supplier.nodes.traverse("coffees", "coffees__species").filter( coffees__name="Nescafe" ) ast = await nodeset.query_cls(nodeset).build_ast() @@ -1258,7 +1212,7 @@ async def test_unique_variables(): assert len(results) == 3 nodeset = ( - Supplier.nodes.fetch_relations("coffees", "coffees__species") + Supplier.nodes.traverse("coffees", "coffees__species") .filter(coffees__name="Nescafe") .unique_variables("coffees") ) @@ -1307,6 +1261,20 @@ def assert_last_query_startswith(mock_func, query) -> bool: return mock_func.call_args_list[-1].kwargs["query"].startswith(query) +def create_mock_async_result(): + """Create a mock AsyncResult that behaves like neo4j.AsyncResult""" + mock_result = MagicMock() + mock_result.keys.return_value = () + + # Create an async iterator that yields empty records + async def async_iter(self): + return + yield # This makes it an async generator + + mock_result.__aiter__ = async_iter + return mock_result + + @mark_async_test async def test_parallel_runtime(mocker): if ( @@ -1322,7 +1290,11 @@ async def test_parallel_runtime(mocker): # Mock transaction.run to access executed query # Assert query starts with CYPHER runtime=parallel assert adb._parallel_runtime == True - mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") + mock_transaction_run = mocker.patch( + "neo4j.AsyncTransaction.run", + new_callable=AsyncMock, + return_value=create_mock_async_result(), + ) await adb.cypher_query("MATCH (n:Coffee) RETURN n") assert assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" @@ -1332,7 +1304,11 @@ async def test_parallel_runtime(mocker): # Parallel should be applied to neomodel queries async with adb.parallel_read_transaction: - mock_transaction_run_2 = mocker.patch("neo4j.AsyncTransaction.run") + mock_transaction_run_2 = mocker.patch( + "neo4j.AsyncTransaction.run", + new_callable=AsyncMock, + return_value=create_mock_async_result(), + ) await Coffee.nodes.all() assert assert_last_query_startswith( mock_transaction_run_2, "CYPHER runtime=parallel" @@ -1345,7 +1321,11 @@ async def test_parallel_runtime_conflict(mocker): skip("Test for unavailable parallel runtime.") assert not await adb.parallel_runtime_available() - mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") + mock_transaction_run = mocker.patch( + "neo4j.AsyncTransaction.run", + new_callable=AsyncMock, + return_value=create_mock_async_result(), + ) with warns( UserWarning, match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", diff --git a/test/async_/test_object_resolution.py b/test/async_/test_object_resolution.py new file mode 100644 index 00000000..2eff1daf --- /dev/null +++ b/test/async_/test_object_resolution.py @@ -0,0 +1,554 @@ +""" +Test cases for object resolution with resolve_objects=True in raw Cypher queries. + +This test file covers various scenarios for automatic class resolution, +including the issues identified in GitHub issues #905 and #906: +- Issue #905: Nested lists in results of raw Cypher queries with collect keyword +- Issue #906: Automatic class resolution for raw queries with nodes nested in maps + +Additional scenarios tested: +- Basic object resolution +- Nested structures (lists, maps, mixed) +- Path resolution +- Relationship resolution +- Complex nested scenarios with collect() and other Cypher functions +""" + +from test._async_compat import mark_async_test + +from neomodel import ( + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + StringProperty, + adb, +) + + +class ResolutionRelationship(AsyncStructuredRel): + """Test relationship with properties.""" + + weight = IntegerProperty(default=1) + description = StringProperty(default="test") + + +class ResolutionNode(AsyncStructuredNode): + """Base test node class.""" + + name = StringProperty(required=True) + value = IntegerProperty(default=0) + related = AsyncRelationshipTo( + "ResolutionNode", "RELATED_TO", model=ResolutionRelationship + ) + + +class ResolutionSpecialNode(AsyncStructuredNode): + """Specialized test node class.""" + + name = StringProperty(required=True) + special_value = IntegerProperty(default=42) + related = AsyncRelationshipTo( + ResolutionNode, "RELATED_TO", model=ResolutionRelationship + ) + + +class ResolutionContainerNode(AsyncStructuredNode): + """Container node for testing nested structures.""" + + name = StringProperty(required=True) + items = AsyncRelationshipTo( + ResolutionNode, "CONTAINS", model=ResolutionRelationship + ) + + +@mark_async_test +async def test_basic_object_resolution(): + """Test basic object resolution for nodes and relationships.""" + # Create test data + await ResolutionNode(name="Node1", value=10).save() + await ResolutionNode(name="Node2", value=20).save() + + # Test basic node resolution + results, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "Node1"}, + resolve_objects=True, + ) + + assert len(results) == 1 + assert len(results[0]) == 1 + resolved_node = results[0][0] + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "Node1" + assert resolved_node.value == 10 + + +@mark_async_test +async def test_relationship_resolution(): + """Test relationship resolution in queries.""" + # Create test data with relationships + node1 = await ResolutionNode(name="Source", value=100).save() + node2 = await ResolutionNode(name="Target", value=200).save() + + # Create relationship + await node1.related.connect(node2, {"weight": 5, "description": "test_rel"}) + + # Test relationship resolution + results, _ = await adb.cypher_query( + "MATCH (a:ResolutionNode)-[r:RELATED_TO]->(b:ResolutionNode) RETURN a, r, b", + resolve_objects=True, + ) + + assert len(results) == 1 + source, rel, target = results[0] + + assert isinstance(source, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert isinstance(target, ResolutionNode) + + assert source.name == "Source" + assert target.name == "Target" + assert rel.weight == 5 + assert rel.description == "test_rel" + + +@mark_async_test +async def test_path_resolution(): + """Test path resolution in queries.""" + # Create test data + node1 = await ResolutionNode(name="Start", value=1).save() + node2 = await ResolutionNode(name="Middle", value=2).save() + node3 = await ResolutionNode(name="End", value=3).save() + + # Create path + await node1.related.connect(node2, {"weight": 1}) + await node2.related.connect(node3, {"weight": 2}) + + # Test path resolution + results, _ = await adb.cypher_query( + "MATCH p=(a:ResolutionNode)-[:RELATED_TO*2]->(c:ResolutionNode) RETURN p", + resolve_objects=True, + ) + + assert len(results) == 1 + path = results[0][0] + + # Path should be resolved to AsyncNeomodelPath + from neomodel.async_.path import AsyncNeomodelPath + + assert isinstance(path, AsyncNeomodelPath) + assert len(path._nodes) == 3 # pylint: disable=protected-access + assert len(path._relationships) == 2 # pylint: disable=protected-access + + +@mark_async_test +async def test_nested_lists_basic(): + """Test basic nested list resolution (Issue #905 - basic case).""" + # Create test data + nodes = [] + for i in range(3): + node = await ResolutionNode(name=f"Node{i}", value=i * 10).save() + nodes.append(node) + + # Test nested list resolution + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + RETURN collect(n) as nodes + """, + resolve_objects=True, + ) + + assert len(results) == 1 + collected_nodes = results[0][0] + + assert isinstance(collected_nodes, list) + assert len(collected_nodes) == 3 + + for i, node in enumerate(collected_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Node{i}" + assert node.value == i * 10 + + +@mark_async_test +async def test_nested_lists_complex(): + """Test complex nested list resolution with collect() (Issue #905 - complex case).""" + # Create test data with relationships + container = await ResolutionContainerNode(name="Container").save() + items = [] + for i in range(2): + item = await ResolutionNode(name=f"Item{i}", value=i * 5).save() + items.append(item) + await container.items.connect(item, {"weight": i + 1}) + + # Test complex nested list with collect + results, _ = await adb.cypher_query( + """ + MATCH (c:ResolutionContainerNode)-[r:CONTAINS]->(i:ResolutionNode) + WITH c, r, i ORDER BY i.name + WITH c, collect({item: i, rel: r}) as items + RETURN c, items + """, + resolve_objects=True, + ) + + assert len(results) == 1 + container_result, items_result = results[0] + + assert isinstance(container_result, ResolutionContainerNode) + assert container_result.name == "Container" + + assert isinstance(items_result, list) + assert len(items_result) == 2 + + for i, item_data in enumerate(items_result): + assert isinstance(item_data, dict) + assert "item" in item_data + assert "rel" in item_data + + item = item_data["item"] + rel = item_data["rel"] + + assert isinstance(item, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert item.name == f"Item{i}" + assert rel.weight == i + 1 + + +@mark_async_test +async def test_nodes_nested_in_maps(): + """Test nodes nested in maps (Issue #906).""" + # Create test data + await ResolutionNode(name="Node1", value=100).save() + await ResolutionNode(name="Node2", value=200).save() + + # Test nodes nested in maps + results, _ = await adb.cypher_query( + """ + MATCH (n1:ResolutionNode), (n2:ResolutionNode) + WHERE n1.name = 'Node1' AND n2.name = 'Node2' + RETURN { + first: n1, + second: n2, + metadata: { + count: 2, + description: 'test map' + } + } as result_map + """, + resolve_objects=True, + ) + + assert len(results) == 1 + result_map = results[0][0] + + assert isinstance(result_map, dict) + assert "first" in result_map + assert "second" in result_map + assert "metadata" in result_map + + # Check that nodes are properly resolved + first_node = result_map["first"] + second_node = result_map["second"] + + assert isinstance(first_node, ResolutionNode) + assert isinstance(second_node, ResolutionNode) + assert first_node.name == "Node1" + assert second_node.name == "Node2" + + # Check metadata (should remain as primitive types) + metadata = result_map["metadata"] + assert isinstance(metadata, dict) + assert metadata["count"] == 2 + assert metadata["description"] == "test map" + + +@mark_async_test +async def test_mixed_nested_structures(): + """Test mixed nested structures with lists, maps, and nodes.""" + # Create test data + special = await ResolutionSpecialNode(name="Special", special_value=999).save() + test_nodes = [] + for i in range(2): + node = await ResolutionNode(name=f"Test{i}", value=i * 100).save() + test_nodes.append(node) + await special.related.connect(node, {"weight": i + 10}) + + # Test complex mixed structure + results, _ = await adb.cypher_query( + """ + MATCH (s:ResolutionSpecialNode)-[r:RELATED_TO]->(t:ResolutionNode) + WITH s, r, t ORDER BY t.name + WITH s, collect({node: t, rel: r}) as related_items + RETURN { + special_node: s, + related: related_items, + summary: { + total_relations: size(related_items), + node_names: [item in related_items | item.node.name] + } + } as complex_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + complex_result = results[0][0] + + assert isinstance(complex_result, dict) + assert "special_node" in complex_result + assert "related" in complex_result + assert "summary" in complex_result + + # Check special node resolution + special_node = complex_result["special_node"] + assert isinstance(special_node, ResolutionSpecialNode) + assert special_node.name == "Special" + assert special_node.special_value == 999 + + # Check related items (list of dicts with nodes and relationships) + related = complex_result["related"] + assert isinstance(related, list) + assert len(related) == 2 + + for i, item in enumerate(related): + assert isinstance(item, dict) + assert "node" in item + assert "rel" in item + + node = item["node"] + rel = item["rel"] + + assert isinstance(node, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert node.name == f"Test{i}" + assert rel.weight == i + 10 + + # Check summary (should remain as primitive types) + summary = complex_result["summary"] + assert isinstance(summary, dict) + assert summary["total_relations"] == 2 + assert isinstance(summary["node_names"], list) + assert summary["node_names"] == ["Test0", "Test1"] + + +@mark_async_test +async def test_deeply_nested_structures(): + """Test deeply nested structures to ensure recursive resolution works.""" + # Create test data + nodes = [] + for i in range(3): + node = await ResolutionNode(name=f"Deep{i}", value=i * 50).save() + nodes.append(node) + + # Test deeply nested structure + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + WITH collect(n) as level1 + RETURN { + level1: level1, + level2: { + nodes: level1, + metadata: { + level3: { + count: size(level1), + items: level1 + } + } + } + } as deep_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + deep_result = results[0][0] + + assert isinstance(deep_result, dict) + assert "level1" in deep_result + assert "level2" in deep_result + + # Check level1 (direct list of nodes) + level1 = deep_result["level1"] + assert isinstance(level1, list) + assert len(level1) == 3 + for i, node in enumerate(level1): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check level2 (nested structure) + level2 = deep_result["level2"] + assert isinstance(level2, dict) + assert "nodes" in level2 + assert "metadata" in level2 + + # Check nodes in level2 + level2_nodes = level2["nodes"] + assert isinstance(level2_nodes, list) + assert len(level2_nodes) == 3 + for i, node in enumerate(level2_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check metadata in level2 + metadata = level2["metadata"] + assert isinstance(metadata, dict) + assert "level3" in metadata + + level3 = metadata["level3"] + assert isinstance(level3, dict) + assert "count" in level3 + assert "items" in level3 + + assert level3["count"] == 3 + level3_items = level3["items"] + assert isinstance(level3_items, list) + assert len(level3_items) == 3 + for i, node in enumerate(level3_items): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + +@mark_async_test +async def test_collect_with_aggregation(): + """Test collect() with aggregation functions.""" + # Create test data + for i in range(5): + node = await ResolutionNode(name=f"AggNode{i}", value=i * 10).save() + + # Test collect with aggregation + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) + WHERE n.name STARTS WITH 'Agg' + WITH n ORDER BY n.name + WITH collect(n) as all_nodes + RETURN { + nodes: all_nodes, + count: size(all_nodes), + total_value: reduce(total = 0, n in all_nodes | total + n.value), + names: [n in all_nodes | n.name] + } as aggregated_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + aggregated_result = results[0][0] + + assert isinstance(aggregated_result, dict) + assert "nodes" in aggregated_result + assert "count" in aggregated_result + assert "total_value" in aggregated_result + assert "names" in aggregated_result + + # Check nodes are resolved + nodes = aggregated_result["nodes"] + assert isinstance(nodes, list) + assert len(nodes) == 5 + for i, node in enumerate(nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"AggNode{i}" + assert node.value == i * 10 + + # Check aggregated values + assert aggregated_result["count"] == 5 + assert aggregated_result["total_value"] == 100 # 0+10+20+30+40 + assert aggregated_result["names"] == [ + "AggNode0", + "AggNode1", + "AggNode2", + "AggNode3", + "AggNode4", + ] + + +@mark_async_test +async def test_resolve_objects_false_comparison(): + """Test that resolve_objects=False returns raw Neo4j objects.""" + # Create test data + await ResolutionNode(name="RawNode", value=123).save() + + # Test with resolve_objects=False + results_false, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=False, + ) + + # Test with resolve_objects=True + results_true, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=True, + ) + + # Compare results + raw_node = results_false[0][0] + resolved_node = results_true[0][0] + + # Raw node should be a Neo4j Node object + from neo4j.graph import Node + + assert isinstance(raw_node, Node) + assert raw_node["name"] == "RawNode" + assert raw_node["value"] == 123 + + # Resolved node should be a ResolutionNode instance + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "RawNode" + assert resolved_node.value == 123 + + +@mark_async_test +async def test_empty_results(): + """Test object resolution with empty results.""" + # Test empty results + results, _ = await adb.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = 'NonExistent' RETURN n", + resolve_objects=True, + ) + + assert len(results) == 0 + + +@mark_async_test +async def test_primitive_types_preserved(): + """Test that primitive types are preserved during object resolution.""" + # Create test data + await ResolutionNode(name="PrimitiveTest", value=456).save() + + # Test with mixed primitive and node types + results, _ = await adb.cypher_query( + """ + MATCH (n:ResolutionNode) WHERE n.name = $name + RETURN n, n.value as int_val, n.name as str_val, true as bool_val, 3.14 as float_val + """, + {"name": "PrimitiveTest"}, + resolve_objects=True, + ) + + assert len(results) == 1 + node_result, int_val, str_val, bool_val, float_val = results[0] + + # Node should be resolved + assert isinstance(node_result, ResolutionNode) + assert node_result.name == "PrimitiveTest" + + # Primitives should remain primitive + assert isinstance(int_val, int) + assert int_val == 456 + + assert isinstance(str_val, str) + assert str_val == "PrimitiveTest" + + assert isinstance(bool_val, bool) + assert bool_val is True + + assert isinstance(float_val, float) + assert float_val == 3.14 diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 4a0704a0..1b32d2bd 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -1,4 +1,4 @@ -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, timezone from test._async_compat import mark_async_test from zoneinfo import ZoneInfo @@ -10,7 +10,7 @@ AsyncStructuredNode, AsyncStructuredRel, adb, - config, + get_config, ) from neomodel.contrib import AsyncSemiStructuredNode from neomodel.exceptions import ( @@ -137,7 +137,7 @@ def test_datetimes_timezones(): prop = DateTimeProperty() prop.name = "foo" prop.owner = FooBar - t = datetime.utcnow() + t = datetime.now(timezone.utc) gr = ZoneInfo("Europe/Athens") gb = ZoneInfo("Europe/London") dt1 = t.replace(tzinfo=gr) @@ -153,12 +153,13 @@ def test_datetimes_timezones(): default_now=True, default=datetime(1900, 1, 1, 0, 0, 0) ) - prev_force_timezone = config.FORCE_TIMEZONE - config.FORCE_TIMEZONE = True + config = get_config() + prev_force_timezone = config.force_timezone + config.force_timezone = True with raises(ValueError, match=r".*No timezone provided."): prop.deflate(datetime.now()) - config.FORCE_TIMEZONE = prev_force_timezone + config.force_timezone = prev_force_timezone def test_date(): @@ -379,28 +380,26 @@ class DefaultTestValueThree(AsyncStructuredNode): assert x.uid == "123" -class TestDBNamePropertyRel(AsyncStructuredRel): +class DBNamePropertyRel(AsyncStructuredRel): known_for = StringProperty(db_property="knownFor") # This must be defined outside of the test, otherwise the `Relationship` definition cannot look up -# `TestDBNamePropertyNode` -class TestDBNamePropertyNode(AsyncStructuredNode): +# `DBNamePropertyNode` +class DBNamePropertyNode(AsyncStructuredNode): name_ = StringProperty(db_property="name") - knows = AsyncRelationship( - "TestDBNamePropertyNode", "KNOWS", model=TestDBNamePropertyRel - ) + knows = AsyncRelationship("DBNamePropertyNode", "KNOWS", model=DBNamePropertyRel) @mark_async_test async def test_independent_property_name(): # -- test node -- - x = TestDBNamePropertyNode() + x = DBNamePropertyNode() x.name_ = "jim" await x.save() # check database property name on low level - results, meta = await adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") + results, meta = await adb.cypher_query("MATCH (n:DBNamePropertyNode) RETURN n") node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties @@ -408,10 +407,10 @@ async def test_independent_property_name(): # check python class property name at a high level assert not hasattr(x, "name") assert hasattr(x, "name_") - assert (await TestDBNamePropertyNode.nodes.filter(name_="jim").all())[ + assert (await DBNamePropertyNode.nodes.filter(name_="jim").all())[ 0 ].name_ == x.name_ - assert (await TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ + assert (await DBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ # -- test relationship -- @@ -421,7 +420,7 @@ async def test_independent_property_name(): # check database property name on low level results, meta = await adb.cypher_query( - "MATCH (:TestDBNamePropertyNode)-[r:KNOWS]->(:TestDBNamePropertyNode) RETURN r" + "MATCH (:DBNamePropertyNode)-[r:KNOWS]->(:DBNamePropertyNode) RETURN r" ) rel_properties = get_graph_entity_properties(results[0][0]) assert rel_properties["knownFor"] == "10 years" @@ -436,15 +435,15 @@ async def test_independent_property_name(): @mark_async_test async def test_independent_property_name_for_semi_structured(): - class TestDBNamePropertySemiStructuredNode(AsyncSemiStructuredNode): + class DBNamePropertySemiStructuredNode(AsyncSemiStructuredNode): title_ = StringProperty(db_property="title") - semi = TestDBNamePropertySemiStructuredNode(title_="sir", extra="data") + semi = DBNamePropertySemiStructuredNode(title_="sir", extra="data") await semi.save() # check database property name on low level results, meta = await adb.cypher_query( - "MATCH (n:TestDBNamePropertySemiStructuredNode) RETURN n" + "MATCH (n:DBNamePropertySemiStructuredNode) RETURN n" ) node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["title"] == "sir" @@ -456,12 +455,12 @@ class TestDBNamePropertySemiStructuredNode(AsyncSemiStructuredNode): assert not hasattr(semi, "title") assert hasattr(semi, "extra") from_filter = ( - await TestDBNamePropertySemiStructuredNode.nodes.filter(title_="sir").all() + await DBNamePropertySemiStructuredNode.nodes.filter(title_="sir").all() )[0] assert from_filter.title_ == "sir" # assert not hasattr(from_filter, "title") assert from_filter.extra == "data" - from_get = await TestDBNamePropertySemiStructuredNode.nodes.get(title_="sir") + from_get = await DBNamePropertySemiStructuredNode.nodes.get(title_="sir") assert from_get.title_ == "sir" # assert not hasattr(from_get, "title") assert from_get.extra == "data" diff --git a/test/async_/test_registry.py b/test/async_/test_registry.py index aac8fa64..c1b6a9be 100644 --- a/test/async_/test_registry.py +++ b/test/async_/test_registry.py @@ -2,21 +2,8 @@ from pytest import raises, skip -from neomodel import ( - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncStructuredRel, - DateProperty, - IntegerProperty, - StringProperty, - adb, - config, -) -from neomodel.exceptions import ( - NodeClassAlreadyDefined, - NodeClassNotDefined, - RelationshipClassRedefined, -) +from neomodel import AsyncStructuredNode, StringProperty, adb, get_config +from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined @mark_async_test @@ -63,9 +50,10 @@ class PatientOneBis(AsyncStructuredNode): await PatientOneBis(name="patient1.2").save() + config = get_config() # Now, we will test object resolution await adb.close_connection() - await adb.set_connection(url=f"{config.DATABASE_URL}/{db_one}") + await adb.set_connection(url=f"{config.database_url}/{db_one}") await adb.clear_neo4j_database() patient1 = await PatientOne(name="patient1").save() patients, _ = await adb.cypher_query( @@ -75,7 +63,7 @@ class PatientOneBis(AsyncStructuredNode): assert patients[0][0] == patient1 await adb.close_connection() - await adb.set_connection(url=f"{config.DATABASE_URL}/{db_two}") + await adb.set_connection(url=f"{config.database_url}/{db_two}") await adb.clear_neo4j_database() patient2 = await PatientTwo(identifier="patient2").save() patients, _ = await adb.cypher_query( @@ -84,7 +72,7 @@ class PatientOneBis(AsyncStructuredNode): assert patients[0][0] == patient2 await adb.close_connection() - await adb.set_connection(url=config.DATABASE_URL) + await adb.set_connection(url=config.database_url) @mark_async_test diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index de7a13e5..95128c12 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -127,14 +127,14 @@ async def test_bookmark_transaction_decorator(): async def test_bookmark_transaction_as_a_context(): async with adb.transaction as transaction: await APerson(name="Tanya").save() - assert isinstance(transaction.last_bookmark, Bookmarks) + assert isinstance(transaction.last_bookmarks, Bookmarks) assert await APerson.nodes.filter(name="Tanya") with raises(UniqueProperty): async with adb.transaction as transaction: await APerson(name="Tanya").save() - assert not hasattr(transaction, "last_bookmark") + assert transaction.last_bookmarks is None @pytest.fixture @@ -157,14 +157,14 @@ async def test_bookmark_passed_in_to_context(spy_on_db_begin): pass assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) - last_bookmark = transaction.last_bookmark + last_bookmarks = transaction.last_bookmarks - transaction.bookmarks = last_bookmark + transaction.bookmarks = last_bookmarks async with transaction: pass assert spy_on_db_begin[-1] == ( (), - {"access_mode": None, "bookmarks": last_bookmark}, + {"access_mode": None, "bookmarks": last_bookmarks}, ) @@ -176,4 +176,4 @@ async def test_query_inside_bookmark_transaction(): assert len([p.name for p in await APerson.nodes]) == 2 - assert isinstance(transaction.last_bookmark, Bookmarks) + assert isinstance(transaction.last_bookmarks, Bookmarks) diff --git a/test/async_/test_vectorfilter.py b/test/async_/test_vectorfilter.py index 05f03f59..cc99e03e 100644 --- a/test/async_/test_vectorfilter.py +++ b/test/async_/test_vectorfilter.py @@ -184,7 +184,7 @@ class ProductV(AsyncStructuredNode): description_embedding = ArrayProperty( FloatProperty(), vector_index=VectorIndex(dimensions=2) ) - suppliers = AsyncRelationshipFrom(SupplierV, "SUPPLIES", model=SuppliesVRel) + suppliers = AsyncRelationshipFrom(SupplierV, "SUPPLIESV", model=SuppliesVRel) await adb.install_labels(SupplierV) await adb.install_labels(SuppliesVRel) diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py index cbe38140..73dbfe21 100644 --- a/test/sync_/conftest.py +++ b/test/sync_/conftest.py @@ -5,7 +5,7 @@ mark_sync_session_auto_fixture, ) -from neomodel import config, db +from neomodel import db, get_config @mark_sync_session_auto_fixture @@ -19,7 +19,7 @@ def setup_neo4j_session(request): warnings.simplefilter("default") - config.DATABASE_URL = os.environ.get( + get_config().database_url = os.environ.get( "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" ) @@ -51,5 +51,8 @@ def setup_neo4j_session(request): @mark_async_function_auto_fixture def setUp(): + db.set_connection(url=get_config().database_url) db.cypher_query("MATCH (n) DETACH DELETE n") yield + + db.close_connection() diff --git a/test/sync_/test_batch.py b/test/sync_/test_batch.py index 80812d31..52b8e1e6 100644 --- a/test/sync_/test_batch.py +++ b/test/sync_/test_batch.py @@ -9,13 +9,11 @@ StringProperty, StructuredNode, UniqueIdProperty, - config, + db, ) from neomodel._async_compat.util import Util from neomodel.exceptions import DeflateError, UniqueProperty -config.AUTO_INSTALL_LABELS = True - class UniqueUser(StructuredNode): uid = UniqueIdProperty() @@ -132,3 +130,448 @@ def test_get_or_create_with_rel(): # not the same gizmo assert bobs_gizmo[0] != tims_gizmo[0] + + +class NodeWithDefaultProp(StructuredNode): + name = StringProperty(required=True) + age = IntegerProperty(default=30) + other_prop = StringProperty() + + +@mark_sync_test +def test_get_or_create_with_ignored_properties(): + node = NodeWithDefaultProp.get_or_create({"name": "Tania", "age": 20}) + assert node[0].name == "Tania" + assert node[0].age == 20 + + node = NodeWithDefaultProp.get_or_create({"name": "Tania"}) + assert node[0].name == "Tania" + assert node[0].age == 20 # Tania was fetched and not created + + node = NodeWithDefaultProp.get_or_create({"name": "Tania", "age": 30}) + assert node[0].name == "Tania" + assert node[0].age == 20 # Tania was fetched and not created + + +@mark_sync_test +def test_create_or_update_with_ignored_properties(): + node = NodeWithDefaultProp.create_or_update({"name": "Tania", "age": 20}) + assert node[0].name == "Tania" + assert node[0].age == 20 + + node = NodeWithDefaultProp.create_or_update( + {"name": "Tania", "other_prop": "other"} + ) + assert node[0].name == "Tania" + assert ( + node[0].age == 20 + ) # Tania is still 20 even though default says she should be 30 + assert ( + node[0].other_prop == "other" + ) # She does have a brand new other_prop, lucky her ! + + node = NodeWithDefaultProp.create_or_update( + {"name": "Tania", "age": 30, "other_prop": "other2"} + ) + assert node[0].name == "Tania" + assert node[0].age == 30 # Tania is now 30, congrats Tania ! + assert ( + node[0].other_prop == "other2" + ) # Plus she has a new other_prop - as a birthday gift ? + + +@mark_sync_test +def test_lazy_mode(): + """Test lazy mode functionality.""" + + node1 = (NodeWithDefaultProp.create({"name": "Tania", "age": 20}))[0] + node = NodeWithDefaultProp.get_or_create({"name": "Tania", "age": 20}, lazy=True) + if db.version_is_higher_than("5.0.0"): + assert node[0] == node1.element_id + else: + assert node[0] == node1.id + + node = NodeWithDefaultProp.create_or_update({"name": "Tania", "age": 25}, lazy=True) + if db.version_is_higher_than("5.0.0"): + assert node[0] == node1.element_id + else: + assert node[0] == node1.id + + +class MergeKeyTestNode(StructuredNode): + """Test node for merge key functionality tests.""" + + name = StringProperty(required=True) + email = StringProperty(required=True, unique_index=True) + age = IntegerProperty() + department = StringProperty() + + +class MergeKeyChildTestNode(MergeKeyTestNode): + """Test node for merge key functionality tests with an extra label.""" + + +@mark_sync_test +def test_default_merge_behavior(): + """Test default merge behavior using required properties.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + {"name": "John", "email": "john@example.com", "age": 30} + ) + )[0] + + # Update with same name and email (should update existing) + nodes = MergeKeyTestNode.create_or_update( + {"name": "John", "email": "john@example.com", "age": 31} + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id + assert nodes[0].age == 31 + + # Verify the node was updated correctly + assert nodes[0].name == "John" + assert nodes[0].email == "john@example.com" + assert nodes[0].age == 31 + + +@mark_sync_test +def test_custom_merge_key_email(): + """Test custom merge key using email only.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + {"name": "Jane", "email": "jane@example.com", "age": 25} + ) + )[0] + + # Update with custom merge key (email only) + nodes = MergeKeyTestNode.create_or_update( + { + "name": "Jane Doe", # Different name + "email": "jane@example.com", # Same email + "age": 26, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node should be the same + assert nodes[0].name == "Jane Doe" # Name should be updated + assert nodes[0].age == 26 # Age should be updated + + +@mark_sync_test +def test_merge_key_unspecified_label(): + """Test merge key with unspecified label.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + {"name": "Jane", "email": "jane@example.com", "age": 25} + ) + )[0] + + # Update with custom merge key (email only) + nodes = MergeKeyTestNode.create_or_update( + { + "name": "Jane Doe", # Different name + "email": "jane@example.com", # Same email + "age": 26, + }, + merge_by={"keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node should be the same + assert nodes[0].name == "Jane Doe" # Name should be updated + assert nodes[0].age == 26 # Age should be updated + + +@mark_sync_test +def test_get_or_create_with_merge_key(): + """Test get_or_create with custom merge key.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + {"name": "Alice", "email": "alice@example.com", "age": 28} + ) + )[0] + + # Use get_or_create with custom merge key + nodes = MergeKeyTestNode.get_or_create( + { + "name": "Alice Smith", # Different name + "email": "alice@example.com", # Same email + "age": 29, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node was fetched and not created + assert nodes[0].name == "Alice" # Name should be the same + assert nodes[0].age == 28 + + +@mark_sync_test +def test_merge_key_create_new_node(): + """Test that merge key creates new node when no match is found.""" + + # Create node with merge key that won't match anything + nodes = MergeKeyTestNode.create_or_update( + {"name": "New User", "email": "new@example.com", "age": 30}, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].name == "New User" + assert nodes[0].email == "new@example.com" + assert nodes[0].age == 30 + + # Verify the node was created correctly + assert nodes[0].name == "New User" + assert nodes[0].email == "new@example.com" + assert nodes[0].age == 30 + + +@mark_sync_test +def test_get_or_create_with_different_label(): + """Test merge key with different label specification.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create({"name": "Eve", "email": "eve@example.com", "age": 27}) + )[0] + node2 = ( + MergeKeyChildTestNode.create( + {"name": "Eve Child", "email": "evechild@example.com", "age": 27} + ) + )[0] + + # Use merge key with explicit label - child one + nodes = MergeKeyTestNode.get_or_create( + { + "name": "Eve Child", # Different name + "email": "evechild@example.com", # Same email + "age": 27, + }, + merge_by={"label": "MergeKeyChildTestNode", "keys": ["age"]}, + ) + + assert len(nodes) == 1 # Only the node with child label + assert nodes[0].element_id == node2.element_id # Node was fetched and not created + assert nodes[0].name == "Eve Child" + + # Use merge key with explicit label - parent one + nodes = MergeKeyTestNode.get_or_create( + { + "name": "Eve Child", # Different name + "email": "evechild@example.com", # Same email + "age": 27, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["age"]}, + ) + + assert len(nodes) == 2 # Both nodes were fetched + element_ids = [node.element_id for node in nodes] + assert node1.element_id in element_ids + assert node2.element_id in element_ids + + +@mark_sync_test +def test_multiple_merge_operations(): + """Test multiple merge operations with different keys.""" + + # Create initial nodes + node1 = ( + MergeKeyTestNode.create( + {"name": "Frank", "email": "frank@example.com", "age": 45} + ) + )[0] + node2 = ( + MergeKeyTestNode.create( + {"name": "Grace", "email": "grace@example.com", "age": 38} + ) + )[0] + + # Update Frank by email + nodes1 = MergeKeyTestNode.create_or_update( + {"name": "Franklin", "email": "frank@example.com", "age": 46}, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + # Update Grace by name + nodes2 = MergeKeyTestNode.create_or_update( + {"name": "Grace", "email": "grace.new@example.com", "age": 39}, + merge_by={"label": "MergeKeyTestNode", "keys": ["name"]}, + ) + + assert len(nodes1) == 1 + assert len(nodes2) == 1 + assert nodes1[0].element_id == node1.element_id + assert nodes2[0].element_id == node2.element_id + assert nodes1[0].name == "Franklin" + assert nodes2[0].email == "grace.new@example.com" + + # Verify both nodes were updated correctly + assert nodes1[0].name == "Franklin" + assert nodes2[0].email == "grace.new@example.com" + + +@mark_sync_test +def test_merge_key_lazy_mode(): + """Test merge key functionality with lazy mode.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + {"name": "Diana", "email": "diana@example.com", "age": 32} + ) + )[0] + + # Test with lazy mode + nodes = MergeKeyTestNode.create_or_update( + {"name": "Diana Prince", "email": "diana@example.com", "age": 33}, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + lazy=True, + ) + + assert len(nodes) == 1 + # In lazy mode, we should get the element_id back + if db.version_is_higher_than("5.0.0"): + assert nodes[0] == node1.element_id + else: + assert nodes[0] == node1.id + + +@mark_sync_test +def test_merge_key_with_multiple_properties(): + """Test merge key with a property that has multiple values.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + { + "name": "Multi", + "email": "multi@example.com", + "age": 25, + "department": "Engineering", + } + ) + )[0] + + # Update with different department but same email + nodes = MergeKeyTestNode.create_or_update( + { + "name": "Multi Updated", + "email": "multi@example.com", + "age": 26, + "department": "Management", + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id # Node should be the same + assert nodes[0].name == "Multi Updated" + assert nodes[0].department == "Management" + assert nodes[0].age == 26 + + +@mark_sync_test +def test_merge_key_with_get_or_create_multiple_keys(): + """Test merge key for get_or_create with multiple keys.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + { + "name": "Charlie", + "email": "charlie@example.com", + "age": 35, + } + ) + )[0] + + # Use get_or_create with multiple keys + nodes = MergeKeyTestNode.get_or_create( + { + "name": "Charlie", + "email": "charlie@example.com", + "age": 36, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id + assert nodes[0].age == 35 + + # Try to get_or_create with different keys (should create new) + nodes2 = MergeKeyTestNode.get_or_create( + { + "name": "Charlie", + "email": "charlie.doe@example.com", # Different email + "age": 37, + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes2) == 1 + assert nodes2[0].element_id != node1.element_id + assert nodes2[0].name == "Charlie" + assert nodes2[0].email == "charlie.doe@example.com" + + +@mark_sync_test +def test_merge_key_with_create_or_update_multiple_keys(): + """Test merge key for create_or_update with multiple keys.""" + + # Create initial node + node1 = ( + MergeKeyTestNode.create( + { + "name": "John", + "email": "john@example.com", + "age": 30, + "department": "Engineering", + } + ) + )[0] + + # Update with same name and email (both keys match) + nodes = MergeKeyTestNode.create_or_update( + { + "name": "John", + "email": "john@example.com", + "age": 31, + "department": "Management", + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes) == 1 + assert nodes[0].element_id == node1.element_id + assert nodes[0].age == 31 + assert nodes[0].department == "Management" + + # Test with only one key matching (should create new node) + nodes2 = MergeKeyTestNode.create_or_update( + { + "name": "John", + "email": "john.doe@example.com", # Different email + "age": 32, + "department": "Sales", + }, + merge_by={"label": "MergeKeyTestNode", "keys": ["name", "email"]}, + ) + + assert len(nodes2) == 1 + assert nodes2[0].element_id != node1.element_id # Should be a new node + assert nodes2[0].name == "John" + assert nodes2[0].email == "john.doe@example.com" diff --git a/test/sync_/test_cardinality.py b/test/sync_/test_cardinality.py index 77014671..a6067e9c 100644 --- a/test/sync_/test_cardinality.py +++ b/test/sync_/test_cardinality.py @@ -16,8 +16,8 @@ StructuredNode, ZeroOrMore, ZeroOrOne, - config, db, + get_config, ) @@ -63,6 +63,11 @@ class Company(StructuredNode): class Employee(StructuredNode): name = StringProperty(required=True) employer = RelationshipFrom("Company", "EMPLOYS", cardinality=ZeroOrOne) + offices = RelationshipFrom("Office", "HOSTS", cardinality=OneOrMore) + + +class Office(StructuredNode): + name = StringProperty(required=True) class Manager(StructuredNode): @@ -116,9 +121,15 @@ def test_cardinality_zero_or_one(): assert single_driver.version == 1 j = ScrewDriver(version=2).save() - with raises(AttemptedCardinalityViolation): + with raises(AttemptedCardinalityViolation) as exc_info: m.driver.connect(j) + error_message = str(exc_info.value) + assert ( + f"Node already has zero or one relationship in a outgoing direction of type HAS_SCREWDRIVER on node ({m.element_id}) of class 'Monkey'. Use reconnect() to replace the existing relationship." + == error_message + ) + m.driver.reconnect(h, j) single_driver = m.driver.single() assert single_driver.version == 2 @@ -157,9 +168,12 @@ def test_cardinality_one_or_more(): cars = m.car.all() assert len(cars) == 1 - with raises(AttemptedCardinalityViolation): + with raises(AttemptedCardinalityViolation) as exc_info: m.car.disconnect(c) + error_message = str(exc_info.value) + assert "One or more expected" == error_message + d = Car(version=3).save() m.car.connect(d) cars = m.car.all() @@ -169,6 +183,11 @@ def test_cardinality_one_or_more(): cars = m.car.all() assert len(cars) == 1 + with raises(AttemptedCardinalityViolation): + m.car.disconnect_all() + + assert m.car.single() is not None + @mark_sync_test def test_cardinality_one(): @@ -188,9 +207,15 @@ def test_cardinality_one(): assert single_toothbrush.name == "Jim" x = ToothBrush(name="Jim").save() - with raises(AttemptedCardinalityViolation): + with raises(AttemptedCardinalityViolation) as exc_info: m.toothbrush.connect(x) + error_message = str(exc_info.value) + assert ( + f"Node already has one relationship in a outgoing direction of type HAS_TOOTHBRUSH on node ({m.element_id}) of class 'Monkey'. Use reconnect() to replace the existing relationship." + == error_message + ) + with raises(AttemptedCardinalityViolation): m.toothbrush.disconnect(b) @@ -226,7 +251,8 @@ def test_relationship_from_one_cardinality_enforced(): were not being enforced. """ # Setup - config.SOFT_INVERSE_CARDINALITY_CHECK = False + config = get_config() + config.soft_cardinality_check = False owner1 = Owner(name="Alice").save() owner2 = Owner(name="Bob").save() pet = Pet(name="Fluffy").save() @@ -244,14 +270,15 @@ def test_relationship_from_one_cardinality_enforced(): stream = io.StringIO() with patch("sys.stdout", new=stream): - config.SOFT_INVERSE_CARDINALITY_CHECK = True + config.soft_cardinality_check = True owner2.pets.connect(pet) assert pet in owner2.pets.all() console_output = stream.getvalue() assert "Cardinality violation detected" in console_output assert "Soft check is enabled so the relationship will be created" in console_output - assert "strict check will be enabled by default in version 6.0" in console_output + + config.soft_cardinality_check = False @mark_sync_test @@ -260,7 +287,8 @@ def test_relationship_from_zero_or_one_cardinality_enforced(): Test that RelationshipFrom with cardinality=ZeroOrOne prevents multiple connections. """ # Setup - config.SOFT_INVERSE_CARDINALITY_CHECK = False + config = get_config() + config.soft_cardinality_check = False company1 = Company(name="TechCorp").save() company2 = Company(name="StartupInc").save() employee = Employee(name="John").save() @@ -278,14 +306,38 @@ def test_relationship_from_zero_or_one_cardinality_enforced(): stream = io.StringIO() with patch("sys.stdout", new=stream): - config.SOFT_INVERSE_CARDINALITY_CHECK = True + config.soft_cardinality_check = True company2.employees.connect(employee) assert employee in company2.employees.all() console_output = stream.getvalue() assert "Cardinality violation detected" in console_output assert "Soft check is enabled so the relationship will be created" in console_output - assert "strict check will be enabled by default in version 6.0" in console_output + + config.soft_cardinality_check = False + + +@mark_sync_test +def test_relationship_from_one_or_more_cardinality_enforced(): + """ + Test that RelationshipFrom with cardinality=OneOrMore prevents disconnecting all nodes. + """ + # Setup + config = get_config() + config.soft_cardinality_check = False + office = Office(name="Headquarters").save() + employee = Employee(name="John").save() + employee.offices.connect(office) + + with raises(AttemptedCardinalityViolation): + employee.offices.disconnect(office) + + with raises(AttemptedCardinalityViolation): + employee.offices.disconnect_all() + + assert employee.offices.single() is not None + + config.soft_cardinality_check = False @mark_sync_test @@ -294,7 +346,8 @@ def test_bidirectional_cardinality_validation(): Test that cardinality is validated on both ends when both sides have constraints. """ # Setup - config.SOFT_INVERSE_CARDINALITY_CHECK = False + config = get_config() + config.soft_cardinality_check = False manager1 = Manager(name="Sarah").save() manager2 = Manager(name="David").save() assistant = Assistant(name="Alex").save() @@ -312,11 +365,12 @@ def test_bidirectional_cardinality_validation(): stream = io.StringIO() with patch("sys.stdout", new=stream): - config.SOFT_INVERSE_CARDINALITY_CHECK = True + config.soft_cardinality_check = True manager2.assistant.connect(assistant) assert assistant in manager2.assistant.all() console_output = stream.getvalue() assert "Cardinality violation detected" in console_output assert "Soft check is enabled so the relationship will be created" in console_output - assert "strict check will be enabled by default in version 6.0" in console_output + + config.soft_cardinality_check = False diff --git a/test/sync_/test_connection.py b/test/sync_/test_connection.py index d5a19e02..164a82de 100644 --- a/test/sync_/test_connection.py +++ b/test/sync_/test_connection.py @@ -1,25 +1,37 @@ import os -from test._async_compat import mark_sync_test +from test._async_compat import ( + mark_async_function_auto_fixture, + mark_sync_session_auto_fixture, + mark_sync_test, +) from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME import pytest from neo4j import Driver, GraphDatabase from neo4j.debug import watch -from neomodel import StringProperty, StructuredNode, config, db +from neomodel import StringProperty, StructuredNode, db, get_config -@mark_sync_test -@pytest.fixture(autouse=True) -def setup_teardown(): +@mark_async_function_auto_fixture +def setup_teardown(request): yield # Teardown actions after tests have run # Reconnect to initial URL for potential subsequent tests - db.close_connection() - db.set_connection(url=config.DATABASE_URL) + # Skip reconnection for Aura tests except bolt+ssc parameter + should_reconnect = True + if ( + "test_connect_to_aura" in request.node.name + and "bolt+ssc" not in request.node.name + ): + should_reconnect = False + + if should_reconnect: + db.close_connection() + db.set_connection(url=get_config().database_url) -@pytest.fixture(autouse=True, scope="session") +@mark_sync_session_auto_fixture def neo4j_logging(): with watch("neo4j"): yield @@ -67,12 +79,13 @@ def test_config_driver_works(): NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) ) - config.DRIVER = driver + config = get_config() + config.driver = driver assert Pastry(name="Grignette").save() # Clear config # No need to close connection - pytest teardown will do it - config.DRIVER = None + config.driver = None @mark_sync_test @@ -83,17 +96,18 @@ def test_connect_to_non_default_database(): db.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") db.close_connection() + config = get_config() # Set database name in url - for url init only - db.set_connection(url=f"{config.DATABASE_URL}/{database_name}") + db.set_connection(url=f"{config.database_url}/{database_name}") assert get_current_database_name() == "pastries" db.close_connection() # Set database name in config - for both url and driver init - config.DATABASE_NAME = database_name + config.database_name = database_name # url init - db.set_connection(url=config.DATABASE_URL) + db.set_connection(url=config.database_url) assert get_current_database_name() == "pastries" db.close_connection() @@ -106,7 +120,7 @@ def test_connect_to_non_default_database(): # Clear config # No need to close connection - pytest teardown will do it - config.DATABASE_NAME = None + config.database_name = None @mark_sync_test @@ -152,9 +166,9 @@ def test_connect_to_aura(protocol): def _set_connection(protocol): - AURA_TEST_DB_USER = os.environ["AURA_TEST_DB_USER"] - AURA_TEST_DB_PASSWORD = os.environ["AURA_TEST_DB_PASSWORD"] - AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"] + aura_test_db_user = os.environ["AURA_TEST_DB_USER"] + aura_test_db_password = os.environ["AURA_TEST_DB_PASSWORD"] + aura_test_db_hostname = os.environ["AURA_TEST_DB_HOSTNAME"] - database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}" + database_url = f"{protocol}://{aura_test_db_user}:{aura_test_db_password}@{aura_test_db_hostname}" db.set_connection(url=database_url) diff --git a/test/sync_/test_core_additional.py b/test/sync_/test_core_additional.py new file mode 100644 index 00000000..4caa9e04 --- /dev/null +++ b/test/sync_/test_core_additional.py @@ -0,0 +1,242 @@ +""" +Additional tests for neomodel.sync_.core module to improve coverage. +""" + +from test._async_compat import mark_sync_test +from unittest.mock import Mock, patch + +import pytest +from neo4j.exceptions import ClientError + +from neomodel.sync_.database import Database, ensure_connection +from neomodel.sync_.transaction import TransactionProxy + + +@mark_sync_test +def test_ensure_connection_decorator_no_driver(): + """Test ensure_connection decorator when driver is None.""" + + class MockDB: + def __init__(self): + self.driver = None + + def set_connection(self, **kwargs): + # Dummy implementation for testing + pass + + @ensure_connection + def test_method(self): + return "success" + + test_db = MockDB() + with patch.object( + test_db, "set_connection", new_callable=Mock + ) as mock_set_connection: + result = test_db.test_method() + assert result == "success" + mock_set_connection.assert_called_once_with( + url="bolt://neo4j:foobarbaz@localhost:7687" + ) + + +@mark_sync_test +def test_ensure_connection_decorator_with_driver(): + """Test ensure_connection decorator when driver is set.""" + + class MockDB: + def __init__(self): + self.driver = "existing_driver" + + @ensure_connection + def test_method(self): + return "success" + + test_db = MockDB() + result = test_db.test_method() + assert result == "success" + + +@mark_sync_test +def test_clear_neo4j_database(): + """Test clear_neo4j_database method.""" + test_db = Database() + + with patch.object(test_db, "cypher_query", new_callable=Mock) as mock_cypher: + with patch.object( + test_db, "drop_constraints", new_callable=Mock + ) as mock_drop_constraints: + with patch.object( + test_db, "drop_indexes", new_callable=Mock + ) as mock_drop_indexes: + test_db.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + + mock_cypher.assert_called_once() + mock_drop_constraints.assert_called_once() + mock_drop_indexes.assert_called_once() + + +@mark_sync_test +def test_drop_constraints(): + """Test drop_constraints method.""" + test_db = Database() + + mock_results = [ + {"name": "constraint1", "labelsOrTypes": ["Label1"], "properties": ["prop1"]}, + {"name": "constraint2", "labelsOrTypes": ["Label2"], "properties": ["prop2"]}, + ] + + with patch.object(test_db, "cypher_query", new_callable=Mock) as mock_cypher: + mock_cypher.return_value = ( + mock_results, + ["name", "labelsOrTypes", "properties"], + ) + + test_db.drop_constraints(quiet=False) + + # Should call LIST_CONSTRAINTS_COMMAND and DROP_CONSTRAINT_COMMAND for each constraint + assert mock_cypher.call_count == 3 # 1 for list + 2 for drop + + +@mark_sync_test +def test_drop_indexes(): + """Test drop_indexes method.""" + test_db = Database() + + mock_indexes = [ + {"name": "index1", "labelsOrTypes": ["Label1"], "properties": ["prop1"]}, + {"name": "index2", "labelsOrTypes": ["Label2"], "properties": ["prop2"]}, + ] + + with patch.object(test_db, "list_indexes", new_callable=Mock) as mock_list_indexes: + mock_list_indexes.return_value = mock_indexes + + with patch.object(test_db, "cypher_query", new_callable=Mock) as mock_cypher: + test_db.drop_indexes(quiet=False) + + # Should call DROP_INDEX_COMMAND for each index + assert mock_cypher.call_count == 2 + + +@mark_sync_test +def test_remove_all_labels(): + """Test remove_all_labels method.""" + test_db = Database() + + with patch.object( + test_db, "drop_constraints", new_callable=Mock + ) as mock_drop_constraints: + with patch.object( + test_db, "drop_indexes", new_callable=Mock + ) as mock_drop_indexes: + with patch("sys.stdout") as mock_stdout: + test_db.remove_all_labels() + + mock_drop_constraints.assert_called_once_with( + quiet=False, stdout=mock_stdout + ) + mock_drop_indexes.assert_called_once_with( + quiet=False, stdout=mock_stdout + ) + + +@mark_sync_test +def test_install_all_labels(): + """Test install_all_labels method.""" + test_db = Database() + + class MockNode: + def __init__(self, name): + self.__name__ = name + + @classmethod + def install_labels(cls, quiet=True, stdout=None): + pass + + with patch("neomodel.sync_.node.StructuredNode", MockNode): + with patch("sys.stdout"): + test_db.install_all_labels() + + # Should call install_labels on each node class + assert True # Test passes if no exception is raised + + +@mark_sync_test +def test_proxy_aenter_parallel_runtime_warning(): + """Test TransactionProxy __enter__ with parallel runtime warning.""" + test_db = Database() + proxy = TransactionProxy(test_db, parallel_runtime=True) + + with patch.object( + test_db, "parallel_runtime_available", new_callable=Mock + ) as mock_available: + mock_available.return_value = False + + with patch("warnings.warn") as mock_warn: + with patch.object(test_db, "begin", new_callable=Mock) as mock_begin: + proxy.__enter__() + + mock_warn.assert_called_once() + mock_begin.assert_called_once() + + +@mark_sync_test +def test_proxy_aexit_with_exception(): + """Test TransactionProxy __exit__ with exception.""" + test_db = Database() + proxy = TransactionProxy(test_db) + + with patch.object(test_db, "rollback", new_callable=Mock) as mock_rollback: + with patch.object(test_db, "commit", new_callable=Mock) as mock_commit: + # Test with exception + proxy.__exit__(ValueError, ValueError("test"), None) + mock_rollback.assert_called_once() + mock_commit.assert_not_called() + + +@mark_sync_test +def test_proxy_aexit_success(): + """Test TransactionProxy __exit__ with success.""" + test_db = Database() + proxy = TransactionProxy(test_db) + + with patch.object(test_db, "rollback", new_callable=Mock) as mock_rollback: + with patch.object(test_db, "commit", new_callable=Mock) as mock_commit: + mock_commit.return_value = "bookmarks" + + proxy.__exit__(None, None, None) + mock_rollback.assert_not_called() + mock_commit.assert_called_once() + assert proxy.last_bookmarks == "bookmarks" + + +@mark_sync_test +def test_proxy_call_decorator(): + """Test TransactionProxy __call__ decorator.""" + test_db = Database() + proxy = TransactionProxy(test_db) + + def test_func(): + return "success" + + decorated = proxy(test_func) + assert callable(decorated) + + # Test that the decorated function works + with patch.object(proxy, "__enter__", new_callable=Mock) as mock_enter: + with patch.object(proxy, "__exit__", new_callable=Mock): + mock_enter.return_value = proxy + result = decorated() + assert result == "success" + + +@mark_sync_test +def test_cypher_query_client_error_generic(): + """Test cypher_query with generic ClientError.""" + test_db = Database() + + with patch.object(test_db, "_run_cypher_query", new_callable=Mock) as mock_run: + client_error = ClientError("Neo.ClientError.Generic", "message") + mock_run.side_effect = client_error + + with pytest.raises(ClientError): + test_db.cypher_query("MATCH (n) RETURN n") diff --git a/test/sync_/test_database_management.py b/test/sync_/test_database_management.py index 9f663994..70f407df 100644 --- a/test/sync_/test_database_management.py +++ b/test/sync_/test_database_management.py @@ -1,5 +1,7 @@ +import asyncio from test._async_compat import mark_sync_test +import neo4j import pytest from neo4j.exceptions import AuthError @@ -11,6 +13,8 @@ StructuredRel, db, ) +from neomodel._async_compat.util import Util +from neomodel.sync_.database import Database class City(StructuredNode): @@ -79,3 +83,76 @@ def test_change_password(): db.close_connection() db.set_connection(url=prev_url) + + +@mark_sync_test +def test_adb_singleton_behavior(): + """Test that Database enforces singleton behavior.""" + + # Get the module-level instance + adb1 = Database.get_instance() + + # Try to create another instance directly + adb2 = Database() + + # Try to create another instance via get_instance + adb3 = Database.get_instance() + + # All instances should be the same object + assert adb1 is adb2, "Direct instantiation should return the same instance" + assert adb1 is adb3, "get_instance should return the same instance" + assert adb2 is adb3, "All instances should be the same object" + + # Test that the module-level 'adb' is also the same instance + assert db is adb1, "Module-level 'db' should be the same instance" + + +@mark_sync_test +def test_async_database_properties(): + # A fresh instance of AsyncDatabase is not yet connected + Database.reset_instance() + reset_singleton = Database.get_instance() + assert reset_singleton._active_transaction is None + assert reset_singleton.url is None + assert reset_singleton.driver is None + assert reset_singleton._session is None + assert reset_singleton._pid is None + assert reset_singleton._database_name is neo4j.DEFAULT_DATABASE + assert reset_singleton._database_version is None + assert reset_singleton._database_edition is None + assert reset_singleton.impersonated_user is None + assert reset_singleton._parallel_runtime is False + + +@mark_sync_test +def test_parallel_transactions(): + if not Util.is_async_code: + pytest.skip("Async only test") + + transactions = set() + sessions = set() + + def query(i: int): + asyncio.sleep(0.05) + + assert db._active_transaction is None + assert db._session is None + + with db.transaction: + # ensure transaction and session are unique for async context + transaction_id = id(db._active_transaction) + assert transaction_id not in transactions + transactions.add(transaction_id) + + session_id = id(db._session) + assert session_id not in sessions + sessions.add(session_id) + + result, _ = db.cypher_query( + "CALL apoc.util.sleep($delay_ms) RETURN $task_id as task_id, $delay_ms as slept", + {"delay_ms": i * 505, "task_id": i}, + ) + + return result[0][0], result[0][1], transaction_id, session_id + + _ = asyncio.gather(*(query(i) for i in range(1, 5))) diff --git a/test/sync_/test_fulltextfilter.py b/test/sync_/test_fulltextfilter.py new file mode 100644 index 00000000..fc49d2e5 --- /dev/null +++ b/test/sync_/test_fulltextfilter.py @@ -0,0 +1,325 @@ +from datetime import datetime +from test._async_compat import mark_sync_test + +import pytest + +from neomodel import ( + DateTimeProperty, + FloatProperty, + FulltextIndex, + RelationshipFrom, + StringProperty, + StructuredNode, + StructuredRel, + db, +) +from neomodel.semantic_filters import FulltextFilter + + +@mark_sync_test +def test_base_fulltextfilter(): + """ + Tests that the fulltextquery is run, node and score are returned. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNode(StructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + other = StringProperty() + + db.install_labels(fulltextNode) + + node1 = fulltextNode(other="thing", description="Another thing").save() + + node2 = fulltextNode(other="other thing", description="Another other thing").save() + + fulltextNodeSearch = fulltextNode.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ) + ) + + result = fulltextNodeSearch.all() + assert all(isinstance(x[0], fulltextNode) for x in result) + assert all(isinstance(x[1], float) for x in result) + + +@mark_sync_test +def test_fulltextfilter_topk_works(): + """ + Tests that the topk filter works. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNodetopk(StructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + + db.install_labels(fulltextNodetopk) + + node1 = fulltextNodetopk(description="this description").save() + node2 = fulltextNodetopk(description="that description").save() + node3 = fulltextNodetopk(description="my description").save() + + fulltextNodeSearch = fulltextNodetopk.nodes.filter( + fulltext_filter=FulltextFilter( + topk=2, fulltext_attribute_name="description", query_string="description" + ) + ) + + result = fulltextNodeSearch.all() + assert len(result) == 2 + + +@mark_sync_test +def test_fulltextfilter_with_node_propertyfilter(): + """ + Tests that the fulltext query is run, and "thing" node is only node returned. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNodeBis(StructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + other = StringProperty() + + db.install_labels(fulltextNodeBis) + + node1 = fulltextNodeBis(other="thing", description="Another thing").save() + + node2 = fulltextNodeBis( + other="other thing", description="Another other thing" + ).save() + + fulltextFilterforthing = fulltextNodeBis.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ), + other="thing", + ) + + result = fulltextFilterforthing.all() + + assert len(result) == 1 + assert all(isinstance(x[0], fulltextNodeBis) for x in result) + assert result[0][0].other == "thing" + assert all(isinstance(x[1], float) for x in result) + + +@mark_sync_test +def test_dont_duplicate_fulltext_filter_node(): + """ + Tests the situation that another node has the same filter value. + Testing that we are only performing the fulltextfilter and metadata filter on the right nodes. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextNodeTer(StructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + name = StringProperty() + + class otherfulltextNodeTer(StructuredNode): + other_description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + other_name = StringProperty() + + db.install_labels(fulltextNodeTer) + db.install_labels(otherfulltextNodeTer) + + node1 = fulltextNodeTer(name="John", description="thing one").save() + node2 = fulltextNodeTer(name="Fred", description="thing two").save() + node3 = otherfulltextNodeTer(name="John", description="thing three").save() + node4 = otherfulltextNodeTer(name="Fred", description="thing four").save() + + john_fulltext_search = fulltextNodeTer.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ), + name="John", + ) + + result = john_fulltext_search.all() + + assert len(result) == 1 + assert isinstance(result[0][0], fulltextNodeTer) + assert result[0][0].name == "John" + assert isinstance(result[0][1], float) + + +@mark_sync_test +def test_django_filter_w_fulltext_filter(): + """ + Tests that django filters still work with the fulltext filter. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class fulltextDjangoNode(StructuredNode): + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + name = StringProperty() + number = FloatProperty() + + db.install_labels(fulltextDjangoNode) + + nodeone = fulltextDjangoNode( + name="John", description="thing one", number=float(10) + ).save() + + nodetwo = fulltextDjangoNode( + name="Fred", description="thing two", number=float(3) + ).save() + + fulltext_index_with_django_filter = fulltextDjangoNode.nodes.filter( + fulltext_filter=FulltextFilter( + topk=3, fulltext_attribute_name="description", query_string="thing" + ), + number__gt=5, + ) + + result = fulltext_index_with_django_filter.all() + assert len(result) == 1 + assert isinstance(result[0][0], fulltextDjangoNode) + assert result[0][0].number > 5 + + +@mark_sync_test +def test_fulltextfilter_with_relationshipfilter(): + """ + Tests that by filtering on fulltext similarity and then peforming a relationshipfilter works. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class SupplierFT(StructuredNode): + name = StringProperty() + + class SuppliesFTRel(StructuredRel): + since = DateTimeProperty(default=datetime.now) + + class ProductFT(StructuredNode): + name = StringProperty() + description = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + suppliers = RelationshipFrom(SupplierFT, "SUPPLIES", model=SuppliesFTRel) + + db.install_labels(SupplierFT) + db.install_labels(SuppliesFTRel) + db.install_labels(ProductFT) + + supplier1 = SupplierFT(name="Supplier 1").save() + supplier2 = SupplierFT(name="Supplier 2").save() + product1 = ProductFT( + name="Product A", + description="High quality product", + ).save() + product2 = ProductFT( + name="Product B", + description="Very High quality product", + ).save() + product1.suppliers.connect(supplier1) + product1.suppliers.connect(supplier2) + product2.suppliers.connect(supplier1) + + filtered_product = ProductFT.nodes.filter( + fulltext_filter=FulltextFilter( + topk=1, fulltext_attribute_name="description", query_string="product" + ), + suppliers__name="Supplier 1", + ) + + result = filtered_product.all() + + assert len(result) == 1 + assert isinstance(result[0][0], ProductFT) + assert isinstance(result[0][1], SupplierFT) + assert isinstance(result[0][2], SuppliesFTRel) + + +@mark_sync_test +def test_fulltextfiler_nonexistent_attribute(): + """ + Tests that AttributeError is raised when fulltext_attribute_name doesn't exist on the source. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class TestNodeWithFT(StructuredNode): + name = StringProperty() + fulltext = StringProperty( + fulltext_index=FulltextIndex( + analyzer="standard-no-stop-words", eventually_consistent=False + ) + ) + + db.install_labels(TestNodeWithFT) + + with pytest.raises( + AttributeError, match="Atribute 'nonexistent_fulltext' not found" + ): + nodeset = TestNodeWithFT.nodes.filter( + fulltext_filter=FulltextFilter( + topk=1, + fulltext_attribute_name="nonexistent_fulltext", + query_string="something", + ) + ) + nodeset.all() + + +@mark_sync_test +def test_fulltextfiler_no_fulltext_index(): + """ + Tests that AttributeError is raised when fulltext_attribute_name doesn't exist on the source. + """ + + if not db.version_is_higher_than("5.16"): + pytest.skip("Not supported before 5.16") + + class TestNodeWithoutFT(StructuredNode): + name = StringProperty() + fulltext = StringProperty() # No fulltext_index + + db.install_labels(TestNodeWithoutFT) + + with pytest.raises(AttributeError, match="is not declared with a full text index"): + nodeset = TestNodeWithoutFT.nodes.filter( + fulltext_filter=FulltextFilter( + topk=1, fulltext_attribute_name="fulltext", query_string="something" + ) + ) + nodeset.all() diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index fab4f0d7..3b30c14f 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -5,8 +5,8 @@ More information about the same issue at: https://github.com/aanastasiou/neomodelInheritanceTest -The following example uses a recursive relationship for economy, but the -idea remains the same: "Instantiate the correct type of node at the end of +The following example uses a recursive relationship for economy, but the +idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ @@ -118,44 +118,6 @@ def test_automatic_result_resolution(): assert type((A.friends_with)[0]) is TechnicalPerson -@mark_sync_test -def test_recursive_automatic_result_resolution(): - """ - Node objects are instantiated to native Python objects, both at the top - level of returned results and in the case where they are returned within - lists. - """ - - # Create a few entities - A = ( - TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"}) - )[0] - B = (TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"}))[ - 0 - ] - C = (TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"}))[0] - D = (TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"}))[0] - - # Retrieve mixed results, both at the top level and nested - L, _ = db.cypher_query( - "MATCH (a:TechnicalPerson) " - "WHERE a.expertise='Grumpiness' " - "WITH collect(a) as Alpha " - "MATCH (b:TechnicalPerson) " - "WHERE b.expertise='Pillows' " - "WITH Alpha, collect(b) as Beta " - "RETURN [Alpha, [Beta, [Beta, ['Banana', " - "Alpha]]]]", - resolve_objects=True, - ) - - # Assert that a Node returned deep in a nested list structure is of the - # correct type - assert type(L[0][0][0][1][0][0][0][0]) is TechnicalPerson - # Assert that primitive data types remain primitive data types - assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - - @mark_sync_test def test_validation_with_inheritance_from_db(): """ diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 41afb1c5..ba2d20a3 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,11 +1,11 @@ import re from datetime import datetime from test._async_compat import mark_sync_test +from unittest.mock import MagicMock, Mock from pytest import raises, skip, warns from neomodel import ( - INCOMING, ArrayProperty, DateTimeProperty, IntegerProperty, @@ -34,6 +34,7 @@ Size, Traversal, ) +from neomodel.util import RelationshipDirection class SupplierRel(StructuredRel): @@ -422,7 +423,7 @@ def test_traversal_definition_keys_are_valid(): "a_name", { "node_class": Supplier, - "direction": INCOMING, + "direction": RelationshipDirection.INCOMING, "relationship_type": "KNOWS", "model": None, }, @@ -433,7 +434,7 @@ def test_traversal_definition_keys_are_valid(): "a_name", { "node_class": Supplier, - "direction": INCOMING, + "direction": RelationshipDirection.INCOMING, "relation_type": "KNOWS", "model": None, }, @@ -544,7 +545,7 @@ def test_q_filters(): robusta = Species(name="Robusta").save() c4.species.connect(robusta) latte_or_robusta_coffee = ( - Coffee.nodes.fetch_relations(Optional("species")) + Coffee.nodes.traverse(Path(value="species", optional=True)) .filter(Q(name="Latte") | Q(species__name="Robusta")) .all() ) @@ -553,7 +554,7 @@ def test_q_filters(): arabica = Species(name="Arabica").save() c1.species.connect(arabica) robusta_coffee = ( - Coffee.nodes.fetch_relations(Optional("species")) + Coffee.nodes.traverse(Path(value="species", optional=True)) .filter(species__name="Robusta") .all() ) @@ -675,72 +676,17 @@ def test_relation_prop_ordering(): nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) nescafe.species.connect(arabica) - results = Supplier.nodes.fetch_relations("coffees").order_by("-coffees|since").all() + results = Supplier.nodes.traverse("coffees").order_by("-coffees|since").all() assert len(results) == 2 assert results[0][0] == supplier1 assert results[1][0] == supplier2 - results = Supplier.nodes.fetch_relations("coffees").order_by("coffees|since").all() + results = Supplier.nodes.traverse("coffees").order_by("coffees|since").all() assert len(results) == 2 assert results[0][0] == supplier2 assert results[1][0] == supplier1 -@mark_sync_test -def test_fetch_relations(): - arabica = Species(name="Arabica").save() - robusta = Species(name="Robusta").save() - nescafe = Coffee(name="Nescafe", price=99).save() - nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() - - tesco = Supplier(name="Tesco", delivery_cost=3).save() - nescafe.suppliers.connect(tesco) - nescafe_gold.suppliers.connect(tesco) - nescafe.species.connect(arabica) - - result = ( - Supplier.nodes.filter(name="Tesco").fetch_relations("coffees__species").all() - ) - assert len(result[0]) == 5 - assert arabica in result[0] - assert robusta not in result[0] - assert tesco in result[0] - assert nescafe in result[0] - assert nescafe_gold not in result[0] - - result = ( - Species.nodes.filter(name="Robusta") - .fetch_relations(Optional("coffees__suppliers")) - .all() - ) - assert len(result) == 1 - - if Util.is_async_code: - count = ( - Supplier.nodes.filter(name="Tesco") - .fetch_relations("coffees__species") - .__len__() - ) - assert count == 1 - - assert ( - Supplier.nodes.fetch_relations("coffees__species") - .filter(name="Tesco") - .__contains__(tesco) - ) - else: - count = len( - Supplier.nodes.filter(name="Tesco") - .fetch_relations("coffees__species") - .all() - ) - assert count == 1 - - assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( - name="Tesco" - ) - - @mark_sync_test def test_traverse(): arabica = Species(name="Arabica").save() @@ -804,7 +750,7 @@ def test_traverse_and_order_by(): nescafe.species.connect(arabica) nescafe_gold.species.connect(robusta) - results = Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all() + results = Species.nodes.traverse("coffees").order_by("-coffees__price").all() assert len(results) == 2 assert len(results[0]) == 3 # 2 nodes and 1 relation assert results[0][0] == robusta @@ -826,36 +772,66 @@ def test_annotate_and_collect(): nescafe_gold.species.connect(arabica) result = ( - Supplier.nodes.traverse_relations(species="coffees__species") + Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(Collect("species")) .all() ) assert len(result) == 1 - assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + assert len(result[0][1]) == 3 # 3 species must be there (with 2 duplicates) result = ( - Supplier.nodes.traverse_relations(species="coffees__species") + Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( - Supplier.nodes.traverse_relations(species="coffees__species") + Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(Size(Collect("species", distinct=True))) .all() ) assert result[0][1] == 2 # 2 species result = ( - Supplier.nodes.traverse_relations(species="coffees__species") + Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate(all_species=Collect("species", distinct=True)) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][1]) == 2 # 2 species must be there result = ( - Supplier.nodes.traverse_relations("coffees__species") + Supplier.nodes.traverse( + species=Path( + value="coffees__species", + include_rels_in_return=False, + include_nodes_in_return=False, + ) + ) .annotate( all_species=Collect(NodeNameResolver("coffees__species"), distinct=True), all_species_rels=Collect( @@ -864,8 +840,8 @@ def test_annotate_and_collect(): ) .all() ) - assert len(result[0][1][0]) == 2 # 2 species must be there - assert len(result[0][2][0]) == 3 # 3 species relations must be there + assert len(result[0][1]) == 2 # 2 species must be there + assert len(result[0][2]) == 3 # 3 species relations must be there @mark_sync_test @@ -884,22 +860,12 @@ def test_resolve_subgraph(): with raises( RuntimeError, match=re.escape( - "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + "Nothing to resolve. Make sure to include relations in the result using traverse() or filter()." ), ): result = Supplier.nodes.resolve_subgraph() - with raises( - NotImplementedError, - match=re.escape( - "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." - ), - ): - result = Supplier.nodes.traverse_relations( - "coffees__species" - ).resolve_subgraph() - - result = Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() + result = Supplier.nodes.traverse("coffees__species").resolve_subgraph() assert len(result) == 2 assert hasattr(result[0], "_relations") @@ -926,8 +892,8 @@ def test_resolve_subgraph_optional(): nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) - result = Supplier.nodes.fetch_relations( - Optional("coffees__species") + result = Supplier.nodes.traverse( + Path(value="coffees__species", optional=True) ).resolve_subgraph() assert len(result) == 1 @@ -951,7 +917,7 @@ def test_subquery(): nescafe.species.connect(arabica) subquery = Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers") + Coffee.nodes.traverse(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] ) @@ -969,16 +935,14 @@ def test_subquery(): match=re.escape("Variable 'unknown' is not returned by subquery."), ): result = Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + Coffee.nodes.traverse(suppliers="suppliers").annotate( supps=Collect("suppliers") ), ["unknown"], ) result_string_context = subquery.subquery( - Coffee.nodes.traverse_relations(supps2="suppliers").annotate( - supps2=Collect("supps") - ), + Coffee.nodes.traverse(supps2="suppliers").annotate(supps2=Collect("supps")), ["supps2"], ["supps"], ) @@ -992,7 +956,7 @@ def test_subquery(): with raises(ValueError, match=r"Wrong variable specified in initial context"): result = Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + Coffee.nodes.traverse(suppliers="suppliers").annotate( supps=Collect("suppliers") ), ["supps"], @@ -1040,7 +1004,7 @@ def test_intermediate_transform(): nescafe.species.connect(arabica) result = ( - Coffee.nodes.fetch_relations("suppliers") + Coffee.nodes.traverse("suppliers") .intermediate_transform( { "coffee": {"source": "coffee", "include_in_return": True}, @@ -1068,7 +1032,7 @@ def test_intermediate_transform(): r"Wrong source type specified for variable 'test', should be a string or an instance of NodeNameResolver or RelationNameResolver" ), ): - Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( + Coffee.nodes.traverse(suppliers="suppliers").intermediate_transform( { "test": {"source": Collect("suppliers")}, } @@ -1079,9 +1043,7 @@ def test_intermediate_transform(): r"You must provide one variable at least when calling intermediate_transform()" ), ): - Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( - {} - ) + Coffee.nodes.traverse(suppliers="suppliers").intermediate_transform({}) @mark_sync_test @@ -1134,12 +1096,12 @@ def test_mix_functions(): full_nodeset = ( Student.nodes.filter(name__istartswith="m", lives_in__name="Eiffel Tower") .order_by("name") - .fetch_relations( + .traverse( "parents", - Optional("children__preferred_course"), + Path(value="children__preferred_course", optional=True), ) .subquery( - Student.nodes.fetch_relations("courses") + Student.nodes.traverse("courses") .intermediate_transform( {"rel": {"source": RelationNameResolver("courses")}}, ordering=[ @@ -1187,9 +1149,9 @@ def test_issue_795(): with raises( RelationshipClassNotDefined, - match=r"[\s\S]*Note that when using the fetch_relations method, the relationship type must be defined in the model.*", + match=r"[\s\S]*Note that when using the traverse method, the relationship type must be defined in the model.*", ): - _ = PersonX.nodes.fetch_relations("country").all() + _ = PersonX.nodes.traverse("country").all() @mark_sync_test @@ -1227,7 +1189,7 @@ def test_unique_variables(): gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) gold3000.species.connect(arabica) - nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter( + nodeset = Supplier.nodes.traverse("coffees", "coffees__species").filter( coffees__name="Nescafe" ) ast = nodeset.query_cls(nodeset).build_ast() @@ -1240,7 +1202,7 @@ def test_unique_variables(): assert len(results) == 3 nodeset = ( - Supplier.nodes.fetch_relations("coffees", "coffees__species") + Supplier.nodes.traverse("coffees", "coffees__species") .filter(coffees__name="Nescafe") .unique_variables("coffees") ) @@ -1289,6 +1251,20 @@ def assert_last_query_startswith(mock_func, query) -> bool: return mock_func.call_args_list[-1].kwargs["query"].startswith(query) +def create_mock_async_result(): + """Create a mock Result that behaves like neo4j.Result""" + mock_result = MagicMock() + mock_result.keys.return_value = () + + # Create an async iterator that yields empty records + def iter(self): + return + yield # This makes it an async generator + + mock_result.__iter__ = iter + return mock_result + + @mark_sync_test def test_parallel_runtime(mocker): if not db.version_is_higher_than("5.13") or not db.edition_is_enterprise(): @@ -1301,7 +1277,11 @@ def test_parallel_runtime(mocker): # Mock transaction.run to access executed query # Assert query starts with CYPHER runtime=parallel assert db._parallel_runtime == True - mock_transaction_run = mocker.patch("neo4j.Transaction.run") + mock_transaction_run = mocker.patch( + "neo4j.Transaction.run", + new_callable=Mock, + return_value=create_mock_async_result(), + ) db.cypher_query("MATCH (n:Coffee) RETURN n") assert assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" @@ -1311,7 +1291,11 @@ def test_parallel_runtime(mocker): # Parallel should be applied to neomodel queries with db.parallel_read_transaction: - mock_transaction_run_2 = mocker.patch("neo4j.Transaction.run") + mock_transaction_run_2 = mocker.patch( + "neo4j.Transaction.run", + new_callable=Mock, + return_value=create_mock_async_result(), + ) Coffee.nodes.all() assert assert_last_query_startswith( mock_transaction_run_2, "CYPHER runtime=parallel" @@ -1324,7 +1308,11 @@ def test_parallel_runtime_conflict(mocker): skip("Test for unavailable parallel runtime.") assert not db.parallel_runtime_available() - mock_transaction_run = mocker.patch("neo4j.Transaction.run") + mock_transaction_run = mocker.patch( + "neo4j.Transaction.run", + new_callable=Mock, + return_value=create_mock_async_result(), + ) with warns( UserWarning, match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", diff --git a/test/sync_/test_object_resolution.py b/test/sync_/test_object_resolution.py new file mode 100644 index 00000000..6ea63201 --- /dev/null +++ b/test/sync_/test_object_resolution.py @@ -0,0 +1,550 @@ +""" +Test cases for object resolution with resolve_objects=True in raw Cypher queries. + +This test file covers various scenarios for automatic class resolution, +including the issues identified in GitHub issues #905 and #906: +- Issue #905: Nested lists in results of raw Cypher queries with collect keyword +- Issue #906: Automatic class resolution for raw queries with nodes nested in maps + +Additional scenarios tested: +- Basic object resolution +- Nested structures (lists, maps, mixed) +- Path resolution +- Relationship resolution +- Complex nested scenarios with collect() and other Cypher functions +""" + +from test._async_compat import mark_sync_test + +from neomodel import ( + IntegerProperty, + RelationshipTo, + StringProperty, + StructuredNode, + StructuredRel, + db, +) + + +class ResolutionRelationship(StructuredRel): + """Test relationship with properties.""" + + weight = IntegerProperty(default=1) + description = StringProperty(default="test") + + +class ResolutionNode(StructuredNode): + """Base test node class.""" + + name = StringProperty(required=True) + value = IntegerProperty(default=0) + related = RelationshipTo( + "ResolutionNode", "RELATED_TO", model=ResolutionRelationship + ) + + +class ResolutionSpecialNode(StructuredNode): + """Specialized test node class.""" + + name = StringProperty(required=True) + special_value = IntegerProperty(default=42) + related = RelationshipTo(ResolutionNode, "RELATED_TO", model=ResolutionRelationship) + + +class ResolutionContainerNode(StructuredNode): + """Container node for testing nested structures.""" + + name = StringProperty(required=True) + items = RelationshipTo(ResolutionNode, "CONTAINS", model=ResolutionRelationship) + + +@mark_sync_test +def test_basic_object_resolution(): + """Test basic object resolution for nodes and relationships.""" + # Create test data + ResolutionNode(name="Node1", value=10).save() + ResolutionNode(name="Node2", value=20).save() + + # Test basic node resolution + results, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "Node1"}, + resolve_objects=True, + ) + + assert len(results) == 1 + assert len(results[0]) == 1 + resolved_node = results[0][0] + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "Node1" + assert resolved_node.value == 10 + + +@mark_sync_test +def test_relationship_resolution(): + """Test relationship resolution in queries.""" + # Create test data with relationships + node1 = ResolutionNode(name="Source", value=100).save() + node2 = ResolutionNode(name="Target", value=200).save() + + # Create relationship + node1.related.connect(node2, {"weight": 5, "description": "test_rel"}) + + # Test relationship resolution + results, _ = db.cypher_query( + "MATCH (a:ResolutionNode)-[r:RELATED_TO]->(b:ResolutionNode) RETURN a, r, b", + resolve_objects=True, + ) + + assert len(results) == 1 + source, rel, target = results[0] + + assert isinstance(source, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert isinstance(target, ResolutionNode) + + assert source.name == "Source" + assert target.name == "Target" + assert rel.weight == 5 + assert rel.description == "test_rel" + + +@mark_sync_test +def test_path_resolution(): + """Test path resolution in queries.""" + # Create test data + node1 = ResolutionNode(name="Start", value=1).save() + node2 = ResolutionNode(name="Middle", value=2).save() + node3 = ResolutionNode(name="End", value=3).save() + + # Create path + node1.related.connect(node2, {"weight": 1}) + node2.related.connect(node3, {"weight": 2}) + + # Test path resolution + results, _ = db.cypher_query( + "MATCH p=(a:ResolutionNode)-[:RELATED_TO*2]->(c:ResolutionNode) RETURN p", + resolve_objects=True, + ) + + assert len(results) == 1 + path = results[0][0] + + # Path should be resolved to AsyncNeomodelPath + from neomodel.sync_.path import NeomodelPath + + assert isinstance(path, NeomodelPath) + assert len(path._nodes) == 3 # pylint: disable=protected-access + assert len(path._relationships) == 2 # pylint: disable=protected-access + + +@mark_sync_test +def test_nested_lists_basic(): + """Test basic nested list resolution (Issue #905 - basic case).""" + # Create test data + nodes = [] + for i in range(3): + node = ResolutionNode(name=f"Node{i}", value=i * 10).save() + nodes.append(node) + + # Test nested list resolution + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + RETURN collect(n) as nodes + """, + resolve_objects=True, + ) + + assert len(results) == 1 + collected_nodes = results[0][0] + + assert isinstance(collected_nodes, list) + assert len(collected_nodes) == 3 + + for i, node in enumerate(collected_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Node{i}" + assert node.value == i * 10 + + +@mark_sync_test +def test_nested_lists_complex(): + """Test complex nested list resolution with collect() (Issue #905 - complex case).""" + # Create test data with relationships + container = ResolutionContainerNode(name="Container").save() + items = [] + for i in range(2): + item = ResolutionNode(name=f"Item{i}", value=i * 5).save() + items.append(item) + container.items.connect(item, {"weight": i + 1}) + + # Test complex nested list with collect + results, _ = db.cypher_query( + """ + MATCH (c:ResolutionContainerNode)-[r:CONTAINS]->(i:ResolutionNode) + WITH c, r, i ORDER BY i.name + WITH c, collect({item: i, rel: r}) as items + RETURN c, items + """, + resolve_objects=True, + ) + + assert len(results) == 1 + container_result, items_result = results[0] + + assert isinstance(container_result, ResolutionContainerNode) + assert container_result.name == "Container" + + assert isinstance(items_result, list) + assert len(items_result) == 2 + + for i, item_data in enumerate(items_result): + assert isinstance(item_data, dict) + assert "item" in item_data + assert "rel" in item_data + + item = item_data["item"] + rel = item_data["rel"] + + assert isinstance(item, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert item.name == f"Item{i}" + assert rel.weight == i + 1 + + +@mark_sync_test +def test_nodes_nested_in_maps(): + """Test nodes nested in maps (Issue #906).""" + # Create test data + ResolutionNode(name="Node1", value=100).save() + ResolutionNode(name="Node2", value=200).save() + + # Test nodes nested in maps + results, _ = db.cypher_query( + """ + MATCH (n1:ResolutionNode), (n2:ResolutionNode) + WHERE n1.name = 'Node1' AND n2.name = 'Node2' + RETURN { + first: n1, + second: n2, + metadata: { + count: 2, + description: 'test map' + } + } as result_map + """, + resolve_objects=True, + ) + + assert len(results) == 1 + result_map = results[0][0] + + assert isinstance(result_map, dict) + assert "first" in result_map + assert "second" in result_map + assert "metadata" in result_map + + # Check that nodes are properly resolved + first_node = result_map["first"] + second_node = result_map["second"] + + assert isinstance(first_node, ResolutionNode) + assert isinstance(second_node, ResolutionNode) + assert first_node.name == "Node1" + assert second_node.name == "Node2" + + # Check metadata (should remain as primitive types) + metadata = result_map["metadata"] + assert isinstance(metadata, dict) + assert metadata["count"] == 2 + assert metadata["description"] == "test map" + + +@mark_sync_test +def test_mixed_nested_structures(): + """Test mixed nested structures with lists, maps, and nodes.""" + # Create test data + special = ResolutionSpecialNode(name="Special", special_value=999).save() + test_nodes = [] + for i in range(2): + node = ResolutionNode(name=f"Test{i}", value=i * 100).save() + test_nodes.append(node) + special.related.connect(node, {"weight": i + 10}) + + # Test complex mixed structure + results, _ = db.cypher_query( + """ + MATCH (s:ResolutionSpecialNode)-[r:RELATED_TO]->(t:ResolutionNode) + WITH s, r, t ORDER BY t.name + WITH s, collect({node: t, rel: r}) as related_items + RETURN { + special_node: s, + related: related_items, + summary: { + total_relations: size(related_items), + node_names: [item in related_items | item.node.name] + } + } as complex_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + complex_result = results[0][0] + + assert isinstance(complex_result, dict) + assert "special_node" in complex_result + assert "related" in complex_result + assert "summary" in complex_result + + # Check special node resolution + special_node = complex_result["special_node"] + assert isinstance(special_node, ResolutionSpecialNode) + assert special_node.name == "Special" + assert special_node.special_value == 999 + + # Check related items (list of dicts with nodes and relationships) + related = complex_result["related"] + assert isinstance(related, list) + assert len(related) == 2 + + for i, item in enumerate(related): + assert isinstance(item, dict) + assert "node" in item + assert "rel" in item + + node = item["node"] + rel = item["rel"] + + assert isinstance(node, ResolutionNode) + assert isinstance(rel, ResolutionRelationship) + assert node.name == f"Test{i}" + assert rel.weight == i + 10 + + # Check summary (should remain as primitive types) + summary = complex_result["summary"] + assert isinstance(summary, dict) + assert summary["total_relations"] == 2 + assert isinstance(summary["node_names"], list) + assert summary["node_names"] == ["Test0", "Test1"] + + +@mark_sync_test +def test_deeply_nested_structures(): + """Test deeply nested structures to ensure recursive resolution works.""" + # Create test data + nodes = [] + for i in range(3): + node = ResolutionNode(name=f"Deep{i}", value=i * 50).save() + nodes.append(node) + + # Test deeply nested structure + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) + WITH n ORDER BY n.name + WITH collect(n) as level1 + RETURN { + level1: level1, + level2: { + nodes: level1, + metadata: { + level3: { + count: size(level1), + items: level1 + } + } + } + } as deep_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + deep_result = results[0][0] + + assert isinstance(deep_result, dict) + assert "level1" in deep_result + assert "level2" in deep_result + + # Check level1 (direct list of nodes) + level1 = deep_result["level1"] + assert isinstance(level1, list) + assert len(level1) == 3 + for i, node in enumerate(level1): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check level2 (nested structure) + level2 = deep_result["level2"] + assert isinstance(level2, dict) + assert "nodes" in level2 + assert "metadata" in level2 + + # Check nodes in level2 + level2_nodes = level2["nodes"] + assert isinstance(level2_nodes, list) + assert len(level2_nodes) == 3 + for i, node in enumerate(level2_nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + # Check metadata in level2 + metadata = level2["metadata"] + assert isinstance(metadata, dict) + assert "level3" in metadata + + level3 = metadata["level3"] + assert isinstance(level3, dict) + assert "count" in level3 + assert "items" in level3 + + assert level3["count"] == 3 + level3_items = level3["items"] + assert isinstance(level3_items, list) + assert len(level3_items) == 3 + for i, node in enumerate(level3_items): + assert isinstance(node, ResolutionNode) + assert node.name == f"Deep{i}" + + +@mark_sync_test +def test_collect_with_aggregation(): + """Test collect() with aggregation functions.""" + # Create test data + for i in range(5): + node = ResolutionNode(name=f"AggNode{i}", value=i * 10).save() + + # Test collect with aggregation + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) + WHERE n.name STARTS WITH 'Agg' + WITH n ORDER BY n.name + WITH collect(n) as all_nodes + RETURN { + nodes: all_nodes, + count: size(all_nodes), + total_value: reduce(total = 0, n in all_nodes | total + n.value), + names: [n in all_nodes | n.name] + } as aggregated_result + """, + resolve_objects=True, + ) + + assert len(results) == 1 + aggregated_result = results[0][0] + + assert isinstance(aggregated_result, dict) + assert "nodes" in aggregated_result + assert "count" in aggregated_result + assert "total_value" in aggregated_result + assert "names" in aggregated_result + + # Check nodes are resolved + nodes = aggregated_result["nodes"] + assert isinstance(nodes, list) + assert len(nodes) == 5 + for i, node in enumerate(nodes): + assert isinstance(node, ResolutionNode) + assert node.name == f"AggNode{i}" + assert node.value == i * 10 + + # Check aggregated values + assert aggregated_result["count"] == 5 + assert aggregated_result["total_value"] == 100 # 0+10+20+30+40 + assert aggregated_result["names"] == [ + "AggNode0", + "AggNode1", + "AggNode2", + "AggNode3", + "AggNode4", + ] + + +@mark_sync_test +def test_resolve_objects_false_comparison(): + """Test that resolve_objects=False returns raw Neo4j objects.""" + # Create test data + ResolutionNode(name="RawNode", value=123).save() + + # Test with resolve_objects=False + results_false, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=False, + ) + + # Test with resolve_objects=True + results_true, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = $name RETURN n", + {"name": "RawNode"}, + resolve_objects=True, + ) + + # Compare results + raw_node = results_false[0][0] + resolved_node = results_true[0][0] + + # Raw node should be a Neo4j Node object + from neo4j.graph import Node + + assert isinstance(raw_node, Node) + assert raw_node["name"] == "RawNode" + assert raw_node["value"] == 123 + + # Resolved node should be a ResolutionNode instance + assert isinstance(resolved_node, ResolutionNode) + assert resolved_node.name == "RawNode" + assert resolved_node.value == 123 + + +@mark_sync_test +def test_empty_results(): + """Test object resolution with empty results.""" + # Test empty results + results, _ = db.cypher_query( + "MATCH (n:ResolutionNode) WHERE n.name = 'NonExistent' RETURN n", + resolve_objects=True, + ) + + assert len(results) == 0 + + +@mark_sync_test +def test_primitive_types_preserved(): + """Test that primitive types are preserved during object resolution.""" + # Create test data + ResolutionNode(name="PrimitiveTest", value=456).save() + + # Test with mixed primitive and node types + results, _ = db.cypher_query( + """ + MATCH (n:ResolutionNode) WHERE n.name = $name + RETURN n, n.value as int_val, n.name as str_val, true as bool_val, 3.14 as float_val + """, + {"name": "PrimitiveTest"}, + resolve_objects=True, + ) + + assert len(results) == 1 + node_result, int_val, str_val, bool_val, float_val = results[0] + + # Node should be resolved + assert isinstance(node_result, ResolutionNode) + assert node_result.name == "PrimitiveTest" + + # Primitives should remain primitive + assert isinstance(int_val, int) + assert int_val == 456 + + assert isinstance(str_val, str) + assert str_val == "PrimitiveTest" + + assert isinstance(bool_val, bool) + assert bool_val is True + + assert isinstance(float_val, float) + assert float_val == 3.14 diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 633ceace..c20280b5 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -1,4 +1,4 @@ -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, timezone from test._async_compat import mark_sync_test from zoneinfo import ZoneInfo @@ -9,8 +9,8 @@ Relationship, StructuredNode, StructuredRel, - config, db, + get_config, ) from neomodel.contrib import SemiStructuredNode from neomodel.exceptions import ( @@ -137,7 +137,7 @@ def test_datetimes_timezones(): prop = DateTimeProperty() prop.name = "foo" prop.owner = FooBar - t = datetime.utcnow() + t = datetime.now(timezone.utc) gr = ZoneInfo("Europe/Athens") gb = ZoneInfo("Europe/London") dt1 = t.replace(tzinfo=gr) @@ -153,12 +153,13 @@ def test_datetimes_timezones(): default_now=True, default=datetime(1900, 1, 1, 0, 0, 0) ) - prev_force_timezone = config.FORCE_TIMEZONE - config.FORCE_TIMEZONE = True + config = get_config() + prev_force_timezone = config.force_timezone + config.force_timezone = True with raises(ValueError, match=r".*No timezone provided."): prop.deflate(datetime.now()) - config.FORCE_TIMEZONE = prev_force_timezone + config.force_timezone = prev_force_timezone def test_date(): @@ -379,26 +380,26 @@ class DefaultTestValueThree(StructuredNode): assert x.uid == "123" -class TestDBNamePropertyRel(StructuredRel): +class DBNamePropertyRel(StructuredRel): known_for = StringProperty(db_property="knownFor") # This must be defined outside of the test, otherwise the `Relationship` definition cannot look up -# `TestDBNamePropertyNode` -class TestDBNamePropertyNode(StructuredNode): +# `DBNamePropertyNode` +class DBNamePropertyNode(StructuredNode): name_ = StringProperty(db_property="name") - knows = Relationship("TestDBNamePropertyNode", "KNOWS", model=TestDBNamePropertyRel) + knows = Relationship("DBNamePropertyNode", "KNOWS", model=DBNamePropertyRel) @mark_sync_test def test_independent_property_name(): # -- test node -- - x = TestDBNamePropertyNode() + x = DBNamePropertyNode() x.name_ = "jim" x.save() # check database property name on low level - results, meta = db.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") + results, meta = db.cypher_query("MATCH (n:DBNamePropertyNode) RETURN n") node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties @@ -406,8 +407,8 @@ def test_independent_property_name(): # check python class property name at a high level assert not hasattr(x, "name") assert hasattr(x, "name_") - assert (TestDBNamePropertyNode.nodes.filter(name_="jim").all())[0].name_ == x.name_ - assert (TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ + assert (DBNamePropertyNode.nodes.filter(name_="jim").all())[0].name_ == x.name_ + assert (DBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ # -- test relationship -- @@ -417,7 +418,7 @@ def test_independent_property_name(): # check database property name on low level results, meta = db.cypher_query( - "MATCH (:TestDBNamePropertyNode)-[r:KNOWS]->(:TestDBNamePropertyNode) RETURN r" + "MATCH (:DBNamePropertyNode)-[r:KNOWS]->(:DBNamePropertyNode) RETURN r" ) rel_properties = get_graph_entity_properties(results[0][0]) assert rel_properties["knownFor"] == "10 years" @@ -432,15 +433,15 @@ def test_independent_property_name(): @mark_sync_test def test_independent_property_name_for_semi_structured(): - class TestDBNamePropertySemiStructuredNode(SemiStructuredNode): + class DBNamePropertySemiStructuredNode(SemiStructuredNode): title_ = StringProperty(db_property="title") - semi = TestDBNamePropertySemiStructuredNode(title_="sir", extra="data") + semi = DBNamePropertySemiStructuredNode(title_="sir", extra="data") semi.save() # check database property name on low level results, meta = db.cypher_query( - "MATCH (n:TestDBNamePropertySemiStructuredNode) RETURN n" + "MATCH (n:DBNamePropertySemiStructuredNode) RETURN n" ) node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["title"] == "sir" @@ -451,13 +452,11 @@ class TestDBNamePropertySemiStructuredNode(SemiStructuredNode): assert hasattr(semi, "title_") assert not hasattr(semi, "title") assert hasattr(semi, "extra") - from_filter = ( - TestDBNamePropertySemiStructuredNode.nodes.filter(title_="sir").all() - )[0] + from_filter = (DBNamePropertySemiStructuredNode.nodes.filter(title_="sir").all())[0] assert from_filter.title_ == "sir" # assert not hasattr(from_filter, "title") assert from_filter.extra == "data" - from_get = TestDBNamePropertySemiStructuredNode.nodes.get(title_="sir") + from_get = DBNamePropertySemiStructuredNode.nodes.get(title_="sir") assert from_get.title_ == "sir" # assert not hasattr(from_get, "title") assert from_get.extra == "data" diff --git a/test/sync_/test_registry.py b/test/sync_/test_registry.py index d47b0ca8..c4413ef9 100644 --- a/test/sync_/test_registry.py +++ b/test/sync_/test_registry.py @@ -2,21 +2,8 @@ from pytest import raises, skip -from neomodel import ( - DateProperty, - IntegerProperty, - RelationshipTo, - StringProperty, - StructuredNode, - StructuredRel, - config, - db, -) -from neomodel.exceptions import ( - NodeClassAlreadyDefined, - NodeClassNotDefined, - RelationshipClassRedefined, -) +from neomodel import StringProperty, StructuredNode, db, get_config +from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined @mark_sync_test @@ -63,9 +50,10 @@ class PatientOneBis(StructuredNode): PatientOneBis(name="patient1.2").save() + config = get_config() # Now, we will test object resolution db.close_connection() - db.set_connection(url=f"{config.DATABASE_URL}/{db_one}") + db.set_connection(url=f"{config.database_url}/{db_one}") db.clear_neo4j_database() patient1 = PatientOne(name="patient1").save() patients, _ = db.cypher_query("MATCH (n:Patient) RETURN n", resolve_objects=True) @@ -73,14 +61,14 @@ class PatientOneBis(StructuredNode): assert patients[0][0] == patient1 db.close_connection() - db.set_connection(url=f"{config.DATABASE_URL}/{db_two}") + db.set_connection(url=f"{config.database_url}/{db_two}") db.clear_neo4j_database() patient2 = PatientTwo(identifier="patient2").save() patients, _ = db.cypher_query("MATCH (n:Patient) RETURN n", resolve_objects=True) assert patients[0][0] == patient2 db.close_connection() - db.set_connection(url=config.DATABASE_URL) + db.set_connection(url=config.database_url) @mark_sync_test diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index 71ce479f..827d3ba0 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -127,14 +127,14 @@ def test_bookmark_transaction_decorator(): def test_bookmark_transaction_as_a_context(): with db.transaction as transaction: APerson(name="Tanya").save() - assert isinstance(transaction.last_bookmark, Bookmarks) + assert isinstance(transaction.last_bookmarks, Bookmarks) assert APerson.nodes.filter(name="Tanya") with raises(UniqueProperty): with db.transaction as transaction: APerson(name="Tanya").save() - assert not hasattr(transaction, "last_bookmark") + assert transaction.last_bookmarks is None @pytest.fixture @@ -157,14 +157,14 @@ def test_bookmark_passed_in_to_context(spy_on_db_begin): pass assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) - last_bookmark = transaction.last_bookmark + last_bookmarks = transaction.last_bookmarks - transaction.bookmarks = last_bookmark + transaction.bookmarks = last_bookmarks with transaction: pass assert spy_on_db_begin[-1] == ( (), - {"access_mode": None, "bookmarks": last_bookmark}, + {"access_mode": None, "bookmarks": last_bookmarks}, ) @@ -176,4 +176,4 @@ def test_query_inside_bookmark_transaction(): assert len([p.name for p in APerson.nodes]) == 2 - assert isinstance(transaction.last_bookmark, Bookmarks) + assert isinstance(transaction.last_bookmarks, Bookmarks) diff --git a/test/sync_/test_vectorfilter.py b/test/sync_/test_vectorfilter.py index e9e408a6..207bedb5 100644 --- a/test/sync_/test_vectorfilter.py +++ b/test/sync_/test_vectorfilter.py @@ -184,7 +184,7 @@ class ProductV(StructuredNode): description_embedding = ArrayProperty( FloatProperty(), vector_index=VectorIndex(dimensions=2) ) - suppliers = RelationshipFrom(SupplierV, "SUPPLIES", model=SuppliesVRel) + suppliers = RelationshipFrom(SupplierV, "SUPPLIESV", model=SuppliesVRel) db.install_labels(SupplierV) db.install_labels(SuppliesVRel) diff --git a/test/test_config_modernization.py b/test/test_config_modernization.py new file mode 100644 index 00000000..ac72831a --- /dev/null +++ b/test/test_config_modernization.py @@ -0,0 +1,677 @@ +""" +Tests for the modernized configuration system. + +This module tests the new dataclass-based configuration system with validation +and environment variable support. +""" + +import os +import warnings +from unittest.mock import patch + +import pytest + +from neomodel import NeomodelConfig, config, get_config, reset_config, set_config +from neomodel.config import clear_deprecation_warnings + +# Type ignore for dynamic module attributes created by config module replacement +# pylint: disable=no-member + + +class TestNeomodelConfig: + """Test the NeomodelConfig dataclass.""" + + def test_default_config(self): + """Test default configuration values.""" + config_obj = NeomodelConfig() + + assert config_obj.database_url == "bolt://neo4j:foobarbaz@localhost:7687" + assert config_obj.force_timezone is False + assert config_obj.soft_cardinality_check is False + assert config_obj.cypher_debug is False + assert config_obj.slow_queries == 0.0 + assert config_obj.connection_timeout == 30.0 + assert config_obj.max_connection_pool_size == 100 + + def test_config_validation(self): + """Test configuration validation.""" + # Test valid configuration + config_obj = NeomodelConfig( + database_url="bolt://test:test@localhost:7687", connection_timeout=60.0 + ) + assert config_obj.connection_timeout == 60.0 + + # Test invalid configuration + with pytest.raises(ValueError, match="connection_timeout must be positive"): + NeomodelConfig(connection_timeout=-1) + + with pytest.raises(ValueError, match="Invalid database URL"): + NeomodelConfig(database_url="invalid-url") + + # Test slow_queries validation + with pytest.raises(ValueError, match="slow_queries must be non-negative"): + NeomodelConfig(slow_queries=-1.0) + + # Test additional validation branches + with pytest.raises( + ValueError, match="connection_acquisition_timeout must be positive" + ): + NeomodelConfig(connection_acquisition_timeout=-1) + + with pytest.raises( + ValueError, match="max_connection_lifetime must be positive" + ): + NeomodelConfig(max_connection_lifetime=-1) + + with pytest.raises( + ValueError, match="max_connection_pool_size must be positive" + ): + NeomodelConfig(max_connection_pool_size=-1) + + with pytest.raises( + ValueError, match="max_transaction_retry_time must be positive" + ): + NeomodelConfig(max_transaction_retry_time=-1) + + # Test database URL validation with invalid format + with pytest.raises(ValueError, match="Invalid database URL format"): + NeomodelConfig(database_url="invalid") + + # Test database URL validation with missing scheme + with pytest.raises(ValueError, match="Invalid database URL format"): + NeomodelConfig(database_url="localhost:7687") + + # Test database URL validation with missing netloc + with pytest.raises(ValueError, match="Invalid database URL format"): + NeomodelConfig(database_url="bolt://") + + # Test database URL validation with exception handling + # Use a URL that will cause urlparse to raise an exception + with pytest.raises(ValueError, match="Invalid database URL"): + NeomodelConfig(database_url="bolt://[invalid") + + def test_from_env(self): + """Test loading configuration from environment variables.""" + env_vars = { + "NEOMODEL_DATABASE_URL": "bolt://env:test@localhost:7687", + "NEOMODEL_FORCE_TIMEZONE": "true", + "NEOMODEL_CONNECTION_TIMEOUT": "45.0", + "NEOMODEL_MAX_CONNECTION_POOL_SIZE": "50", + "NEOMODEL_CYPHER_DEBUG": "true", + "NEOMODEL_SLOW_QUERIES": "1.5", + } + + with patch.dict(os.environ, env_vars): + config_obj = NeomodelConfig.from_env() + + assert config_obj.database_url == "bolt://env:test@localhost:7687" + assert config_obj.force_timezone is True + assert config_obj.connection_timeout == 45.0 + assert config_obj.max_connection_pool_size == 50 + assert config_obj.cypher_debug is True + assert config_obj.slow_queries == 1.5 + + def test_to_dict(self): + """Test converting configuration to dictionary.""" + config_obj = NeomodelConfig( + database_url="bolt://test:test@localhost:7687", force_timezone=True + ) + + config_dict = config_obj.to_dict() + + assert config_dict["database_url"] == "bolt://test:test@localhost:7687" + assert config_dict["force_timezone"] is True + assert "driver" not in config_dict # Non-serializable values excluded + + def test_to_dict_excludes_non_serializable(self): + """Test that to_dict excludes non-serializable values.""" + config_obj = NeomodelConfig() + config_dict = config_obj.to_dict() + + # These should be excluded from serialization + excluded_fields = ["driver", "resolver", "trusted_certificates"] + for field in excluded_fields: + assert field not in config_dict + + # These should be included + included_fields = ["database_url", "force_timezone", "connection_timeout"] + for field in included_fields: + assert field in config_dict + + def test_update(self): + """Test updating configuration values.""" + config_obj = NeomodelConfig() + + config_obj.update( + database_url="bolt://updated:test@localhost:7687", force_timezone=True + ) + + assert config_obj.database_url == "bolt://updated:test@localhost:7687" + assert config_obj.force_timezone is True + + # Test validation on update + with pytest.raises(ValueError): + config_obj.update(connection_timeout=-1) + + def test_update_with_unknown_option(self): + """Test update method with unknown configuration option.""" + config_obj = NeomodelConfig() + + # This should trigger a warning but not fail + with pytest.warns( + UserWarning, match="Unknown configuration option: unknown_field" + ): + config_obj.update(unknown_field="value") + + # Original values should remain unchanged + assert config_obj.database_url == "bolt://neo4j:foobarbaz@localhost:7687" + + def test_setattr_initialization(self): + """Test __setattr__ method initialization logic.""" + # Test that _initialized is set correctly + config_obj = NeomodelConfig() + + # First attribute set should mark as initialized + config_obj.database_url = "bolt://test:test@localhost:7687" + assert hasattr(config_obj, "_initialized") + assert config_obj._initialized is True # pylint: disable=protected-access + + # Setting _initialized itself should not trigger validation + config_obj._initialized = False # pylint: disable=protected-access + assert config_obj._initialized is False # pylint: disable=protected-access + + def test_setattr_validation_skip(self): + """Test that validation is skipped during initialization.""" + # Create config without triggering validation during init + config_obj = NeomodelConfig() + + # Should not raise validation error during attribute setting + # because validation is skipped when _initialized is not set + config_obj.connection_timeout = -1 # This should not raise immediately + + # But validation should occur when explicitly called + with pytest.raises(ValueError): + config_obj._validate_config() # pylint: disable=protected-access + + +class TestBackwardCompatibility: + """Test backward compatibility with existing config usage.""" + + def test_module_level_access(self): + """Test that module-level attributes work as before.""" + # Test reading values + assert isinstance(config.DATABASE_URL, str) + assert isinstance(config.FORCE_TIMEZONE, bool) + assert isinstance(config.SOFT_CARDINALITY_CHECK, bool) + assert isinstance(config.CYPHER_DEBUG, bool) # type: ignore[attr-defined] + assert isinstance(config.SLOW_QUERIES, float) # type: ignore[attr-defined] + + # Test setting values + original_url = config.DATABASE_URL + config.DATABASE_URL = "bolt://test:test@localhost:7687" + assert config.DATABASE_URL == "bolt://test:test@localhost:7687" + + # Restore original value + config.DATABASE_URL = original_url + + def test_all_property_setters(self): + """Test all property setters for backward compatibility.""" + # Test DRIVER setter + original_driver = config.DRIVER + config.DRIVER = None + assert config.DRIVER is None + config.DRIVER = original_driver + + # Test DATABASE_NAME setter + original_name = config.DATABASE_NAME + config.DATABASE_NAME = "test_db" + assert config.DATABASE_NAME == "test_db" + config.DATABASE_NAME = original_name + + # Test CONNECTION_ACQUISITION_TIMEOUT setter + original_timeout = config.CONNECTION_ACQUISITION_TIMEOUT + config.CONNECTION_ACQUISITION_TIMEOUT = 120.0 + assert config.CONNECTION_ACQUISITION_TIMEOUT == 120.0 + config.CONNECTION_ACQUISITION_TIMEOUT = original_timeout + + # Test MAX_CONNECTION_LIFETIME setter + original_lifetime = config.MAX_CONNECTION_LIFETIME + config.MAX_CONNECTION_LIFETIME = 7200 + assert config.MAX_CONNECTION_LIFETIME == 7200 + config.MAX_CONNECTION_LIFETIME = original_lifetime + + # Test MAX_TRANSACTION_RETRY_TIME setter + original_retry = config.MAX_TRANSACTION_RETRY_TIME + config.MAX_TRANSACTION_RETRY_TIME = 60.0 + assert config.MAX_TRANSACTION_RETRY_TIME == 60.0 + config.MAX_TRANSACTION_RETRY_TIME = original_retry + + # Test RESOLVER setter + original_resolver = config.RESOLVER + config.RESOLVER = None + assert config.RESOLVER is None + config.RESOLVER = original_resolver + + # Test TRUSTED_CERTIFICATES setter + original_certs = config.TRUSTED_CERTIFICATES + config.TRUSTED_CERTIFICATES = None + assert config.TRUSTED_CERTIFICATES is None + config.TRUSTED_CERTIFICATES = original_certs + + # Test USER_AGENT setter + original_agent = config.USER_AGENT + config.USER_AGENT = "custom-agent/2.0" + assert config.USER_AGENT == "custom-agent/2.0" + config.USER_AGENT = original_agent + + # Test ENCRYPTED setter + original_encrypted = config.ENCRYPTED + config.ENCRYPTED = True + assert config.ENCRYPTED is True + config.ENCRYPTED = original_encrypted + + # Test KEEP_ALIVE setter + original_keep_alive = config.KEEP_ALIVE + config.KEEP_ALIVE = False + assert config.KEEP_ALIVE is False + config.KEEP_ALIVE = original_keep_alive + + # Test CYPHER_DEBUG setter + original_cypher_debug = config.CYPHER_DEBUG + config.CYPHER_DEBUG = True + assert config.CYPHER_DEBUG is True + config.CYPHER_DEBUG = original_cypher_debug + + # Test SLOW_QUERIES setter + original_slow_queries = config.SLOW_QUERIES + config.SLOW_QUERIES = 5.0 + assert config.SLOW_QUERIES == 5.0 + config.SLOW_QUERIES = original_slow_queries + + def test_custom_driver_configuration(self): + """Test configuration with a custom Neo4j driver.""" + from unittest.mock import Mock + + # Create a mock driver + mock_driver = Mock() + mock_driver.close = Mock() + + # Test setting driver via NeomodelConfig + config_obj = NeomodelConfig(driver=mock_driver) + assert config_obj.driver is mock_driver + + # Test setting driver via module-level attribute + original_driver = config.DRIVER + config.DRIVER = mock_driver + assert config.DRIVER is mock_driver + + # Test that driver is accessible through the config + current_config = get_config() + assert current_config.driver is mock_driver + + # Test that driver is excluded from serialization + config_dict = config_obj.to_dict() + assert "driver" not in config_dict + + # Restore original driver + config.DRIVER = original_driver + + def test_validation_on_set(self): + """Test that validation occurs when setting module-level attributes.""" + with pytest.raises(ValueError, match="connection_timeout must be positive"): + config.CONNECTION_TIMEOUT = -1 + + def test_validation_revert_on_error(self): + """Test that values are reverted when validation fails.""" + original_timeout = config.CONNECTION_TIMEOUT + + # This should fail and revert the value + with pytest.raises(ValueError): + config.CONNECTION_TIMEOUT = -1 + + # Value should be reverted to original + assert config.CONNECTION_TIMEOUT == original_timeout + + def test_validation_revert_multiple_attributes(self): + """Test validation and revert for multiple attributes.""" + original_values = { + "CONNECTION_TIMEOUT": config.CONNECTION_TIMEOUT, + "MAX_CONNECTION_POOL_SIZE": config.MAX_CONNECTION_POOL_SIZE, + } + + # Test that each invalid value is reverted + with pytest.raises(ValueError): + config.CONNECTION_TIMEOUT = -1 + + assert config.CONNECTION_TIMEOUT == original_values["CONNECTION_TIMEOUT"] + + with pytest.raises(ValueError): + config.MAX_CONNECTION_POOL_SIZE = -1 + + assert ( + config.MAX_CONNECTION_POOL_SIZE + == original_values["MAX_CONNECTION_POOL_SIZE"] + ) + + def test_unknown_config_warning(self): + """Test warning for unknown configuration options.""" + with pytest.warns(UserWarning, match="Unknown configuration option"): + config_obj = get_config() + config_obj.update(unknown_option="value") + + +class TestGlobalConfigManagement: + """Test global configuration management functions.""" + + def test_get_set_config(self): + """Test getting and setting global configuration.""" + # Get default config + config_obj = get_config() + assert isinstance(config_obj, NeomodelConfig) + + # Set custom config + custom_config = NeomodelConfig(database_url="bolt://custom:test@localhost:7687") + set_config(custom_config) + + assert get_config().database_url == "bolt://custom:test@localhost:7687" + + def test_reset_config(self): + """Test resetting configuration to defaults.""" + # Set custom config + custom_config = NeomodelConfig(database_url="bolt://custom:test@localhost:7687") + set_config(custom_config) + + # Reset to defaults + reset_config() + + # Should load from environment or use defaults + config_obj = get_config() + assert isinstance(config_obj, NeomodelConfig) + + def test_get_config_singleton(self): + """Test that get_config returns the same instance.""" + config1 = get_config() + config2 = get_config() + assert config1 is config2 + + def test_set_config_replaces_singleton(self): + """Test that set_config replaces the global instance.""" + original_config = get_config() + custom_config = NeomodelConfig(database_url="bolt://custom:test@localhost:7687") + + set_config(custom_config) + assert get_config() is custom_config + assert get_config() is not original_config + + +class TestEnvironmentVariableSupport: + """Test environment variable support.""" + + def test_env_var_loading(self): + """Test loading configuration from environment variables.""" + env_vars = { + "NEOMODEL_DATABASE_URL": "bolt://env:test@localhost:7687", + "NEOMODEL_FORCE_TIMEZONE": "true", + "NEOMODEL_SOFT_CARDINALITY_CHECK": "true", + "NEOMODEL_CYPHER_DEBUG": "true", + "NEOMODEL_SLOW_QUERIES": "2.0", + } + + with patch.dict(os.environ, env_vars): + reset_config() # Force reload from environment + + assert config.DATABASE_URL == "bolt://env:test@localhost:7687" + assert config.FORCE_TIMEZONE is True + assert config.SOFT_CARDINALITY_CHECK is True + assert config.CYPHER_DEBUG is True + assert config.SLOW_QUERIES == 2.0 + + def test_env_var_type_conversion(self): + """Test type conversion for environment variables.""" + env_vars = { + "NEOMODEL_CONNECTION_TIMEOUT": "60.0", + "NEOMODEL_MAX_CONNECTION_POOL_SIZE": "200", + "NEOMODEL_ENCRYPTED": "true", + "NEOMODEL_KEEP_ALIVE": "false", + "NEOMODEL_CYPHER_DEBUG": "false", + "NEOMODEL_SLOW_QUERIES": "0.5", + } + + with patch.dict(os.environ, env_vars): + reset_config() + + assert config.CONNECTION_TIMEOUT == 60.0 + assert config.MAX_CONNECTION_POOL_SIZE == 200 + assert config.ENCRYPTED is True # type: ignore[attr-defined] + assert config.KEEP_ALIVE is False # type: ignore[attr-defined] + assert config.CYPHER_DEBUG is False # type: ignore[attr-defined] + assert config.SLOW_QUERIES == 0.5 # type: ignore[attr-defined] + + def test_env_var_boolean_conversion(self): + """Test boolean environment variable conversion edge cases.""" + # Test various boolean representations + boolean_tests = [ + ("true", True), + ("1", True), + ("yes", True), + ("on", True), + ("false", False), + ("0", False), + ("no", False), + ("off", False), + ("TRUE", True), # Case insensitive + ("FALSE", False), + ] + + for env_value, expected in boolean_tests: + env_vars = { + "NEOMODEL_FORCE_TIMEZONE": env_value, + "NEOMODEL_ENCRYPTED": env_value, + "NEOMODEL_KEEP_ALIVE": env_value, + "NEOMODEL_SOFT_CARDINALITY_CHECK": env_value, + "NEOMODEL_CYPHER_DEBUG": env_value, + } + + with patch.dict(os.environ, env_vars): + reset_config() + assert config.FORCE_TIMEZONE == expected + assert config.ENCRYPTED == expected # type: ignore[attr-defined] + assert config.KEEP_ALIVE == expected # type: ignore[attr-defined] + assert config.SOFT_CARDINALITY_CHECK == expected + assert config.CYPHER_DEBUG == expected # type: ignore[attr-defined] + + def test_env_var_numeric_conversion(self): + """Test numeric environment variable conversion.""" + env_vars = { + "NEOMODEL_CONNECTION_TIMEOUT": "45.5", + "NEOMODEL_MAX_CONNECTION_POOL_SIZE": "150", + "NEOMODEL_MAX_CONNECTION_LIFETIME": "7200", + "NEOMODEL_MAX_TRANSACTION_RETRY_TIME": "60.0", + "NEOMODEL_SLOW_QUERIES": "2.5", + } + + with patch.dict(os.environ, env_vars): + reset_config() + assert config.CONNECTION_TIMEOUT == 45.5 + assert config.MAX_CONNECTION_POOL_SIZE == 150 + assert config.MAX_CONNECTION_LIFETIME == 7200 # type: ignore[attr-defined] + assert config.MAX_TRANSACTION_RETRY_TIME == 60.0 # type: ignore[attr-defined] + assert config.SLOW_QUERIES == 2.5 # type: ignore[attr-defined] + + def test_env_var_string_conversion(self): + """Test string environment variable handling.""" + env_vars = { + "NEOMODEL_DATABASE_URL": "bolt://custom:password@localhost:7687", + "NEOMODEL_DATABASE_NAME": "test_database", + "NEOMODEL_USER_AGENT": "custom-agent/1.0", + } + + with patch.dict(os.environ, env_vars): + reset_config() + assert config.DATABASE_URL == "bolt://custom:password@localhost:7687" + assert config.DATABASE_NAME == "test_database" # type: ignore[attr-defined] + assert config.USER_AGENT == "custom-agent/1.0" # type: ignore[attr-defined] + + def test_env_var_missing_fields(self): + """Test that missing environment variables use defaults.""" + # Clear all neomodel environment variables + env_vars = {} + for key in list(os.environ.keys()): + if key.startswith("NEOMODEL_"): + env_vars[key] = None # Remove from environment + + with patch.dict(os.environ, env_vars, clear=True): + reset_config() + # Should use default values + assert config.DATABASE_URL == "bolt://neo4j:foobarbaz@localhost:7687" + assert config.FORCE_TIMEZONE is False + assert config.CONNECTION_TIMEOUT == 30.0 + assert config.MAX_CONNECTION_POOL_SIZE == 100 + + +class TestIntegration: + """Test integration with existing neomodel functionality.""" + + def test_config_with_properties(self): + """Test that configuration works with neomodel properties.""" + from datetime import datetime + + from neomodel.properties import DateTimeProperty + + # Test FORCE_TIMEZONE functionality + prop = DateTimeProperty() + + # Default should not raise error + config.FORCE_TIMEZONE = False + result = prop.deflate(datetime.now()) + assert result is not None + + # With FORCE_TIMEZONE=True, should raise error for naive datetime + config.FORCE_TIMEZONE = True + with pytest.raises(Exception): # May be ValueError or DeflateError + prop.deflate(datetime.now()) + + # Restore default + config.FORCE_TIMEZONE = False + + def test_config_with_cardinality(self): + """Test that configuration works with cardinality checking.""" + # Test SOFT_CARDINALITY_CHECK functionality + original_value = config.SOFT_CARDINALITY_CHECK + + config.SOFT_CARDINALITY_CHECK = True + assert config.SOFT_CARDINALITY_CHECK is True + + config.SOFT_CARDINALITY_CHECK = False + assert config.SOFT_CARDINALITY_CHECK is False + + # Restore original value + config.SOFT_CARDINALITY_CHECK = original_value + + +class TestDeprecationWarnings: + """Test deprecation warnings for legacy configuration access.""" + + def setup_method(self): + """Clear deprecation warnings before each test.""" + clear_deprecation_warnings() + + def test_deprecation_warning_on_get(self): + """Test that deprecation warnings are issued when accessing legacy attributes.""" + with pytest.warns( + DeprecationWarning, match="Accessing config.DATABASE_URL is deprecated" + ): + _ = config.DATABASE_URL + + with pytest.warns( + DeprecationWarning, match="Accessing config.FORCE_TIMEZONE is deprecated" + ): + _ = config.FORCE_TIMEZONE + + with pytest.warns( + DeprecationWarning, match="Accessing config.CYPHER_DEBUG is deprecated" + ): + _ = config.CYPHER_DEBUG + + def test_deprecation_warning_on_set(self): + """Test that deprecation warnings are issued when setting legacy attributes.""" + with pytest.warns( + DeprecationWarning, match="Setting config.DATABASE_URL is deprecated" + ): + config.DATABASE_URL = "bolt://test:test@localhost:7687" + + with pytest.warns( + DeprecationWarning, match="Setting config.FORCE_TIMEZONE is deprecated" + ): + config.FORCE_TIMEZONE = True + + with pytest.warns( + DeprecationWarning, match="Setting config.SLOW_QUERIES is deprecated" + ): + config.SLOW_QUERIES = 1.0 + + def test_deprecation_warning_only_once_per_attribute(self): + """Test that deprecation warnings are only shown once per attribute.""" + # First access should show warning + with pytest.warns(DeprecationWarning): + _ = config.DATABASE_URL + + # Second access should not show warning + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + _ = config.DATABASE_URL # Should not raise + + def test_deprecation_warning_message_content(self): + """Test that deprecation warning messages contain helpful migration information.""" + with pytest.warns(DeprecationWarning) as warning_info: + _ = config.DATABASE_URL + + warning = warning_info[0] + assert "Accessing config.DATABASE_URL is deprecated" in str(warning.message) + assert "from neomodel import get_config" in str(warning.message) + assert "config.database_url" in str(warning.message) + + def test_deprecation_warning_message_for_setting(self): + """Test that deprecation warning messages for setting contain helpful migration information.""" + with pytest.warns(DeprecationWarning) as warning_info: + config.FORCE_TIMEZONE = True + + warning = warning_info[0] + assert "Setting config.FORCE_TIMEZONE is deprecated" in str(warning.message) + assert "from neomodel import get_config" in str(warning.message) + assert "config.force_timezone = value" in str(warning.message) + + def test_clear_deprecation_warnings_resets_state(self): + """Test that clear_deprecation_warnings resets the warning state.""" + # First access should show warning + with pytest.warns(DeprecationWarning): + _ = config.DATABASE_URL + + # Clear warnings + clear_deprecation_warnings() + + # Next access should show warning again + with pytest.warns(DeprecationWarning): + _ = config.DATABASE_URL + + def test_modern_api_no_deprecation_warnings(self): + """Test that the modern API does not trigger deprecation warnings.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + + # Modern API should not trigger warnings + config_obj = get_config() + _ = config_obj.database_url + _ = config_obj.force_timezone + _ = config_obj.cypher_debug + + config_obj.database_url = "bolt://test:test@localhost:7687" + config_obj.force_timezone = True + config_obj.slow_queries = 1.0 + + # Pytest fixture to run after the last test in this file and reset config to default + # This prevents interference for subsequent tests in other files. + @pytest.fixture(autouse=True) + def teardown(self): + yield + reset_config() diff --git a/test/test_exceptions_additional.py b/test/test_exceptions_additional.py new file mode 100644 index 00000000..00d7db63 --- /dev/null +++ b/test/test_exceptions_additional.py @@ -0,0 +1,334 @@ +""" +Additional tests for neomodel.exceptions module to improve coverage. +""" + +import pytest + +from neomodel import StructuredNode +from neomodel.exceptions import ( + AttemptedCardinalityViolation, + CardinalityViolation, + ConstraintValidationFailed, + DeflateConflict, + DeflateError, + DoesNotExist, + FeatureNotSupported, + InflateConflict, + InflateError, + ModelDefinitionException, + MultipleNodesReturned, + NeomodelException, + NotConnected, + RequiredProperty, + UniqueProperty, +) + + +def test_neomodel_exception(): + """Test NeomodelException base class.""" + exc = NeomodelException("Test message") + assert str(exc) == "Test message" + assert isinstance(exc, Exception) + + +def test_attempted_cardinality_violation(): + """Test AttemptedCardinalityViolation exception.""" + exc = AttemptedCardinalityViolation("Test message") + assert str(exc) == "Test message" + assert isinstance(exc, NeomodelException) + + +def test_cardinality_violation(): + """Test CardinalityViolation exception.""" + exc = CardinalityViolation("Test rel_manager", 5) + assert exc.rel_manager == "Test rel_manager" + assert exc.actual == "5" + assert "Expected: Test rel_manager, got: 5" in str(exc) + + +def test_cardinality_violation_str(): + """Test CardinalityViolation string representation.""" + exc = CardinalityViolation("OneOrMore", 0) + assert "Expected: OneOrMore, got: 0" in str(exc) + + +def test_model_definition_exception(): + """Test ModelDefinitionException initialization.""" + db_node = {"labels": ["TestNode"]} + registry = {frozenset(["TestNode"]): "TestClass"} + db_registry = {"test_db": {frozenset(["TestNode"]): "TestClass"}} + + exc = ModelDefinitionException(db_node, registry, db_registry) + assert exc.db_node_rel_class is db_node + assert exc.current_node_class_registry is registry + assert exc.current_db_specific_node_class_registry is db_registry + + +def test_model_definition_exception_get_node_class_registry_formatted(): + """Test ModelDefinitionException _get_node_class_registry_formatted method.""" + db_node = {"labels": ["TestNode"]} + registry = {frozenset(["TestNode"]): "TestClass"} + db_registry = {"test_db": {frozenset(["TestNode"]): "TestClass"}} + + exc = ModelDefinitionException(db_node, registry, db_registry) + formatted = exc._get_node_class_registry_formatted() # type: ignore + + assert "TestNode --> TestClass" in formatted + assert "Database-specific: test_db" in formatted + + +def test_constraint_validation_failed(): + """Test ConstraintValidationFailed exception.""" + exc = ConstraintValidationFailed("Test constraint error") + assert str(exc) == "Test constraint error" + assert isinstance(exc, NeomodelException) + + +def test_unique_property(): + """Test UniqueProperty exception.""" + exc = UniqueProperty("Test unique constraint error") + assert exc.message == "Test unique constraint error" + assert isinstance(exc, ConstraintValidationFailed) + + +def test_required_property(): + """Test RequiredProperty exception.""" + + class TestClass: + pass + + exc = RequiredProperty("test_prop", TestClass) + assert exc.property_name == "test_prop" + assert exc.node_class is TestClass + assert "property 'test_prop' on objects of class TestClass" in str(exc) + + +def test_required_property_str(): + """Test RequiredProperty string representation.""" + + class TestClass: + pass + + exc = RequiredProperty("age", TestClass) + assert "property 'age' on objects of class TestClass" in str(exc) + + +def test_inflate_error(): + """Test InflateError exception.""" + + class TestClass: + pass + + exc = InflateError("test_prop", TestClass, "Test inflate error", "test_obj") + assert exc.property_name == "test_prop" + assert exc.node_class is TestClass + assert exc.msg == "Test inflate error" + assert exc.obj == "'test_obj'" + assert ( + "property 'test_prop' on 'test_obj' of class 'TestClass': Test inflate error" + in str(exc) + ) + + +def test_inflate_error_str(): + """Test InflateError string representation.""" + + class TestClass: + pass + + exc = InflateError("name", TestClass, "Invalid value", None) + assert "property 'name' on None of class 'TestClass': Invalid value" in str(exc) + + +def test_deflate_error(): + """Test DeflateError exception.""" + + class TestClass: + pass + + exc = DeflateError("test_prop", TestClass, "Test deflate error", "test_obj") + assert exc.property_name == "test_prop" + assert exc.node_class is TestClass + assert exc.msg == "Test deflate error" + assert exc.obj == "'test_obj'" + assert ( + "property 'test_prop' on 'test_obj' of class 'TestClass': Test deflate error" + in str(exc) + ) + + +def test_deflate_error_str(): + """Test DeflateError string representation.""" + + class TestClass: + pass + + exc = DeflateError("age", TestClass, "Invalid age", None) + assert "property 'age' on None of class 'TestClass': Invalid age" in str(exc) + + +def test_inflate_conflict(): + """Test InflateConflict exception.""" + + class TestClass: + pass + + exc = InflateConflict(TestClass, "test_prop", "test_value", "test_nid") + assert ( + str(exc) + == "Found conflict with node test_nid, has property 'test_prop' with value 'test_value' although class TestClass already has a property 'test_prop'" + ) + assert isinstance(exc, NeomodelException) + + +def test_deflate_conflict(): + """Test DeflateConflict exception.""" + + class TestClass: + pass + + exc = DeflateConflict(TestClass, "test_prop", "test_value", "test_nid") + assert ( + str(exc) + == "Found trying to set property 'test_prop' with value 'test_value' on node test_nid although class TestClass already has a property 'test_prop'" + ) + assert isinstance(exc, NeomodelException) + + +def test_multiple_nodes_returned(): + """Test MultipleNodesReturned exception.""" + exc = MultipleNodesReturned("Test multiple nodes error") + assert str(exc) == "Test multiple nodes error" + assert isinstance(exc, NeomodelException) + + +def test_does_not_exist(): + """Test DoesNotExist exception.""" + + class TestClass: + pass + + # Set the model class + DoesNotExist._model_class = TestClass + + exc = DoesNotExist("Test does not exist error") + assert exc.message == "Test does not exist error" + assert isinstance(exc, NeomodelException) + + +def test_does_not_exist_no_model_class(): + """Test DoesNotExist exception without model class.""" + # Reset the model class + DoesNotExist._model_class = None + + with pytest.raises(RuntimeError, match="This class hasn't been setup properly"): + DoesNotExist("Test error") + + +def test_not_connected(): + """Test NotConnected exception.""" + + class MockNode1: + def __init__(self, element_id): + self.element_id = element_id + + class MockNode2: + def __init__(self, element_id): + self.element_id = element_id + + node1 = MockNode1("id1") + node2 = MockNode2("id2") + exc = NotConnected("connect", node1, node2) + assert exc.action == "connect" + assert exc.node1 is node1 + assert exc.node2 is node2 + assert ( + "Error performing 'connect' - Node id1 of type 'MockNode1' is not connected to id2 of type 'MockNode2'" + in str(exc) + ) + + +def test_not_connected_str(): + """Test NotConnected string representation.""" + + class MockNode1: + def __init__(self, element_id): + self.element_id = element_id + + class MockNode2: + def __init__(self, element_id): + self.element_id = element_id + + exc = NotConnected("delete", MockNode1("id1"), MockNode2("id2")) + assert ( + "Error performing 'delete' - Node id1 of type 'MockNode1' is not connected to id2 of type 'MockNode2'" + in str(exc) + ) + + +def test_feature_not_supported(): + """Test FeatureNotSupported exception.""" + exc = FeatureNotSupported("Test feature not supported") + assert exc.message == "Test feature not supported" + assert isinstance(exc, NeomodelException) + + +def test_feature_not_supported_str(): + """Test FeatureNotSupported string representation.""" + exc = FeatureNotSupported("Vector indexes not supported in this version") + assert exc.message == "Vector indexes not supported in this version" + + +def test_exception_string_representations(): + """Test string representations of various exceptions.""" + # Test with different message types + exc = NeomodelException("Simple message") + assert str(exc) == "Simple message" + + exc = NeomodelException("") + assert str(exc) == "" + + exc = NeomodelException(None) + assert str(exc) == "None" + + +def test_cardinality_violation_with_different_types(): + """Test CardinalityViolation with different actual value types.""" + exc = CardinalityViolation("OneOrMore", "zero") + assert exc.actual == "zero" + assert "Expected: OneOrMore, got: zero" in str(exc) + + exc = CardinalityViolation("One", 1) + assert exc.actual == "1" + assert "Expected: One, got: 1" in str(exc) + + +def test_model_definition_exception_with_empty_registries(): + """Test ModelDefinitionException with empty registries.""" + db_node = {"labels": ["TestNode"]} + registry = {} + db_registry = {} + + exc = ModelDefinitionException(db_node, registry, db_registry) + formatted = exc._get_node_class_registry_formatted() # type: ignore + assert formatted == "" + + +def test_model_definition_exception_with_multiple_entries(): + """Test ModelDefinitionException with multiple registry entries.""" + db_node = {"labels": ["TestNode"]} + registry = {frozenset(["Node1"]): "Class1", frozenset(["Node2"]): "Class2"} + db_registry = { + "db1": {frozenset(["Node3"]): "Class3"}, + "db2": {frozenset(["Node4"]): "Class4"}, + } + + exc = ModelDefinitionException(db_node, registry, db_registry) + formatted = exc._get_node_class_registry_formatted() # type: ignore + + assert "Node1 --> Class1" in formatted + assert "Node2 --> Class2" in formatted + assert "Database-specific: db1" in formatted + assert "Database-specific: db2" in formatted + assert "Node3 --> Class3" in formatted + assert "Node4 --> Class4" in formatted diff --git a/test/test_properties_additional.py b/test/test_properties_additional.py new file mode 100644 index 00000000..0632bb95 --- /dev/null +++ b/test/test_properties_additional.py @@ -0,0 +1,278 @@ +""" +Additional tests for neomodel.properties module to improve coverage. +""" + +import pytest + +from neomodel.properties import ( + FulltextIndex, + Property, + StringProperty, + VectorIndex, + validator, +) + + +def test_validator_decorator_invalid_function(): + """Test validator decorator with invalid function name.""" + + def invalid_function(): + pass + + # This should raise a ValueError because the function name is not "inflate" or "deflate" + with pytest.raises(ValueError, match="Unknown Property method"): + validator(invalid_function) + + +def test_validator_decorator_inflate_error(): + """Test validator decorator with inflate function that raises an exception.""" + + class TestProperty(Property): + def inflate(self, value, rethrow=False): + raise ValueError("Test error") + + def deflate(self, value, rethrow=False): + return value + + # Apply validator to inflate method + TestProperty.inflate = validator(TestProperty.inflate) + + prop = TestProperty() + prop.name = "test" + prop.owner = "TestClass" + + with pytest.raises(Exception): # Should raise InflateError + prop.inflate("test_value") + + +def test_validator_decorator_deflate_error(): + """Test validator decorator with deflate function that raises an exception.""" + + class TestProperty(Property): + def inflate(self, value, rethrow=False): + return value + + def deflate(self, value, rethrow=False): + raise ValueError("Test error") + + # Apply validator to deflate method + TestProperty.deflate = validator(TestProperty.deflate) + + prop = TestProperty() + prop.name = "test" + prop.owner = "TestClass" + + with pytest.raises(Exception): # Should raise DeflateError + prop.deflate("test_value") + + +def test_fulltext_index_initialization(): + """Test FulltextIndex initialization.""" + fti = FulltextIndex() + assert fti.analyzer == "standard-no-stop-words" + assert fti.eventually_consistent is False + + +def test_fulltext_index_initialization_with_params(): + """Test FulltextIndex initialization with parameters.""" + fti = FulltextIndex(analyzer="english", eventually_consistent=True) + assert fti.analyzer == "english" + assert fti.eventually_consistent is True + + +def test_vector_index_initialization(): + """Test VectorIndex initialization.""" + vi = VectorIndex() + assert vi.dimensions == 1536 + assert vi.similarity_function == "cosine" + + +def test_vector_index_initialization_with_params(): + """Test VectorIndex initialization with custom parameters.""" + vi = VectorIndex(dimensions=512, similarity_function="euclidean") + assert vi.dimensions == 512 + assert vi.similarity_function == "euclidean" + + +def test_property_initialization_mutually_exclusive_required_default(): + """Test Property initialization with mutually exclusive required and default.""" + with pytest.raises(ValueError, match="mutually exclusive"): + StringProperty(required=True, default="test") + + +def test_property_initialization_mutually_exclusive_unique_index_index(): + """Test Property initialization with mutually exclusive unique_index and index.""" + with pytest.raises(ValueError, match="mutually exclusive"): + StringProperty(unique_index=True, index=True) + + +def test_property_default_value_callable(): + """Test Property default_value with callable default.""" + + def get_default(): + return "callable_default" + + prop = StringProperty(default=get_default) + assert prop.default_value() == "callable_default" + + +def test_property_default_value_non_callable(): + """Test Property default_value with non-callable default.""" + prop = StringProperty(default="static_default") + assert prop.default_value() == "static_default" + + +def test_property_default_value_no_default(): + """Test Property default_value with no default.""" + prop = StringProperty() + with pytest.raises(ValueError, match="No default value specified"): + prop.default_value() + + +def test_property_get_db_property_name(): + """Test Property get_db_property_name method.""" + prop = StringProperty(db_property="db_name") + assert prop.get_db_property_name("attribute_name") == "db_name" + + prop_no_db = StringProperty() + assert prop_no_db.get_db_property_name("attribute_name") == "attribute_name" + + +def test_property_is_indexed(): + """Test Property is_indexed property.""" + prop_indexed = StringProperty(index=True) + assert prop_indexed.is_indexed is True + + prop_unique = StringProperty(unique_index=True) + assert prop_unique.is_indexed is True + + prop_not_indexed = StringProperty() + assert prop_not_indexed.is_indexed is False + + +def test_property_initialization_with_all_params(): + """Test Property initialization with all parameters.""" + fti = FulltextIndex() + vi = VectorIndex() + + prop = StringProperty( + name="test_prop", + owner="TestClass", + unique_index=True, + index=False, + fulltext_index=fti, + vector_index=vi, + required=True, + default=None, + db_property="db_prop", + label="Test Label", + help_text="Test help", + ) + + assert prop.name == "test_prop" + assert prop.owner == "TestClass" + assert prop.unique_index is True + assert prop.index is False + assert prop.fulltext_index is fti + assert prop.vector_index is vi + assert prop.required is True + assert prop.default is None + assert prop.db_property == "db_prop" + assert prop.label == "Test Label" + assert prop.help_text == "Test help" + + +def test_property_initialization_with_kwargs(): + """Test Property initialization with additional kwargs.""" + prop = StringProperty(custom_param="custom_value", another_param=123) + assert hasattr(prop, "custom_param") + assert hasattr(prop, "another_param") + assert getattr(prop, "custom_param") == "custom_value" + assert getattr(prop, "another_param") == 123 + + +def test_property_has_default(): + """Test Property has_default property.""" + prop_with_default = StringProperty(default="test") + assert prop_with_default.has_default is True + + prop_without_default = StringProperty() + assert prop_without_default.has_default is False + + +def test_property_initialization_edge_cases(): + """Test Property initialization with edge cases.""" + # Test with empty string default + prop = StringProperty(default="") + assert prop.has_default is True + assert prop.default_value() == "" + + # Test with None default + prop = StringProperty(default=None) + assert prop.has_default is False + with pytest.raises(ValueError, match="No default value specified"): + prop.default_value() + + +def test_property_initialization_with_indexes(): + """Test Property initialization with various index configurations.""" + # Test with only index + prop = StringProperty(index=True) + assert prop.index is True + assert prop.unique_index is False + assert prop.is_indexed is True + + # Test with only unique_index + prop = StringProperty(unique_index=True) + assert prop.unique_index is True + assert prop.index is False + assert prop.is_indexed is True + + # Test with neither + prop = StringProperty() + assert prop.index is False + assert prop.unique_index is False + assert prop.is_indexed is False + + +def test_property_initialization_with_required(): + """Test Property initialization with required parameter.""" + # Test with required=True + prop = StringProperty(required=True) + assert prop.required is True + + # Test with required=False + prop = StringProperty(required=False) + assert prop.required is False + + # Test default required value + prop = StringProperty() + assert prop.required is False + + +def test_property_initialization_with_fulltext_index(): + """Test Property initialization with fulltext_index.""" + fti = FulltextIndex(analyzer="test_analyzer") + prop = StringProperty(fulltext_index=fti) + assert prop.fulltext_index is fti + + +def test_property_initialization_with_vector_index(): + """Test Property initialization with vector_index.""" + vi = VectorIndex(dimensions=256) + prop = StringProperty(vector_index=vi) + assert prop.vector_index is vi + + +def test_property_initialization_with_db_property(): + """Test Property initialization with db_property.""" + prop = StringProperty(db_property="custom_db_name") + assert prop.db_property == "custom_db_name" + assert prop.get_db_property_name("attribute_name") == "custom_db_name" + + +def test_property_initialization_with_label_and_help_text(): + """Test Property initialization with label and help_text.""" + prop = StringProperty(label="Test Label", help_text="Test help text") + assert prop.label == "Test Label" + assert prop.help_text == "Test help text" diff --git a/test/test_scripts.py b/test/test_scripts.py index 23a973e7..ca7bfb5e 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -8,9 +8,9 @@ StringProperty, StructuredNode, StructuredRel, - config, + get_config, ) -from neomodel.sync_.core import db +from neomodel.sync_.database import db class ScriptsTestRel(StructuredRel): @@ -74,7 +74,7 @@ def test_neomodel_remove_labels(): assert result.returncode == 0 result = subprocess.run( - ["neomodel_remove_labels", "--db", config.DATABASE_URL], + ["neomodel_remove_labels", "--db", get_config().database_url], capture_output=True, text=True, check=False, @@ -141,7 +141,7 @@ def test_neomodel_inspect_database(script_flavour): ) # Test the console output version of the script - args_list = ["neomodel_inspect_database", "--db", config.DATABASE_URL] + args_list = ["neomodel_inspect_database", "--db", get_config().database_url] if script_flavour == "_light": args_list += ["--no-rel-props", "--no-rel-cardinality"] result = subprocess.run( diff --git a/test/test_util.py b/test/test_util.py new file mode 100644 index 00000000..d0833fb3 --- /dev/null +++ b/test/test_util.py @@ -0,0 +1,129 @@ +""" +Simple tests for neomodel.util module to improve coverage. +""" + +import inspect +import unittest +import warnings +from types import FrameType + +from neomodel.util import ( + RelationshipDirection, + _UnsavedNode, + classproperty, + deprecated, + enumerate_traceback, + get_graph_entity_properties, + version_tag_to_integer, +) + + +class TestUtil(unittest.TestCase): + """Test cases for neomodel.util module.""" + + def test_relationship_direction_enum(self): + """Test RelationshipDirection enum values.""" + self.assertEqual(RelationshipDirection.OUTGOING, 1) + self.assertEqual(RelationshipDirection.INCOMING, -1) + self.assertEqual(RelationshipDirection.EITHER, 0) + + def test_deprecated_decorator(self): + """Test the deprecated decorator functionality.""" + + @deprecated("This function is deprecated") + def deprecated_function(): + return "test" + + # Test that the function still works + self.assertEqual(deprecated_function(), "test") + + # Test that a deprecation warning is issued + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + deprecated_function() + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("This function is deprecated", str(w[0].message)) + + def test_classproperty(self): + """Test classproperty decorator.""" + + class TestClass: + _value = "test" + + @classproperty + def value(cls): + return cls._value + + self.assertEqual(TestClass.value, "test") + + # Test that it works on instances too + instance = TestClass() + self.assertEqual(instance.value, "test") + + def test_unsaved_node(self): + """Test _UnsavedNode class.""" + node = _UnsavedNode() + self.assertEqual(repr(node), "") + self.assertEqual(str(node), "") + + def test_get_graph_entity_properties(self): + """Test get_graph_entity_properties function.""" + + # Mock entity with properties + class MockEntity: + def __init__(self): + self._properties = {"name": "test", "age": 25} + + entity = MockEntity() + properties = get_graph_entity_properties(entity) + self.assertEqual(properties, {"name": "test", "age": 25}) + + def test_enumerate_traceback(self): + """Test enumerate_traceback function.""" + + def test_function(): + current_frame = inspect.currentframe() + return list(enumerate_traceback(current_frame)) + + result = test_function() + self.assertGreater(len(result), 0) + self.assertTrue( + all(isinstance(item, tuple) and len(item) == 2 for item in result) + ) + self.assertTrue(all(isinstance(depth, int) for depth, frame in result)) + self.assertTrue(all(isinstance(frame, FrameType) for depth, frame in result)) + + def test_version_tag_to_integer(self): + """Test version_tag_to_integer function.""" + # Test basic version conversion + self.assertEqual(version_tag_to_integer("5.4.0"), 50400) + self.assertEqual(version_tag_to_integer("4.0.0"), 40000) + self.assertEqual(version_tag_to_integer("3.2.1"), 30201) + + # Test with less than 3 components + self.assertEqual(version_tag_to_integer("5.4"), 50400) + self.assertEqual(version_tag_to_integer("5"), 50000) + + # Test with aura suffix + self.assertEqual(version_tag_to_integer("5.14-aura"), 51400) + self.assertEqual(version_tag_to_integer("4.0-aura"), 40000) + + # Test edge cases + self.assertEqual(version_tag_to_integer("0.0.0"), 0) + self.assertEqual(version_tag_to_integer("1.0.0"), 10000) + + def test_version_tag_to_integer_invalid_input(self): + """Test version_tag_to_integer with invalid input.""" + with self.assertRaises(ValueError): + version_tag_to_integer("invalid") + + with self.assertRaises(ValueError): + version_tag_to_integer("5.4.invalid") + + with self.assertRaises(ValueError): + version_tag_to_integer("") + + +if __name__ == "__main__": + unittest.main()