22# -*- coding: utf-8 -*-
33
44from pymongo import MongoClient
5- import logging
5+ from . errors import ConnectionError
66import time
77
88
9- class MongodbConnector :
9+ class MongoDB :
1010 def __init__ (self ,
11- host = None ,
11+ host : str = None ,
1212 port = None ,
13- default_database = None ,
14- default_collection = None ,
15- connect = False ,
16- endpoint = None ):
13+ default_database : str = None ,
14+ default_collection : str = None ,
15+ connect : bool = False ,
16+ endpoint : str = None ):
1717
1818 if endpoint is not None :
1919 host , port = endpoint .split (':' )
@@ -33,7 +33,7 @@ def __init__(self,
3333 def client (self ):
3434 return self ._client
3535
36- def _get_collection (self , database_name , collection_name ):
36+ def _get_collection (self , database_name : str , collection_name : str ):
3737 self .open_connection ()
3838 database_name , collection_name = \
3939 self ._get_database_and_collection_names (database_name ,
@@ -42,19 +42,19 @@ def _get_collection(self, database_name, collection_name):
4242 collection = getattr (database , collection_name )
4343 return collection
4444
45- def set_defaults (self , database_name , database_collection = None ):
45+ def set_defaults (self , database_name : str , database_collection : str = None ):
4646 self ._default_database = database_name
4747 if database_collection != None :
4848 self ._default_collection = database_collection
4949
50- def open_connection (self , attempts = 10 ):
50+ def open_connection (self , attempts : int = 10 ):
5151 if not self ._client :
5252 try :
5353 self ._client = MongoClient (** self ._config )
5454 self ._client .server_info ()
5555 except Exception :
5656 if attempts == 0 :
57- logging . exception ( '' )
57+ raise ConnectionError
5858 else :
5959 time .sleep (1 )
6060 self .open_connection (attempts - 1 )
@@ -63,14 +63,21 @@ def close_connection(self):
6363 if self ._client :
6464 self ._client .close ()
6565
66- def _get_database_and_collection_names (self , database_name , collection_name ):
66+ def _get_database_and_collection_names (
67+ self ,
68+ database_name : str ,
69+ collection_name : str ,
70+ ):
6771 if not database_name and hasattr (self , '_default_database' ):
6872 database_name = self ._default_database
6973 if not collection_name and hasattr (self , '_default_collection' ):
7074 collection_name = self ._default_collection
7175 return database_name , collection_name
7276
73- def get_and_close (self , query = None , database_name = None , collection_name = None ):
77+ def get_and_close (self ,
78+ query = None ,
79+ database_name : str = None ,
80+ collection_name : str = None ):
7481 result = self .get (query , database_name , collection_name )
7582 self .close_connection ()
7683 return result
@@ -82,13 +89,15 @@ def exists(self, query, database_name=None, collection_name=None):
8289 return True
8390
8491 def create_index (self ,
85- attribute = None ,
92+ attribute : str = None ,
8693 keys = None ,
87- database_name = None ,
88- collection_name = None ,
89- unique = False ,
90- type_ = 'asc' ,
91- background = True ):
94+ database_name : str = None ,
95+ collection_name : str = None ,
96+ unique : bool = False ,
97+ type_ : str = None ,
98+ background : bool = True ):
99+ if type_ is None :
100+ type_ = 'asc'
92101 if not keys and not attribute :
93102 raise TypeError
94103 if not keys :
@@ -102,15 +111,15 @@ def create_index(self,
102111
103112 def get (self ,
104113 query = None ,
105- database_name = None ,
106- collection_name = None ,
107- sort = None ,
108- sort_attribute = None ,
109- sort_type = None ,
110- limit = None ,
114+ database_name : str = None ,
115+ collection_name : str = None ,
116+ sort : str = None ,
117+ sort_attribute : str = None ,
118+ sort_type : str = None ,
119+ limit : int = None ,
111120 index = None ,
112- index_attribute = None ,
113- index_type = None ):
121+ index_attribute : str = None ,
122+ index_type : str = None ):
114123 collection = self ._get_collection (database_name , collection_name )
115124 if sort_attribute and sort_type :
116125 if sort_type == 'desc' :
@@ -139,12 +148,14 @@ def get(self,
139148
140149 def get_random (self ,
141150 query = None ,
142- database_name = None ,
143- collection_name = None ,
151+ database_name : str = None ,
152+ collection_name : str = None ,
144153 sort = None ,
145- sort_attribute = None ,
146- sort_type = None ,
147- limit = 1 ):
154+ sort_attribute : str = None ,
155+ sort_type : str = None ,
156+ limit : int = None ):
157+ if limit is None :
158+ limit = 1
148159 collection = self ._get_collection (database_name , collection_name )
149160 if sort_attribute and sort_type :
150161 if sort_type == 'desc' :
@@ -160,30 +171,56 @@ def get_random(self,
160171 result = collection .aggregate (operation )
161172 return result
162173
163- def update (self , query , value , database_name = None , collection_name = None ):
174+ def update (self ,
175+ query ,
176+ value ,
177+ database_name : str = None ,
178+ collection_name : str = None ):
164179 collection = self ._get_collection (database_name , collection_name )
165180 collection .update (query , {'$set' : value }, upsert = True )
166181
167- def push (self , query , key , value , database_name = None , collection_name = None ):
182+ def push (self ,
183+ query ,
184+ key : str ,
185+ value ,
186+ database_name : str = None ,
187+ collection_name : str = None ):
168188 collection = self ._get_collection (database_name , collection_name )
169189 collection .update (query , {'$push' : {key : {'$each' : value }}})
170190
171- def put (self , value , query = None , database_name = None , collection_name = None ):
191+ def put (self ,
192+ value ,
193+ query = None ,
194+ database_name : str = None ,
195+ collection_name : str = None ):
172196 if query :
173197 self .update (query , value , database_name , collection_name )
174198 else :
175199 self .insert (value , database_name , collection_name )
176200
177- def remove (self , query = None , database_name = None , collection_name = None ):
201+ def remove (self ,
202+ query = None ,
203+ database_name : str = None ,
204+ collection_name : str = None ):
178205 collection = self ._get_collection (database_name , collection_name )
179206 collection .remove (query )
180207
181- def insert (self , value , database_name = None , collection_name = None ):
208+ def insert (self ,
209+ value ,
210+ database_name : str = None ,
211+ collection_name : str = None ):
182212 collection = self ._get_collection (database_name , collection_name )
183213 collection .insert_one (dict (value ))
184214
185- def count (self , query = None , database_name = None , collection_name = None ):
215+ def count (self ,
216+ query = None ,
217+ database_name : str = None ,
218+ collection_name : str = None ):
186219 collection = self ._get_collection (database_name , collection_name )
187220 if not query :
188221 query = {}
189222 return collection .count_documents (filter = query )
223+
224+
225+ # Backwards compatible
226+ MongodbConnector = MongoDB
0 commit comments