Skip to content

Commit 81c759a

Browse files
authored
PYTHON-2878 Allow passing dict to sort/create_index/hint (#1389)
1 parent 2f13aee commit 81c759a

File tree

6 files changed

+62
-28
lines changed

6 files changed

+62
-28
lines changed

pymongo/collection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def _update(
798798
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
799799
)
800800
if not isinstance(hint, str):
801-
hint = helpers._index_document(hint) # type: ignore[assignment]
801+
hint = helpers._index_document(hint)
802802
update_doc["hint"] = hint
803803
command = SON([("update", self.name), ("ordered", ordered), ("updates", [update_doc])])
804804
if let is not None:
@@ -1277,7 +1277,7 @@ def _delete(
12771277
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
12781278
)
12791279
if not isinstance(hint, str):
1280-
hint = helpers._index_document(hint) # type: ignore[assignment]
1280+
hint = helpers._index_document(hint)
12811281
delete_doc["hint"] = hint
12821282
command = SON([("delete", self.name), ("ordered", ordered), ("deletes", [delete_doc])])
12831283

@@ -3097,7 +3097,7 @@ def __find_and_modify(
30973097
cmd["upsert"] = upsert
30983098
if hint is not None:
30993099
if not isinstance(hint, str):
3100-
hint = helpers._index_document(hint) # type: ignore[assignment]
3100+
hint = helpers._index_document(hint)
31013101

31023102
write_concern = self._write_concern_for_cmd(cmd, session)
31033103

pymongo/cursor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def close(self) -> None:
157157
self.conn = None
158158

159159

160-
_Sort = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]]
160+
_Sort = Union[
161+
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
162+
]
161163
_Hint = Union[str, _Sort]
162164

163165

pymongo/helpers.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ def _index_list(
110110
else:
111111
if isinstance(key_or_list, str):
112112
return [(key_or_list, ASCENDING)]
113-
if isinstance(key_or_list, abc.ItemsView):
114-
return list(key_or_list)
113+
elif isinstance(key_or_list, abc.ItemsView):
114+
return list(key_or_list) # type: ignore[arg-type]
115+
elif isinstance(key_or_list, abc.Mapping):
116+
return list(key_or_list.items())
115117
elif not isinstance(key_or_list, (list, tuple)):
116118
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
117119
values: list[tuple[str, int]] = []
@@ -127,33 +129,40 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]:
127129
128130
Takes a list of (key, direction) pairs.
129131
"""
130-
if isinstance(index_list, abc.Mapping):
132+
if not isinstance(index_list, (list, tuple, abc.Mapping)):
131133
raise TypeError(
132-
"passing a dict to sort/create_index/hint is not "
133-
"allowed - use a list of tuples instead. did you "
134-
"mean %r?" % list(index_list.items())
134+
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
135135
)
136-
elif not isinstance(index_list, (list, tuple)):
137-
raise TypeError("must use a list of (key, direction) pairs, not: " + repr(index_list))
138136
if not len(index_list):
139-
raise ValueError("key_or_list must not be the empty list")
137+
raise ValueError("key_or_list must not be empty")
140138

141139
index: SON[str, Any] = SON()
142-
for item in index_list:
143-
if isinstance(item, str):
144-
item = (item, ASCENDING)
145-
key, value = item
146-
if not isinstance(key, str):
147-
raise TypeError("first item in each key pair must be an instance of str")
148-
if not isinstance(value, (str, int, abc.Mapping)):
149-
raise TypeError(
150-
"second item in each key pair must be 1, -1, "
151-
"'2d', or another valid MongoDB index specifier."
152-
)
153-
index[key] = value
140+
141+
if isinstance(index_list, abc.Mapping):
142+
for key in index_list:
143+
value = index_list[key]
144+
_validate_index_key_pair(key, value)
145+
index[key] = value
146+
else:
147+
for item in index_list:
148+
if isinstance(item, str):
149+
item = (item, ASCENDING)
150+
key, value = item
151+
_validate_index_key_pair(key, value)
152+
index[key] = value
154153
return index
155154

156155

156+
def _validate_index_key_pair(key: Any, value: Any) -> None:
157+
if not isinstance(key, str):
158+
raise TypeError("first item in each key pair must be an instance of str")
159+
if not isinstance(value, (str, int, abc.Mapping)):
160+
raise TypeError(
161+
"second item in each key pair must be 1, -1, "
162+
"'2d', or another valid MongoDB index specifier."
163+
)
164+
165+
157166
def _check_command_response(
158167
response: _DocumentOut,
159168
max_wire_version: Optional[int],

pymongo/operations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@
3737
from bson.son import SON
3838
from pymongo.bulk import _Bulk
3939

40-
# Hint supports index name, "myIndex", or list of either strings or index pairs: [('x', 1), ('y', -1), 'z'']
41-
_IndexList = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]]
40+
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
41+
_IndexList = Union[
42+
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
43+
]
4244
_IndexKeyHint = Union[str, _IndexList]
4345

4446

test/test_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,28 @@ def test_handshake_08_invalid_aws_ec2(self):
18361836
None,
18371837
)
18381838

1839+
def test_dict_hints(self):
1840+
c = rs_or_single_client()
1841+
try:
1842+
c.t.t.find(hint={"x": 1})
1843+
except Exception:
1844+
self.fail("passing a dictionary hint to find failed!")
1845+
1846+
def test_dict_hints_sort(self):
1847+
c = rs_or_single_client()
1848+
try:
1849+
result = c.t.t.find()
1850+
result.sort({"x": 1})
1851+
except Exception:
1852+
self.fail("passing a dictionary to sort failed!")
1853+
1854+
def test_dict_hints_create_index(self):
1855+
c = rs_or_single_client()
1856+
try:
1857+
c.t.t.create_index({"x": pymongo.ASCENDING})
1858+
except Exception:
1859+
self.fail("passing a dictionary to create_index failed!")
1860+
18391861

18401862
class TestExhaustCursor(IntegrationTest):
18411863
"""Test that clients properly handle errors from exhaust cursors."""

test/test_collection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def test_create_index(self):
276276
db = self.db
277277

278278
self.assertRaises(TypeError, db.test.create_index, 5)
279-
self.assertRaises(TypeError, db.test.create_index, {"hello": 1})
280279
self.assertRaises(ValueError, db.test.create_index, [])
281280

282281
db.test.drop_indexes()

0 commit comments

Comments
 (0)