|
| 1 | +from datetime import datetime |
| 2 | + |
| 3 | +from django.conf import settings |
1 | 4 | from django.db import NotSupportedError
|
| 5 | +from django.db.models import DateField, DateTimeField, TimeField |
2 | 6 | from django.db.models.expressions import Func
|
3 | 7 | from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
|
4 | 8 | from django.db.models.functions.datetime import (
|
@@ -195,6 +199,33 @@ def trunc(self, compiler, connection):
|
195 | 199 | return {"$dateTrunc": lhs_mql}
|
196 | 200 |
|
197 | 201 |
|
| 202 | +def trunc_convert_value(self, value, expression, connection): |
| 203 | + if connection.vendor == "mongodb": |
| 204 | + # A custom TruncBase.convert_value() for MongoDB. |
| 205 | + if value is None: |
| 206 | + return None |
| 207 | + convert_to_tz = settings.USE_TZ and self.get_tzname() != "UTC" |
| 208 | + if isinstance(self.output_field, DateTimeField): |
| 209 | + if convert_to_tz: |
| 210 | + # Unlike other databases, MongoDB returns the value in UTC, |
| 211 | + # so rather than setting the time zone equal to self.tzinfo, |
| 212 | + # the value must be converted to tzinfo. |
| 213 | + value = value.astimezone(self.tzinfo) |
| 214 | + elif isinstance(value, datetime): |
| 215 | + if isinstance(self.output_field, DateField): |
| 216 | + if convert_to_tz: |
| 217 | + value = value.astimezone(self.tzinfo) |
| 218 | + # Truncate for Trunc(..., output_field=DateField) |
| 219 | + value = value.date() |
| 220 | + elif isinstance(self.output_field, TimeField): |
| 221 | + if convert_to_tz: |
| 222 | + value = value.astimezone(self.tzinfo) |
| 223 | + # Truncate for Trunc(..., output_field=TimeField) |
| 224 | + value = value.time() |
| 225 | + return value |
| 226 | + return self.convert_value(value, expression, connection) |
| 227 | + |
| 228 | + |
198 | 229 | def trunc_date(self, compiler, connection):
|
199 | 230 | # Cast to date rather than truncate to date.
|
200 | 231 | lhs_mql = process_lhs(self, compiler, connection)
|
@@ -254,6 +285,7 @@ def register_functions():
|
254 | 285 | Substr.as_mql = substr
|
255 | 286 | Trim.as_mql = trim("trim")
|
256 | 287 | TruncBase.as_mql = trunc
|
| 288 | + TruncBase.convert_value = trunc_convert_value |
257 | 289 | TruncDate.as_mql = trunc_date
|
258 | 290 | TruncTime.as_mql = trunc_time
|
259 | 291 | Upper.as_mql = preserve_null("toUpper")
|
0 commit comments