diff --git a/django_mongodb/base.py b/django_mongodb/base.py index 8efcbb8e6..180ef87cb 100644 --- a/django_mongodb/base.py +++ b/django_mongodb/base.py @@ -12,7 +12,7 @@ from .operations import DatabaseOperations from .query_utils import regex_match from .schema import DatabaseSchemaEditor -from .utils import CollectionDebugWrapper +from .utils import OperationDebugWrapper class Cursor: @@ -137,9 +137,14 @@ def __init__(self, *args, **kwargs): def get_collection(self, name, **kwargs): collection = Collection(self.database, name, **kwargs) if self.queries_logged: - collection = CollectionDebugWrapper(collection, self) + collection = OperationDebugWrapper(self, collection) return collection + def get_database(self): + if self.queries_logged: + return OperationDebugWrapper(self) + return self.database + def __getattr__(self, attr): """ Connect to the database the first time `connection` or `database` are diff --git a/django_mongodb/schema.py b/django_mongodb/schema.py index 82592820c..3132ce20d 100644 --- a/django_mongodb/schema.py +++ b/django_mongodb/schema.py @@ -4,9 +4,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): + def get_collection(self, name): + return self.connection.get_collection(name) + + def get_database(self): + return self.connection.get_database() + @wrap_database_errors def create_model(self, model): - self.connection.database.create_collection(model._meta.db_table) + self.get_database().create_collection(model._meta.db_table) # Make implicit M2M tables. for field in model._meta.local_many_to_many: if field.remote_field.through._meta.auto_created: @@ -17,7 +23,7 @@ def delete_model(self, model): for field in model._meta.local_many_to_many: if field.remote_field.through._meta.auto_created: self.delete_model(field.remote_field.through) - self.connection.database[model._meta.db_table].drop() + self.get_collection(model._meta.db_table).drop() def add_field(self, model, field): # Create implicit M2M tables. @@ -26,7 +32,7 @@ def add_field(self, model, field): return # Set default value on existing documents. if column := field.column: - self.connection.database[model._meta.db_table].update_many( + self.get_collection(model._meta.db_table).update_many( {}, [{"$set": {column: self.effective_default(field)}}] ) @@ -41,7 +47,7 @@ def _alter_field( new_db_params, strict=False, ): - collection = self.connection.database[model._meta.db_table] + collection = self.get_collection(model._meta.db_table) # Have they renamed the column? if old_field.column != new_field.column: collection.update_many({}, {"$rename": {old_field.column: new_field.column}}) @@ -59,7 +65,7 @@ def remove_field(self, model, field): return # Unset field on existing documents. if column := field.column: - self.connection.database[model._meta.db_table].update_many({}, {"$unset": {column: ""}}) + self.get_collection(model._meta.db_table).update_many({}, {"$unset": {column: ""}}) def alter_index_together(self, model, old_index_together, new_index_together): pass @@ -85,4 +91,4 @@ def remove_constraint(self, model, constraint): def alter_db_table(self, model, old_db_table, new_db_table): if old_db_table == new_db_table: return - self.connection.database[old_db_table].rename(new_db_table) + self.get_collection(old_db_table).rename(new_db_table) diff --git a/django_mongodb/utils.py b/django_mongodb/utils.py index 4dd68258b..10d745799 100644 --- a/django_mongodb/utils.py +++ b/django_mongodb/utils.py @@ -25,13 +25,16 @@ def check_django_compatability(): ) -class CollectionDebugWrapper: - def __init__(self, collection, db): +class OperationDebugWrapper: + def __init__(self, db, collection=None): self.collection = collection self.db = db + use_collection = collection is not None + self.collection_name = f"{collection.name}." if use_collection else "" + self.wrapped = self.collection if use_collection else self.db.database def __getattr__(self, attr): - return getattr(self.collection, attr) + return getattr(self.wrapped, attr) def profile_call(self, func, args=(), kwargs=None): start = time.monotonic() @@ -43,8 +46,8 @@ def log(self, op, duration, args, kwargs=None): # If kwargs are used by any operations in the future, they must be # added to this logging. msg = "(%.3f) %s" - args = ", ".join(str(arg) for arg in args) - operation = f"{self.collection.name}.{op}({args})" + args = ", ".join(repr(arg) for arg in args) + operation = f"db.{self.collection_name}{op}({args})" if len(settings.DATABASES) > 1: msg += f"; alias={self.db.alias}" self.db.queries_log.append( @@ -66,7 +69,7 @@ def log(self, op, duration, args, kwargs=None): def logging_wrapper(method): def wrapper(self, *args, **kwargs): - func = getattr(self.collection, method) + func = getattr(self.wrapped, method) # Collection.insert_many() mutates args (the documents) by adding # _id. deepcopy() to avoid logging that version. original_args = copy.deepcopy(args) @@ -78,8 +81,11 @@ def wrapper(self, *args, **kwargs): # These are the operations that this backend uses. aggregate = logging_wrapper("aggregate") + create_collection = logging_wrapper("create_collection") + drop = logging_wrapper("drop") insert_many = logging_wrapper("insert_many") delete_many = logging_wrapper("delete_many") + rename = logging_wrapper("rename") update_many = logging_wrapper("update_many") del logging_wrapper