Skip to content

Commit ff68925

Browse files
committed
PYTHON-1831 Refactor CRUD v2 to use base SpecRunner class
1 parent 086b600 commit ff68925

File tree

2 files changed

+56
-217
lines changed

2 files changed

+56
-217
lines changed

test/test_crud_v2.py

Lines changed: 19 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,15 @@
1414

1515
"""Test the collection module."""
1616

17-
import json
1817
import os
1918
import sys
2019

2120
sys.path[0:0] = [""]
2221

23-
from bson.py3compat import iteritems
24-
from pymongo import operations, WriteConcern
25-
from pymongo.command_cursor import CommandCursor
26-
from pymongo.cursor import Cursor
27-
from pymongo.errors import PyMongoError
28-
from pymongo.read_concern import ReadConcern
29-
from pymongo.results import _WriteResult, BulkWriteResult
22+
from test import unittest
23+
from test.utils import TestCreator
24+
from test.utils_spec_runner import SpecRunner
3025

31-
from test import unittest, client_context, IntegrationTest
32-
from test.utils import (camel_to_snake, camel_to_upper_camel,
33-
camel_to_snake_args, drop_collections,
34-
parse_collection_options, rs_client,
35-
OvertCommandListener, TestCreator)
3626

3727
# Location of JSON test specifications.
3828
_TEST_PATH = os.path.join(
@@ -43,216 +33,34 @@
4333
TEST_COLLECTION = 'testcollection'
4434

4535

46-
class TestAllScenarios(IntegrationTest):
47-
def run_operation(self, collection, test):
48-
# Iterate over all operations.
49-
for opdef in test['operations']:
50-
# Convert command from CamelCase to pymongo.collection method.
51-
operation = camel_to_snake(opdef['name'])
36+
class TestSpec(SpecRunner):
37+
def get_scenario_db_name(self, scenario_def):
38+
"""Crud spec says database_name is optional."""
39+
return scenario_def.get('database_name', TEST_DB)
5240

53-
# Get command handle on target entity (collection/database).
54-
target_object = opdef.get('object', 'collection')
55-
if target_object == 'database':
56-
cmd = getattr(collection.database, operation)
57-
elif target_object == 'collection':
58-
collection = collection.with_options(**dict(
59-
parse_collection_options(opdef.get(
60-
'collectionOptions', {}))))
61-
cmd = getattr(collection, operation)
62-
else:
63-
self.fail("Unknown object name %s" % (target_object,))
41+
def get_scenario_coll_name(self, scenario_def):
42+
"""Crud spec says collection_name is optional."""
43+
return scenario_def.get('collection_name', TEST_COLLECTION)
6444

65-
# Convert arguments to snake_case and handle special cases.
66-
arguments = opdef['arguments']
67-
options = arguments.pop("options", {})
45+
def get_object_name(self, op):
46+
"""Crud spec says object is optional and defaults to 'collection'."""
47+
return op.get('object', 'collection')
6848

69-
for option_name in options:
70-
arguments[camel_to_snake(option_name)] = options[option_name]
71-
72-
if operation == "bulk_write":
73-
# Parse each request into a bulk write model.
74-
requests = []
75-
for request in arguments["requests"]:
76-
bulk_model = camel_to_upper_camel(request["name"])
77-
bulk_class = getattr(operations, bulk_model)
78-
bulk_arguments = camel_to_snake_args(request["arguments"])
79-
requests.append(bulk_class(**bulk_arguments))
80-
arguments["requests"] = requests
81-
else:
82-
for arg_name in list(arguments):
83-
c2s = camel_to_snake(arg_name)
84-
# PyMongo accepts sort as list of tuples.
85-
if arg_name == "sort":
86-
sort_dict = arguments[arg_name]
87-
arguments[arg_name] = list(iteritems(sort_dict))
88-
# Named "key" instead not fieldName.
89-
if arg_name == "fieldName":
90-
arguments["key"] = arguments.pop(arg_name)
91-
# Aggregate uses "batchSize", while find uses batch_size.
92-
elif arg_name == "batchSize" and operation == "aggregate":
93-
continue
94-
# Requires boolean returnDocument.
95-
elif arg_name == "returnDocument":
96-
arguments[c2s] = arguments.pop(arg_name) == "After"
97-
else:
98-
arguments[c2s] = arguments.pop(arg_name)
99-
100-
if opdef.get('error') is True:
101-
with self.assertRaises(PyMongoError):
102-
cmd(**arguments)
103-
else:
104-
result = cmd(**arguments)
105-
self.check_result(opdef.get('result'), result)
106-
107-
def check_result(self, expected_result, result):
108-
if expected_result is None:
109-
return True
110-
111-
if isinstance(result, Cursor) or isinstance(result, CommandCursor):
112-
return list(result) == expected_result
113-
114-
elif isinstance(result, _WriteResult):
115-
for res in expected_result:
116-
prop = camel_to_snake(res)
117-
# SPEC-869: Only BulkWriteResult has upserted_count.
118-
if (prop == "upserted_count" and
119-
not isinstance(result, BulkWriteResult)):
120-
if result.upserted_id is not None:
121-
upserted_count = 1
122-
else:
123-
upserted_count = 0
124-
if upserted_count != expected_result[res]:
125-
return False
126-
elif prop == "inserted_ids":
127-
# BulkWriteResult does not have inserted_ids.
128-
if isinstance(result, BulkWriteResult):
129-
if len(expected_result[res]) != result.inserted_count:
130-
return False
131-
else:
132-
# InsertManyResult may be compared to [id1] from the
133-
# crud spec or {"0": id1} from the retryable write spec.
134-
ids = expected_result[res]
135-
if isinstance(ids, dict):
136-
ids = [ids[str(i)] for i in range(len(ids))]
137-
if ids != result.inserted_ids:
138-
return False
139-
elif prop == "upserted_ids":
140-
# Convert indexes from strings to integers.
141-
ids = expected_result[res]
142-
expected_ids = {}
143-
for str_index in ids:
144-
expected_ids[int(str_index)] = ids[str_index]
145-
if expected_ids != result.upserted_ids:
146-
return False
147-
elif getattr(result, prop) != expected_result[res]:
148-
return False
149-
return True
150-
else:
151-
if not expected_result:
152-
return result is None
153-
else:
154-
return result == expected_result
155-
156-
def check_events(self, expected_events, listener):
157-
res = listener.results
158-
if not len(expected_events):
159-
return
160-
161-
# Expectations only have CommandStartedEvents.
162-
self.assertEqual(len(res['started']), len(expected_events))
163-
for i, expectation in enumerate(expected_events):
164-
event_type = next(iter(expectation))
165-
event = res['started'][i]
166-
167-
# The tests substitute 42 for any number other than 0.
168-
if (event.command_name == 'getMore'
169-
and event.command['getMore']):
170-
event.command['getMore'] = 42
171-
elif event.command_name == 'killCursors':
172-
event.command['cursors'] = [42]
173-
# Add upsert and multi fields back into expectations.
174-
elif event.command_name == 'update':
175-
updates = expectation[event_type]['command']['updates']
176-
for update in updates:
177-
update.setdefault('upsert', False)
178-
update.setdefault('multi', False)
179-
180-
# Replace afterClusterTime: 42 with actual afterClusterTime.
181-
expected_cmd = expectation[event_type]['command']
182-
expected_read_concern = expected_cmd.get('readConcern')
183-
if expected_read_concern is not None:
184-
time = expected_read_concern.get('afterClusterTime')
185-
if time == 42:
186-
actual_time = event.command.get(
187-
'readConcern', {}).get('afterClusterTime')
188-
if actual_time is not None:
189-
expected_read_concern['afterClusterTime'] = actual_time
190-
191-
for attr, expected in expectation[event_type].items():
192-
actual = getattr(event, attr)
193-
if isinstance(expected, dict):
194-
for key, val in expected.items():
195-
if val is None:
196-
if key in actual:
197-
self.fail("Unexpected key [%s] in %r" % (
198-
key, actual))
199-
elif key not in actual:
200-
self.fail("Expected key [%s] in %r" % (
201-
key, actual))
202-
else:
203-
self.assertEqual(val, actual[key],
204-
"Key [%s] in %s" % (key, actual))
205-
else:
206-
self.assertEqual(actual, expected)
49+
def get_outcome_coll_name(self, outcome, collection):
50+
"""Crud spec says outcome has an optional 'collection.name'."""
51+
return outcome['collection'].get('name', collection.name)
20752

20853

20954
def create_test(scenario_def, test, name):
21055
def run_scenario(self):
211-
listener = OvertCommandListener()
212-
# New client, to avoid interference from pooled sessions.
213-
# Convert test['clientOptions'] to dict to avoid a Jython bug using "**"
214-
# with ScenarioDict.
215-
client = rs_client(event_listeners=[listener],
216-
**dict(test.get('clientOptions', {})))
217-
# Close the client explicitly to avoid having too many threads open.
218-
self.addCleanup(client.close)
219-
220-
# Get database and collection objects.
221-
database = getattr(
222-
client, scenario_def.get('database_name', TEST_DB))
223-
drop_collections(database)
224-
collection = getattr(
225-
database, scenario_def.get('collection_name', TEST_COLLECTION))
226-
227-
# Populate collection with data and run test.
228-
collection.with_options(
229-
write_concern=WriteConcern(w="majority")).insert_many(
230-
scenario_def.get('data', []))
231-
listener.results.clear()
232-
self.run_operation(collection, test)
233-
234-
# Assert expected events.
235-
self.check_events(test.get('expectations', {}), listener)
236-
237-
# Assert final state is expected.
238-
expected_outcome = test.get('outcome', {}).get('collection')
239-
if expected_outcome is not None:
240-
collname = expected_outcome.get('name')
241-
if collname is not None:
242-
o_collection = getattr(database, collname)
243-
else:
244-
o_collection = collection
245-
o_collection = o_collection.with_options(
246-
read_concern=ReadConcern(level="local"))
247-
self.assertEqual(list(o_collection.find()),
248-
expected_outcome['data'])
56+
self.run_scenario(scenario_def, test)
24957

25058
return run_scenario
25159

25260

253-
test_creator = TestCreator(create_test, TestAllScenarios, _TEST_PATH)
61+
test_creator = TestCreator(create_test, TestSpec, _TEST_PATH)
25462
test_creator.create_tests()
25563

25664

25765
if __name__ == "__main__":
258-
unittest.main()
66+
unittest.main()

test/utils_spec_runner.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ def check_result(self, expected_result, result):
170170
else:
171171
self.assertEqual(result, expected_result)
172172

173+
def get_object_name(self, op):
174+
"""Allow CRUD spec to override handling of 'object'
175+
176+
Transaction spec says 'object' is required.
177+
"""
178+
return op['object']
179+
173180
def run_operation(self, sessions, collection, operation):
174181
original_collection = collection
175182
name = camel_to_snake(operation['name'])
@@ -200,7 +207,7 @@ def parse_options(opts):
200207
collection = collection.with_options(
201208
**dict(parse_options(operation['collectionOptions'])))
202209

203-
object_name = operation['object']
210+
object_name = self.get_object_name(operation)
204211
if object_name == 'gridfsbucket':
205212
# Only create the GridFSBucket when we need it (for the gridfs
206213
# retryable reads tests).
@@ -340,6 +347,13 @@ def check_events(self, test, listener, session_ids):
340347
event.command['getMore'] = 42
341348
elif event.command_name == 'killCursors':
342349
event.command['cursors'] = [42]
350+
elif event.command_name == 'update':
351+
# TODO: remove this once PYTHON-1744 is done.
352+
# Add upsert and multi fields back into expectations.
353+
updates = expectation[event_type]['command']['updates']
354+
for update in updates:
355+
update.setdefault('upsert', False)
356+
update.setdefault('multi', False)
343357

344358
# Replace afterClusterTime: 42 with actual afterClusterTime.
345359
expected_cmd = expectation[event_type]['command']
@@ -384,6 +398,18 @@ def maybe_skip_scenario(self, test):
384398
if test.get('skipReason'):
385399
raise unittest.SkipTest(test.get('skipReason'))
386400

401+
def get_scenario_db_name(self, scenario_def):
402+
"""Allow CRUD spec to override a test's database name."""
403+
return scenario_def['database_name']
404+
405+
def get_scenario_coll_name(self, scenario_def):
406+
"""Allow CRUD spec to override a test's collection name."""
407+
return scenario_def['collection_name']
408+
409+
def get_outcome_coll_name(self, outcome, collection):
410+
"""Allow CRUD spec to override outcome collection."""
411+
return collection.name
412+
387413
def run_scenario(self, scenario_def, test):
388414
self.maybe_skip_scenario(test)
389415
listener = OvertCommandListener()
@@ -406,7 +432,7 @@ def run_scenario(self, scenario_def, test):
406432
self.kill_all_sessions()
407433
self.addCleanup(self.kill_all_sessions)
408434

409-
database_name = scenario_def['database_name']
435+
database_name = self.get_scenario_db_name(scenario_def)
410436
write_concern_db = client_context.client.get_database(
411437
database_name, write_concern=WriteConcern(w='majority'))
412438
if 'bucket_name' in scenario_def:
@@ -419,7 +445,7 @@ def run_scenario(self, scenario_def, test):
419445
write_concern_db['fs.chunks'].insert_many(data['fs.chunks'])
420446
write_concern_db['fs.files'].insert_many(data['fs.files'])
421447
else:
422-
collection_name = scenario_def['collection_name']
448+
collection_name = self.get_scenario_coll_name(scenario_def)
423449
write_concern_coll = write_concern_db[collection_name]
424450
write_concern_coll.drop()
425451
write_concern_db.create_collection(collection_name)
@@ -491,14 +517,19 @@ def run_scenario(self, scenario_def, test):
491517
'configureFailPoint': 'failCommand', 'mode': 'off'})
492518

493519
# Assert final state is expected.
494-
expected_c = test['outcome'].get('collection')
520+
outcome = test['outcome']
521+
expected_c = outcome.get('collection')
495522
if expected_c is not None:
523+
outcome_coll_name = self.get_outcome_coll_name(
524+
outcome, collection)
525+
496526
# Read from the primary with local read concern to ensure causal
497527
# consistency.
498-
primary_coll = collection.with_options(
528+
outcome_coll = collection.database.get_collection(
529+
outcome_coll_name,
499530
read_preference=ReadPreference.PRIMARY,
500531
read_concern=ReadConcern('local'))
501-
self.assertEqual(list(primary_coll.find()), expected_c['data'])
532+
self.assertEqual(list(outcome_coll.find()), expected_c['data'])
502533

503534

504535
def expect_any_error(op):

0 commit comments

Comments
 (0)