diff --git a/README.md b/README.md index fad0a9b42..d41820995 100644 --- a/README.md +++ b/README.md @@ -110,17 +110,59 @@ to this: DATABASES = { "default": { "ENGINE": "django_mongodb", + "HOST": "mongodb+srv://cluster0.example.mongodb.net", "NAME": "my_database", "USER": "my_user", "PASSWORD": "my_password", - "OPTIONS": {...}, + "PORT": 27017, + "OPTIONS": { + # Example: + "retryWrites": "true", + "w": "majority", + "tls": "false", + }, }, } ``` +For a localhost configuration, you can omit `HOST` or specify +`"HOST": "localhost"`. + +`HOST` only needs a scheme prefix for SRV connections (`mongodb+srv://`). A +`mongodb://` prefix is never required. + `OPTIONS` is an optional dictionary of parameters that will be passed to [`MongoClient`](https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html). +`USER`, `PASSWORD`, and `PORT` (if 27017) may also be optional. + +For a replica set or sharded cluster where you have multiple hosts, include +all of them in `HOST`, e.g. +`"mongodb://mongos0.example.com:27017,mongos1.example.com:27017"`. + +Alternatively, if you prefer to simply paste in a MongoDB URI rather than parse +it into the format above, you can use: + +```python +import django_mongodb + +MONGODB_URI = "mongodb+srv://my_user:my_password@cluster0.example.mongodb.net/myDatabase?retryWrites=true&w=majority&tls=false" +DATABASES["default"] = django_mongodb.parse_uri(MONGODB_URI) +``` + +This constructs a `DATABASES` setting equivalent to the first example. + +#### `django_mongodb.parse_uri(uri, conn_max_age=0, test=None)` + +`parse_uri()` provides a few options to customize the resulting `DATABASES` +setting, but for maximum flexibility, construct `DATABASES` manually as +described above. + +- Use `conn_max_age` to configure [persistent database connections]( + https://docs.djangoproject.com/en/stable/ref/databases/#persistent-database-connections). +- Use `test` to provide a dictionary of [settings for test databases]( + https://docs.djangoproject.com/en/stable/ref/settings/#test). + Congratulations, your project is ready to go! ## Notes on Django QuerySets diff --git a/django_mongodb/__init__.py b/django_mongodb/__init__.py index 31d8f2d3a..5ee8bc1c6 100644 --- a/django_mongodb/__init__.py +++ b/django_mongodb/__init__.py @@ -2,7 +2,7 @@ # Check Django compatibility before other imports which may fail if the # wrong version of Django is installed. -from .utils import check_django_compatability +from .utils import check_django_compatability, parse_uri check_django_compatability() @@ -14,6 +14,8 @@ from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 +__all__ = ["parse_uri"] + register_aggregates() register_expressions() register_fields() diff --git a/django_mongodb/utils.py b/django_mongodb/utils.py index b4d87cc7d..a8b4b95d9 100644 --- a/django_mongodb/utils.py +++ b/django_mongodb/utils.py @@ -6,6 +6,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db.backends.utils import logger from django.utils.version import get_version_tuple +from pymongo.uri_parser import parse_uri as pymongo_parse_uri def check_django_compatability(): @@ -25,6 +26,38 @@ def check_django_compatability(): ) +def parse_uri(uri, conn_max_age=0, test=None): + """ + Convert the given uri into a dictionary suitable for Django's DATABASES + setting. + """ + uri = pymongo_parse_uri(uri) + host = None + port = None + if uri["fqdn"]: + # This is a SRV URI and the host is the fqdn. + host = f"mongodb+srv://{uri['fqdn']}" + else: + nodelist = uri.get("nodelist") + if len(nodelist) == 1: + host, port = nodelist[0] + elif len(nodelist) > 1: + host = ",".join([f"{host}:{port}" for host, port in nodelist]) + settings_dict = { + "ENGINE": "django_mongodb", + "NAME": uri["database"], + "HOST": host, + "PORT": port, + "USER": uri.get("username"), + "PASSWORD": uri.get("password"), + "OPTIONS": uri.get("options"), + "CONN_MAX_AGE": conn_max_age, + } + if test: + settings_dict["TEST"] = test + return settings_dict + + def set_wrapped_methods(cls): """Initialize the wrapped methods on cls.""" if hasattr(cls, "logging_wrapper"): diff --git a/tests/backend_/utils/__init__.py b/tests/backend_/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/backend_/utils/test_parse_uri.py b/tests/backend_/utils/test_parse_uri.py new file mode 100644 index 000000000..01a479fb7 --- /dev/null +++ b/tests/backend_/utils/test_parse_uri.py @@ -0,0 +1,71 @@ +from unittest.mock import patch + +import pymongo +from django.test import SimpleTestCase + +from django_mongodb import parse_uri + + +class ParseURITests(SimpleTestCase): + def test_simple_uri(self): + settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net/myDatabase") + self.assertEqual(settings_dict["ENGINE"], "django_mongodb") + self.assertEqual(settings_dict["NAME"], "myDatabase") + self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") + + def test_no_database(self): + settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net") + self.assertIsNone(settings_dict["NAME"]) + self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") + + def test_srv_uri_with_options(self): + uri = "mongodb+srv://my_user:my_password@cluster0.example.mongodb.net/my_database?retryWrites=true&w=majority" + # patch() prevents a crash when PyMongo attempts to resolve the + # nonexistent SRV record. + with patch("dns.resolver.resolve"): + settings_dict = parse_uri(uri) + self.assertEqual(settings_dict["NAME"], "my_database") + self.assertEqual(settings_dict["HOST"], "mongodb+srv://cluster0.example.mongodb.net") + self.assertEqual(settings_dict["USER"], "my_user") + self.assertEqual(settings_dict["PASSWORD"], "my_password") + self.assertIsNone(settings_dict["PORT"]) + self.assertEqual( + settings_dict["OPTIONS"], {"retryWrites": True, "w": "majority", "tls": True} + ) + + def test_localhost(self): + settings_dict = parse_uri("mongodb://localhost") + self.assertEqual(settings_dict["HOST"], "localhost") + self.assertEqual(settings_dict["PORT"], 27017) + + def test_localhost_with_port(self): + settings_dict = parse_uri("mongodb://localhost:27018") + self.assertEqual(settings_dict["HOST"], "localhost") + self.assertEqual(settings_dict["PORT"], 27018) + + def test_hosts_with_ports(self): + settings_dict = parse_uri("mongodb://localhost:27017,localhost:27018") + self.assertEqual(settings_dict["HOST"], "localhost:27017,localhost:27018") + self.assertEqual(settings_dict["PORT"], None) + + def test_hosts_without_ports(self): + settings_dict = parse_uri("mongodb://host1.net,host2.net") + self.assertEqual(settings_dict["HOST"], "host1.net:27017,host2.net:27017") + self.assertEqual(settings_dict["PORT"], None) + + def test_conn_max_age(self): + settings_dict = parse_uri("mongodb://localhost", conn_max_age=600) + self.assertEqual(settings_dict["CONN_MAX_AGE"], 600) + + def test_test_kwarg(self): + settings_dict = parse_uri("mongodb://localhost", test={"NAME": "test_db"}) + self.assertEqual(settings_dict["TEST"], {"NAME": "test_db"}) + + def test_invalid_credentials(self): + msg = "The empty string is not valid username." + with self.assertRaisesMessage(pymongo.errors.InvalidURI, msg): + parse_uri("mongodb://:@localhost") + + def test_no_scheme(self): + with self.assertRaisesMessage(pymongo.errors.InvalidURI, "Invalid URI scheme"): + parse_uri("cluster0.example.mongodb.net")