Skip to content

Commit b7d547f

Browse files
authored
Added patching logic to __main__.py (#9)
1 parent 4af38dd commit b7d547f

File tree

1 file changed

+28
-35
lines changed

1 file changed

+28
-35
lines changed

pymongoexplain/__main__.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,48 +13,41 @@
1313
# limitations under the License.
1414

1515

16-
from pymongo import monitoring, MongoClient
1716
from pymongo.collection import Collection
17+
from .explainable_collection import ExplainCollection
18+
1819
import sys
19-
from bson.son import SON
20+
import logging
2021

21-
class CommandLogger(monitoring.CommandListener):
22-
def __init__(self):
23-
self.payloads = []
2422

25-
def started(self, event):
26-
print("started")
27-
self.payloads.append(event.command)
23+
'''This module allows for pymongo scripts to run with with pymongoexplain,
24+
explaining each command as it occurs.
25+
'''
2826

29-
def succeeded(self, event):
30-
pass
3127

32-
def failed(self, event):
33-
pass
28+
FORMAT = '%(asctime)s %(levelname)s %(module)s %(message)s'
29+
logging.basicConfig(format=FORMAT, level=logging.INFO)
3430

35-
if __name__ == '__main__':
31+
old_function_names = ["update_one", "replace_one", "update_many", "delete_one",
32+
"delete_many", "aggregate", "watch", "find", "find_one",
33+
"find_one_and_delete", "find_one_and_replace",
34+
"find_one_and_update", "count_documents",
35+
"estimated_document_count", "distinct"]
36+
old_functions = [getattr(Collection, i) for i in old_function_names]
37+
38+
39+
def make_func(old_func, old_func_name):
40+
def new_func(self: Collection, *args, **kwargs):
41+
res = getattr(ExplainCollection(self),old_func_name)(*args, **kwargs)
42+
logging.info("%s explain response: %s", old_func_name, res)
43+
return old_func(self, *args, **kwargs)
44+
return new_func
3645

46+
47+
for old_func, old_func_name in zip(old_functions, old_function_names):
48+
setattr(Collection, old_func_name, make_func(old_func, old_func_name))
49+
50+
if __name__ == '__main__':
3751
for file in sys.argv[1:]:
3852
with open(file) as f:
39-
logger = CommandLogger()
40-
monitoring.register(logger)
41-
l = ""
42-
for line in f.readlines():
43-
l = l+line
44-
try:
45-
exec(l)
46-
l = ""
47-
except:
48-
continue
49-
collection = [i for i in locals().values() if type(i)
50-
==Collection][0]
51-
print(collection)
52-
for payload in logger.payloads:
53-
payload = SON([("explain", payload), ("verbosity", "queryPlanner")])
54-
print(payload)
55-
print(collection.database.command(payload))
56-
logger.payloads = []
57-
58-
59-
60-
print(sys.argv)
53+
exec(f.read())

0 commit comments

Comments
 (0)