Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 69 additions & 23 deletions django_mongodb_backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import time
from urllib.parse import parse_qsl, quote, urlsplit

import django
from django.conf import settings
Expand Down Expand Up @@ -31,39 +32,84 @@ def check_django_compatability():
def parse_uri(uri, *, db_name=None, options=None, test=None):
"""
Convert the given uri into a dictionary suitable for Django's DATABASES
setting.
setting. Keep query string args on HOST (not in OPTIONS).

Behavior:
- Non-SRV: HOST = "<host[,host2:port2]><?query>", no scheme/path.
- SRV: HOST = "mongodb+srv://<fqdn><?query>", no path.
- NAME is db_name if provided else the db in the URI path (required).
- If the URI has a db path and no authSource in the query, append it.
- options kwarg merges by appending to the HOST query (last-one-wins for
single-valued options), without re-encoding existing query content.
- PORT is set only for single-host URIs; multi-host and SRV => PORT=None.
"""
uri = pymongo_parse_uri(uri)
host = None
port = None
if uri["fqdn"]:
# This is a SRV URI and the host is the fqdn.
host = f"mongodb+srv://{uri['fqdn']}"
parsed = pymongo_parse_uri(uri)
split = urlsplit(uri)

# Keep the original query string verbatim to avoid breaking special
# options like readPreferenceTags.
query_str = split.query or ""

# Determine NAME; must come from db_name or the URI path.
db = db_name or parsed.get("database")
if not db:
raise ImproperlyConfigured("You must provide the db_name parameter.")

# Helper: check if a key is present in the existing query (case-sensitive).
def query_has_key(key: str) -> bool:
return any(k == key for k, _ in parse_qsl(query_str, keep_blank_values=True))

# If URI path had a database and no authSource is present, append it.
if parsed.get("database") and not query_has_key("authSource"):
suffix = f"authSource={quote(parsed['database'], safe='')}"
query_str = f"{query_str}&{suffix}" if query_str else suffix

# Merge options by appending them (so "last one wins" for single-valued opts).
if options:
for k, v in options.items():
# Convert value to string as expected in URIs.
v_str = ("true" if v else "false") if isinstance(v, bool) else str(v)
# Preserve ':' and ',' unescaped (important for readPreferenceTags).
v_enc = quote(v_str, safe=":,")
pair = f"{k}={v_enc}"
query_str = f"{query_str}&{pair}" if query_str else pair

# Build HOST (and PORT) based on SRV vs. standard.
if parsed.get("fqdn"): # SRV URI
host_base = f"mongodb+srv://{parsed['fqdn']}"
port = None
else:
nodelist = uri.get("nodelist")
nodelist = parsed.get("nodelist") or []
if len(nodelist) == 1:
host, port = nodelist[0]
h, p = nodelist[0]
host_base = h
port = p
elif len(nodelist) > 1:
host = ",".join([f"{host}:{port}" for host, port in nodelist])
db_name = db_name or uri["database"]
if not db_name:
raise ImproperlyConfigured("You must provide the db_name parameter.")
opts = uri.get("options")
if options:
opts.update(options)
# Ensure explicit ports for each host (default 27017 if missing).
parts = [f"{h}:{(p if p is not None else 27017)}" for h, p in nodelist]
host_base = ",".join(parts)
port = None
else:
# Fallback for unusual/invalid URIs.
host_base = split.netloc.split("@")[-1]
port = None

host_with_query = f"{host_base}?{query_str}" if query_str else host_base

settings_dict = {
"ENGINE": "django_mongodb_backend",
"NAME": db_name,
"HOST": host,
"NAME": db,
"HOST": host_with_query,
"PORT": port,
"USER": uri.get("username"),
"PASSWORD": uri.get("password"),
"OPTIONS": opts,
"USER": parsed.get("username"),
"PASSWORD": parsed.get("password"),
# Options remain empty; all query args live in HOST.
"OPTIONS": {},
}
if "authSource" not in settings_dict["OPTIONS"] and uri["database"]:
settings_dict["OPTIONS"]["authSource"] = uri["database"]

if test:
settings_dict["TEST"] = test

return settings_dict


Expand Down
117 changes: 98 additions & 19 deletions tests/backend_/utils/test_parse_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@ def test_simple_uri(self):
settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net/myDatabase")
self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend")
self.assertEqual(settings_dict["NAME"], "myDatabase")
self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net")
self.assertEqual(settings_dict["OPTIONS"], {"authSource": "myDatabase"})
# Default authSource derived from URI path db is appended to HOST
self.assertEqual(
settings_dict["HOST"], "cluster0.example.mongodb.net?authSource=myDatabase"
)
self.assertEqual(settings_dict["OPTIONS"], {})

def test_db_name(self):
settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net/", db_name="myDatabase")
self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend")
self.assertEqual(settings_dict["NAME"], "myDatabase")
self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net")
# No default authSource injected when the URI has no database path
self.assertEqual(settings_dict["OPTIONS"], {})

def test_db_name_overrides_default_auth_db(self):
Expand All @@ -28,8 +32,11 @@ def test_db_name_overrides_default_auth_db(self):
)
self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend")
self.assertEqual(settings_dict["NAME"], "myDatabase")
self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net")
self.assertEqual(settings_dict["OPTIONS"], {"authSource": "default_auth_db"})
# authSource defaults to the database from the URI, not db_name
self.assertEqual(
settings_dict["HOST"], "cluster0.example.mongodb.net?authSource=default_auth_db"
)
self.assertEqual(settings_dict["OPTIONS"], {})

def test_no_database(self):
msg = "You must provide the db_name parameter."
Expand All @@ -43,55 +50,71 @@ def test_srv_uri_with_options(self):
with patch("dns.resolver.resolve"):
settings_dict = parse_uri(uri)
self.assertEqual(settings_dict["NAME"], "my_database")
self.assertEqual(settings_dict["HOST"], "mongodb+srv://cluster0.example.mongodb.net")
# HOST includes scheme + fqdn only (no path), with query
# preserved and default authSource appended
self.assertTrue(
settings_dict["HOST"].startswith("mongodb+srv://cluster0.example.mongodb.net?")
)
self.assertIn("retryWrites=true", settings_dict["HOST"])
self.assertIn("w=majority", settings_dict["HOST"])
self.assertIn("authSource=my_database", settings_dict["HOST"])
self.assertEqual(settings_dict["USER"], "my_user")
self.assertEqual(settings_dict["PASSWORD"], "my_password")
self.assertIsNone(settings_dict["PORT"])
self.assertEqual(
settings_dict["OPTIONS"],
{"authSource": "my_database", "retryWrites": True, "w": "majority", "tls": True},
)
# No options copied into OPTIONS; they live in HOST query
self.assertEqual(settings_dict["OPTIONS"], {})

def test_localhost(self):
settings_dict = parse_uri("mongodb://localhost/db")
self.assertEqual(settings_dict["HOST"], "localhost")
# Default authSource appended to HOST
self.assertEqual(settings_dict["HOST"], "localhost?authSource=db")
self.assertEqual(settings_dict["PORT"], 27017)

def test_localhost_with_port(self):
settings_dict = parse_uri("mongodb://localhost:27018/db")
self.assertEqual(settings_dict["HOST"], "localhost")
# HOST omits the path and port, keeps only host + query
self.assertEqual(settings_dict["HOST"], "localhost?authSource=db")
self.assertEqual(settings_dict["PORT"], 27018)

def test_hosts_with_ports(self):
settings_dict = parse_uri("mongodb://localhost:27017,localhost:27018/db")
self.assertEqual(settings_dict["HOST"], "localhost:27017,localhost:27018")
# For multi-host, PORT is None and HOST carries the full host list plus query
self.assertEqual(settings_dict["HOST"], "localhost:27017,localhost:27018?authSource=db")
self.assertEqual(settings_dict["PORT"], None)

def test_hosts_without_ports(self):
settings_dict = parse_uri("mongodb://host1.net,host2.net/db")
self.assertEqual(settings_dict["HOST"], "host1.net:27017,host2.net:27017")
# Default ports are added to each host in HOST, plus the query
self.assertEqual(settings_dict["HOST"], "host1.net:27017,host2.net:27017?authSource=db")
self.assertEqual(settings_dict["PORT"], None)

def test_auth_source_in_query_string(self):
settings_dict = parse_uri("mongodb://localhost/?authSource=auth", db_name="db")
self.assertEqual(settings_dict["NAME"], "db")
self.assertEqual(settings_dict["OPTIONS"], {"authSource": "auth"})
# Keep original query intact in HOST; do not duplicate into OPTIONS
self.assertEqual(settings_dict["HOST"], "localhost?authSource=auth")
self.assertEqual(settings_dict["OPTIONS"], {})

def test_auth_source_in_query_string_overrides_defaultauthdb(self):
settings_dict = parse_uri("mongodb://localhost/db?authSource=auth")
self.assertEqual(settings_dict["NAME"], "db")
self.assertEqual(settings_dict["OPTIONS"], {"authSource": "auth"})
# Query-provided authSource overrides default; kept in HOST only
self.assertEqual(settings_dict["HOST"], "localhost?authSource=auth")
self.assertEqual(settings_dict["OPTIONS"], {})

def test_options_kwarg(self):
options = {"authSource": "auth", "retryWrites": True}
settings_dict = parse_uri(
"mongodb://cluster0.example.mongodb.net/myDatabase?retryWrites=false&retryReads=true",
options=options,
)
self.assertEqual(
settings_dict["OPTIONS"],
{"authSource": "auth", "retryWrites": True, "retryReads": True},
)
# options kwarg overrides same-key query params; query-only keys are kept.
# All options live in HOST's query string; OPTIONS is empty.
self.assertTrue(settings_dict["HOST"].startswith("cluster0.example.mongodb.net?"))
self.assertIn("authSource=auth", settings_dict["HOST"])
self.assertIn("retryWrites=true", settings_dict["HOST"]) # overridden
self.assertIn("retryReads=true", settings_dict["HOST"]) # preserved
self.assertEqual(settings_dict["OPTIONS"], {})

def test_test_kwarg(self):
settings_dict = parse_uri("mongodb://localhost/db", test={"NAME": "test_db"})
Expand All @@ -105,3 +128,59 @@ def test_invalid_credentials(self):
def test_no_scheme(self):
with self.assertRaisesMessage(pymongo.errors.InvalidURI, "Invalid URI scheme"):
parse_uri("cluster0.example.mongodb.net")

def test_read_preference_tags_in_host_query_allows_mongoclient_construction(self):
"""
Ensure readPreferenceTags preserved in the HOST query string can be parsed by
MongoClient without raising validation errors, and result in correct tag sets.
This verifies we no longer rely on pymongo's normalized options dict for tags.
"""
cases = [
(
"mongodb://localhost/?readPreference=secondary&readPreferenceTags=dc:ny,other:sf&readPreferenceTags=dc:2,other:1",
[{"dc": "ny", "other": "sf"}, {"dc": "2", "other": "1"}],
),
(
"mongodb://localhost/?retryWrites=true&readPreference=secondary&readPreferenceTags=nodeType:ANALYTICS&w=majority&appName=sniply-production",
[{"nodeType": "ANALYTICS"}],
),
]

for uri, expected_tags in cases:
with self.subTest(uri=uri):
# Baseline: demonstrate why relying on parsed options can be problematic.
parsed = pymongo.uri_parser.parse_uri(uri)
# Some PyMongo versions normalize this into a dict (invalid as a kwarg),
# others into a list. If it's a dict, passing it as a kwarg will raise a
# ValueError as shown in the issue.
# We only assert no crash in our new path below; this is informational.
if isinstance(parsed["options"].get("readPreferenceTags"), dict):
with self.assertRaises(ValueError):
pymongo.MongoClient(
readPreferenceTags=parsed["options"]["readPreferenceTags"]
)

# New behavior: keep the raw query on HOST, not in OPTIONS.
settings_dict = parse_uri(uri, db_name="db")
host_with_query = settings_dict["HOST"]

# Compose a full URI for MongoClient (non-SRV -> prepend scheme and
# ensure "/?" before query)
if host_with_query.startswith("mongodb+srv://"):
full_uri = host_with_query # SRV already includes scheme
else:
if "?" in host_with_query:
base, q = host_with_query.split("?", 1)
full_uri = f"mongodb://{base}/?{q}"
else:
full_uri = f"mongodb://{host_with_query}/"

# Constructing MongoClient should not raise, and should reflect the read
# preference + tags.
client = pymongo.MongoClient(full_uri, serverSelectionTimeoutMS=1)
try:
doc = client.read_preference.document
self.assertEqual(doc.get("mode"), "secondary")
self.assertEqual(doc.get("tags"), expected_tags)
finally:
client.close()