diff --git a/django_mongodb_backend/utils.py b/django_mongodb_backend/utils.py index 634ec234b..11412a666 100644 --- a/django_mongodb_backend/utils.py +++ b/django_mongodb_backend/utils.py @@ -1,5 +1,6 @@ import copy import time +from urllib.parse import parse_qsl, quote, urlsplit import django from django.conf import settings @@ -31,39 +32,84 @@ def check_django_compatability(): def parse_uri(uri, *, db_name=None, options=None, test=None): """ Convert the given uri into a dictionary suitable for Django's DATABASES - setting. + setting. Keep query string args on HOST (not in OPTIONS). + + Behavior: + - Non-SRV: HOST = "", no scheme/path. + - SRV: HOST = "mongodb+srv://", no path. + - NAME is db_name if provided else the db in the URI path (required). + - If the URI has a db path and no authSource in the query, append it. + - options kwarg merges by appending to the HOST query (last-one-wins for + single-valued options), without re-encoding existing query content. + - PORT is set only for single-host URIs; multi-host and SRV => PORT=None. """ - uri = pymongo_parse_uri(uri) - host = None - port = None - if uri["fqdn"]: - # This is a SRV URI and the host is the fqdn. - host = f"mongodb+srv://{uri['fqdn']}" + parsed = pymongo_parse_uri(uri) + split = urlsplit(uri) + + # Keep the original query string verbatim to avoid breaking special + # options like readPreferenceTags. + query_str = split.query or "" + + # Determine NAME; must come from db_name or the URI path. + db = db_name or parsed.get("database") + if not db: + raise ImproperlyConfigured("You must provide the db_name parameter.") + + # Helper: check if a key is present in the existing query (case-sensitive). + def query_has_key(key: str) -> bool: + return any(k == key for k, _ in parse_qsl(query_str, keep_blank_values=True)) + + # If URI path had a database and no authSource is present, append it. + if parsed.get("database") and not query_has_key("authSource"): + suffix = f"authSource={quote(parsed['database'], safe='')}" + query_str = f"{query_str}&{suffix}" if query_str else suffix + + # Merge options by appending them (so "last one wins" for single-valued opts). + if options: + for k, v in options.items(): + # Convert value to string as expected in URIs. + v_str = ("true" if v else "false") if isinstance(v, bool) else str(v) + # Preserve ':' and ',' unescaped (important for readPreferenceTags). + v_enc = quote(v_str, safe=":,") + pair = f"{k}={v_enc}" + query_str = f"{query_str}&{pair}" if query_str else pair + + # Build HOST (and PORT) based on SRV vs. standard. + if parsed.get("fqdn"): # SRV URI + host_base = f"mongodb+srv://{parsed['fqdn']}" + port = None else: - nodelist = uri.get("nodelist") + nodelist = parsed.get("nodelist") or [] if len(nodelist) == 1: - host, port = nodelist[0] + h, p = nodelist[0] + host_base = h + port = p elif len(nodelist) > 1: - host = ",".join([f"{host}:{port}" for host, port in nodelist]) - db_name = db_name or uri["database"] - if not db_name: - raise ImproperlyConfigured("You must provide the db_name parameter.") - opts = uri.get("options") - if options: - opts.update(options) + # Ensure explicit ports for each host (default 27017 if missing). + parts = [f"{h}:{(p if p is not None else 27017)}" for h, p in nodelist] + host_base = ",".join(parts) + port = None + else: + # Fallback for unusual/invalid URIs. + host_base = split.netloc.split("@")[-1] + port = None + + host_with_query = f"{host_base}?{query_str}" if query_str else host_base + settings_dict = { "ENGINE": "django_mongodb_backend", - "NAME": db_name, - "HOST": host, + "NAME": db, + "HOST": host_with_query, "PORT": port, - "USER": uri.get("username"), - "PASSWORD": uri.get("password"), - "OPTIONS": opts, + "USER": parsed.get("username"), + "PASSWORD": parsed.get("password"), + # Options remain empty; all query args live in HOST. + "OPTIONS": {}, } - if "authSource" not in settings_dict["OPTIONS"] and uri["database"]: - settings_dict["OPTIONS"]["authSource"] = uri["database"] + if test: settings_dict["TEST"] = test + return settings_dict diff --git a/tests/backend_/utils/test_parse_uri.py b/tests/backend_/utils/test_parse_uri.py index 804c4efcb..04d84b8be 100644 --- a/tests/backend_/utils/test_parse_uri.py +++ b/tests/backend_/utils/test_parse_uri.py @@ -12,14 +12,18 @@ def test_simple_uri(self): settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net/myDatabase") self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend") self.assertEqual(settings_dict["NAME"], "myDatabase") - self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") - self.assertEqual(settings_dict["OPTIONS"], {"authSource": "myDatabase"}) + # Default authSource derived from URI path db is appended to HOST + self.assertEqual( + settings_dict["HOST"], "cluster0.example.mongodb.net?authSource=myDatabase" + ) + self.assertEqual(settings_dict["OPTIONS"], {}) def test_db_name(self): settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net/", db_name="myDatabase") self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend") self.assertEqual(settings_dict["NAME"], "myDatabase") self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") + # No default authSource injected when the URI has no database path self.assertEqual(settings_dict["OPTIONS"], {}) def test_db_name_overrides_default_auth_db(self): @@ -28,8 +32,11 @@ def test_db_name_overrides_default_auth_db(self): ) self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend") self.assertEqual(settings_dict["NAME"], "myDatabase") - self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") - self.assertEqual(settings_dict["OPTIONS"], {"authSource": "default_auth_db"}) + # authSource defaults to the database from the URI, not db_name + self.assertEqual( + settings_dict["HOST"], "cluster0.example.mongodb.net?authSource=default_auth_db" + ) + self.assertEqual(settings_dict["OPTIONS"], {}) def test_no_database(self): msg = "You must provide the db_name parameter." @@ -43,44 +50,57 @@ def test_srv_uri_with_options(self): with patch("dns.resolver.resolve"): settings_dict = parse_uri(uri) self.assertEqual(settings_dict["NAME"], "my_database") - self.assertEqual(settings_dict["HOST"], "mongodb+srv://cluster0.example.mongodb.net") + # HOST includes scheme + fqdn only (no path), with query + # preserved and default authSource appended + self.assertTrue( + settings_dict["HOST"].startswith("mongodb+srv://cluster0.example.mongodb.net?") + ) + self.assertIn("retryWrites=true", settings_dict["HOST"]) + self.assertIn("w=majority", settings_dict["HOST"]) + self.assertIn("authSource=my_database", settings_dict["HOST"]) self.assertEqual(settings_dict["USER"], "my_user") self.assertEqual(settings_dict["PASSWORD"], "my_password") self.assertIsNone(settings_dict["PORT"]) - self.assertEqual( - settings_dict["OPTIONS"], - {"authSource": "my_database", "retryWrites": True, "w": "majority", "tls": True}, - ) + # No options copied into OPTIONS; they live in HOST query + self.assertEqual(settings_dict["OPTIONS"], {}) def test_localhost(self): settings_dict = parse_uri("mongodb://localhost/db") - self.assertEqual(settings_dict["HOST"], "localhost") + # Default authSource appended to HOST + self.assertEqual(settings_dict["HOST"], "localhost?authSource=db") self.assertEqual(settings_dict["PORT"], 27017) def test_localhost_with_port(self): settings_dict = parse_uri("mongodb://localhost:27018/db") - self.assertEqual(settings_dict["HOST"], "localhost") + # HOST omits the path and port, keeps only host + query + self.assertEqual(settings_dict["HOST"], "localhost?authSource=db") self.assertEqual(settings_dict["PORT"], 27018) def test_hosts_with_ports(self): settings_dict = parse_uri("mongodb://localhost:27017,localhost:27018/db") - self.assertEqual(settings_dict["HOST"], "localhost:27017,localhost:27018") + # For multi-host, PORT is None and HOST carries the full host list plus query + self.assertEqual(settings_dict["HOST"], "localhost:27017,localhost:27018?authSource=db") self.assertEqual(settings_dict["PORT"], None) def test_hosts_without_ports(self): settings_dict = parse_uri("mongodb://host1.net,host2.net/db") - self.assertEqual(settings_dict["HOST"], "host1.net:27017,host2.net:27017") + # Default ports are added to each host in HOST, plus the query + self.assertEqual(settings_dict["HOST"], "host1.net:27017,host2.net:27017?authSource=db") self.assertEqual(settings_dict["PORT"], None) def test_auth_source_in_query_string(self): settings_dict = parse_uri("mongodb://localhost/?authSource=auth", db_name="db") self.assertEqual(settings_dict["NAME"], "db") - self.assertEqual(settings_dict["OPTIONS"], {"authSource": "auth"}) + # Keep original query intact in HOST; do not duplicate into OPTIONS + self.assertEqual(settings_dict["HOST"], "localhost?authSource=auth") + self.assertEqual(settings_dict["OPTIONS"], {}) def test_auth_source_in_query_string_overrides_defaultauthdb(self): settings_dict = parse_uri("mongodb://localhost/db?authSource=auth") self.assertEqual(settings_dict["NAME"], "db") - self.assertEqual(settings_dict["OPTIONS"], {"authSource": "auth"}) + # Query-provided authSource overrides default; kept in HOST only + self.assertEqual(settings_dict["HOST"], "localhost?authSource=auth") + self.assertEqual(settings_dict["OPTIONS"], {}) def test_options_kwarg(self): options = {"authSource": "auth", "retryWrites": True} @@ -88,10 +108,13 @@ def test_options_kwarg(self): "mongodb://cluster0.example.mongodb.net/myDatabase?retryWrites=false&retryReads=true", options=options, ) - self.assertEqual( - settings_dict["OPTIONS"], - {"authSource": "auth", "retryWrites": True, "retryReads": True}, - ) + # options kwarg overrides same-key query params; query-only keys are kept. + # All options live in HOST's query string; OPTIONS is empty. + self.assertTrue(settings_dict["HOST"].startswith("cluster0.example.mongodb.net?")) + self.assertIn("authSource=auth", settings_dict["HOST"]) + self.assertIn("retryWrites=true", settings_dict["HOST"]) # overridden + self.assertIn("retryReads=true", settings_dict["HOST"]) # preserved + self.assertEqual(settings_dict["OPTIONS"], {}) def test_test_kwarg(self): settings_dict = parse_uri("mongodb://localhost/db", test={"NAME": "test_db"}) @@ -105,3 +128,59 @@ def test_invalid_credentials(self): def test_no_scheme(self): with self.assertRaisesMessage(pymongo.errors.InvalidURI, "Invalid URI scheme"): parse_uri("cluster0.example.mongodb.net") + + def test_read_preference_tags_in_host_query_allows_mongoclient_construction(self): + """ + Ensure readPreferenceTags preserved in the HOST query string can be parsed by + MongoClient without raising validation errors, and result in correct tag sets. + This verifies we no longer rely on pymongo's normalized options dict for tags. + """ + cases = [ + ( + "mongodb://localhost/?readPreference=secondary&readPreferenceTags=dc:ny,other:sf&readPreferenceTags=dc:2,other:1", + [{"dc": "ny", "other": "sf"}, {"dc": "2", "other": "1"}], + ), + ( + "mongodb://localhost/?retryWrites=true&readPreference=secondary&readPreferenceTags=nodeType:ANALYTICS&w=majority&appName=sniply-production", + [{"nodeType": "ANALYTICS"}], + ), + ] + + for uri, expected_tags in cases: + with self.subTest(uri=uri): + # Baseline: demonstrate why relying on parsed options can be problematic. + parsed = pymongo.uri_parser.parse_uri(uri) + # Some PyMongo versions normalize this into a dict (invalid as a kwarg), + # others into a list. If it's a dict, passing it as a kwarg will raise a + # ValueError as shown in the issue. + # We only assert no crash in our new path below; this is informational. + if isinstance(parsed["options"].get("readPreferenceTags"), dict): + with self.assertRaises(ValueError): + pymongo.MongoClient( + readPreferenceTags=parsed["options"]["readPreferenceTags"] + ) + + # New behavior: keep the raw query on HOST, not in OPTIONS. + settings_dict = parse_uri(uri, db_name="db") + host_with_query = settings_dict["HOST"] + + # Compose a full URI for MongoClient (non-SRV -> prepend scheme and + # ensure "/?" before query) + if host_with_query.startswith("mongodb+srv://"): + full_uri = host_with_query # SRV already includes scheme + else: + if "?" in host_with_query: + base, q = host_with_query.split("?", 1) + full_uri = f"mongodb://{base}/?{q}" + else: + full_uri = f"mongodb://{host_with_query}/" + + # Constructing MongoClient should not raise, and should reflect the read + # preference + tags. + client = pymongo.MongoClient(full_uri, serverSelectionTimeoutMS=1) + try: + doc = client.read_preference.document + self.assertEqual(doc.get("mode"), "secondary") + self.assertEqual(doc.get("tags"), expected_tags) + finally: + client.close()