diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index 957c5f155..2d14f3d16 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -203,20 +203,26 @@ def when(self, compiler, connection): def value(self, compiler, connection): # noqa: ARG001 value = self.value + if isinstance(value, int): + # Wrap numbers in $literal to prevent ambiguity when Value appears in + # $project. + return {"$literal": value} if isinstance(value, Decimal): - value = Decimal128(value) - elif isinstance(value, datetime.date): + return Decimal128(value) + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): # Turn dates into datetimes since BSON doesn't support dates. - value = datetime.datetime.combine(value, datetime.datetime.min.time()) - elif isinstance(value, datetime.time): + return datetime.datetime.combine(value, datetime.datetime.min.time()) + if isinstance(value, datetime.time): # Turn times into datetimes since BSON doesn't support times. - value = datetime.datetime.combine(datetime.datetime.min.date(), value) - elif isinstance(value, datetime.timedelta): + return datetime.datetime.combine(datetime.datetime.min.date(), value) + if isinstance(value, datetime.timedelta): # DurationField stores milliseconds rather than microseconds. - value /= datetime.timedelta(milliseconds=1) - elif isinstance(value, UUID): - value = value.hex - return {"$literal": value} + return value / datetime.timedelta(milliseconds=1) + if isinstance(value, UUID): + return value.hex + return value def register_expressions(): diff --git a/tests/expressions_/__init__.py b/tests/expressions_/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/expressions_/test_value.py b/tests/expressions_/test_value.py new file mode 100644 index 000000000..c57c2f032 --- /dev/null +++ b/tests/expressions_/test_value.py @@ -0,0 +1,43 @@ +import datetime +import uuid +from decimal import Decimal + +from bson import Decimal128 +from django.db.models import Value +from django.test import SimpleTestCase + + +class ValueTests(SimpleTestCase): + def test_date(self): + self.assertEqual( + Value(datetime.date(2025, 1, 1)).as_mql(None, None), + datetime.datetime(2025, 1, 1), + ) + + def test_datetime(self): + self.assertEqual( + Value(datetime.datetime(2025, 1, 1, 9, 8, 7)).as_mql(None, None), + datetime.datetime(2025, 1, 1, 9, 8, 7), + ) + + def test_decimal(self): + self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), Decimal128("1.0")) + + def test_time(self): + self.assertEqual( + Value(datetime.time(9, 8, 7)).as_mql(None, None), + datetime.datetime(1, 1, 1, 9, 8, 7), + ) + + def test_timedelta(self): + self.assertEqual(Value(datetime.timedelta(3600)).as_mql(None, None), 311040000000.0) + + def test_int(self): + self.assertEqual(Value(1).as_mql(None, None), {"$literal": 1}) + + def test_str(self): + self.assertEqual(Value("foo").as_mql(None, None), "foo") + + def test_uuid(self): + value = uuid.UUID(int=1) + self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001")