|
24 | 24 | sys.path[0:0] = [""]
|
25 | 25 |
|
26 | 26 | from test import unittest
|
27 |
| -from unittest.mock import patch |
| 27 | +from unittest.mock import MagicMock, patch |
28 | 28 |
|
29 | 29 | from bson.binary import JAVA_LEGACY
|
30 | 30 | from pymongo import ReadPreference
|
@@ -559,6 +559,42 @@ def test_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
|
559 | 559 | parse_uri("mongodb+srv://localhost/")
|
560 | 560 | parse_uri("mongodb+srv://mongo.local/")
|
561 | 561 |
|
| 562 | + def test_error_when_return_address_does_not_end_with_srv_domain(self): |
| 563 | + test_cases = [ |
| 564 | + { |
| 565 | + "query": "_mongodb._tcp.localhost", |
| 566 | + "mock_target": "localhost.mongodb", |
| 567 | + "expected_error": "Invalid SRV host", |
| 568 | + }, |
| 569 | + { |
| 570 | + "query": "_mongodb._tcp.blogs.mongodb.com", |
| 571 | + "mock_target": "blogs.evil.com", |
| 572 | + "expected_error": "Invalid SRV host", |
| 573 | + }, |
| 574 | + { |
| 575 | + "query": "_mongodb._tcp.blogs.mongo.local", |
| 576 | + "mock_target": "test_1.evil.com", |
| 577 | + "expected_error": "Invalid SRV host", |
| 578 | + }, |
| 579 | + ] |
| 580 | + for case in test_cases: |
| 581 | + with patch("dns.resolver.resolve") as mock_resolver: |
| 582 | + |
| 583 | + def mock_resolve(query, record_type, *args, **kwargs): |
| 584 | + mock_srv = MagicMock() |
| 585 | + mock_srv.target.to_text.return_value = case["mock_target"] |
| 586 | + return [mock_srv] |
| 587 | + |
| 588 | + mock_resolver.side_effect = mock_resolve |
| 589 | + domain = case["query"].split("._tcp.")[1] |
| 590 | + connection_string = f"mongodb+srv://{domain}" |
| 591 | + try: |
| 592 | + parse_uri(connection_string) |
| 593 | + except ConfigurationError as e: |
| 594 | + self.assertIn(case["expected_error"], str(e)) |
| 595 | + else: |
| 596 | + self.fail(f"ConfigurationError was not raised for query: {case['query']}") |
| 597 | + |
562 | 598 |
|
563 | 599 | if __name__ == "__main__":
|
564 | 600 | unittest.main()
|
0 commit comments