Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pymongo/asynchronous/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def __init__(
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))

async def get_options(self) -> Optional[str]:
from dns import resolver
Expand Down Expand Up @@ -131,6 +129,11 @@ async def _get_srv_response_and_hosts(
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
results = await self._resolve_uri(encapsulate_errors)

if self.__fqdn == results[0].target.to_text():
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)

# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
Expand Down
7 changes: 5 additions & 2 deletions pymongo/synchronous/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def __init__(
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))

def get_options(self) -> Optional[str]:
from dns import resolver
Expand Down Expand Up @@ -131,6 +129,11 @@ def _get_srv_response_and_hosts(
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
results = self._resolve_uri(encapsulate_errors)

if self.__fqdn == results[0].target.to_text():
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)

# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
Expand Down
6 changes: 0 additions & 6 deletions test/asynchronous/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,6 @@ def create_tests(cls):

class TestParsingErrors(AsyncPyMongoTestCase):
async def test_invalid_host(self):
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
client = self.simple_client("mongodb+srv://mongodb")
await client.aconnect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
client = self.simple_client("mongodb+srv://mongodb.com")
await client.aconnect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://127.0.0.1")
await client.aconnect()
Expand Down
6 changes: 0 additions & 6 deletions test/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,6 @@ def create_tests(cls):

class TestParsingErrors(PyMongoTestCase):
def test_invalid_host(self):
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
client = self.simple_client("mongodb+srv://mongodb")
client._connect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
client = self.simple_client("mongodb+srv://mongodb.com")
client._connect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://127.0.0.1")
client._connect()
Expand Down
67 changes: 67 additions & 0 deletions test/test_uri_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
sys.path[0:0] = [""]

from test import unittest
from unittest.mock import MagicMock, patch

from bson.binary import JAVA_LEGACY
from pymongo import ReadPreference
Expand Down Expand Up @@ -553,6 +554,72 @@ def test_port_with_whitespace(self):
with self.assertRaisesRegex(ValueError, r"Port contains whitespace character: '\\n'"):
parse_uri("mongodb://localhost:27\n017")

def test_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
with patch("dns.resolver.resolve"):
parse_uri("mongodb+srv://localhost/")
parse_uri("mongodb+srv://mongo.local/")

def test_error_when_return_address_does_not_end_with_srv_domain(self):
test_cases = [
{
"query": "_mongodb._tcp.localhost",
"mock_target": "localhost.mongodb",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "blogs.evil.com",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongo.local",
"mock_target": "test_1.evil.com",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.localhost",
"mock_target": "localhost",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.mongo.local",
"mock_target": "mongo.local",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.localhost",
"mock_target": "test_1.cluster_1localhost",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.mongo.local",
"mock_target": "test_1.my_hostmongo.local",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "cluster.testmongodb.com",
"expected_error": "Invalid SRV host",
},
]
for case in test_cases:
with patch("dns.resolver.resolve") as mock_resolver:

def mock_resolve(query, record_type, *args, **kwargs):
mock_srv = MagicMock()
mock_srv.target.to_text.return_value = case["mock_target"]
return [mock_srv]

mock_resolver.side_effect = mock_resolve
domain = case["query"].split("._tcp.")[1]
connection_string = f"mongodb+srv://{domain}"
try:
parse_uri(connection_string)
except ConfigurationError as e:
self.assertIn(case["expected_error"], str(e))
else:
self.fail(f"ConfigurationError was not raised for query: {case['query']}")


if __name__ == "__main__":
unittest.main()
Loading