|
5 | 5 | import operator |
6 | 6 | import warnings |
7 | 7 | from functools import reduce |
| 8 | +from typing import Iterable |
8 | 9 |
|
9 | 10 | from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured |
10 | 11 | from django.db import models |
@@ -217,32 +218,68 @@ def get_schema_operation_parameters(self, view): |
217 | 218 | ] |
218 | 219 |
|
219 | 220 |
|
| 221 | +class OrderingExpressionFactory: |
| 222 | + |
| 223 | + def __init__(self, queryset_value: str = '', nulls_as_low: bool = True): |
| 224 | + self.queryset_value = queryset_value |
| 225 | + self.nulls_as_low = nulls_as_low |
| 226 | + |
| 227 | + def get_expression(self, given_value: str): |
| 228 | + queryset_value = self.queryset_value or given_value |
| 229 | + if given_value.startswith('-'): |
| 230 | + return models.F(queryset_value.strip('-')).desc(nulls_last=self.nulls_as_low) |
| 231 | + else: |
| 232 | + return models.F(queryset_value).asc(nulls_first=self.nulls_as_low) |
| 233 | + |
| 234 | + |
220 | 235 | class OrderingFilter(BaseFilterBackend): |
| 236 | + ''' |
| 237 | + If you use this class, the ordering fields can be a tuple of |
| 238 | + (<query_value>, <OrderingExpressionFactory instance>). Using `OrderingExpressionFactory`, |
| 239 | + you can map query_param values to certain values in the queryset and whether `null` values |
| 240 | + are considered as low or high. |
| 241 | +
|
| 242 | + Using this class, you cas set `ordering_prefix (Iterable)` attribute to the view. This prefix would be |
| 243 | + considered as the first ordering expressions. |
| 244 | + ''' |
221 | 245 | # The URL query parameter used for the ordering. |
222 | 246 | ordering_param = api_settings.ORDERING_PARAM |
223 | 247 | ordering_fields = None |
224 | 248 | ordering_title = _('Ordering') |
225 | 249 | ordering_description = _('Which field to use when ordering the results.') |
226 | 250 | template = 'rest_framework/filters/ordering.html' |
227 | 251 |
|
228 | | - def get_ordering(self, request, queryset, view): |
| 252 | + def get_ordering(self, request, queryset, view): # returns an iterable of expressions for ordering |
229 | 253 | """ |
230 | 254 | Ordering is set by a comma delimited ?ordering=... query parameter. |
231 | 255 |
|
232 | 256 | The `ordering` query parameter can be overridden by setting |
233 | 257 | the `ordering_param` value on the OrderingFilter or by |
234 | 258 | specifying an `ORDERING_PARAM` value in the API settings. |
235 | 259 | """ |
236 | | - params = request.query_params.get(self.ordering_param) |
237 | | - if params: |
238 | | - fields = [param.strip() for param in params.split(',')] |
239 | | - ordering = self.remove_invalid_fields(queryset, fields, view, request) |
240 | | - if ordering: |
241 | | - return ordering |
| 260 | + params = request.query_params.get(self.ordering_param, '') |
| 261 | + if params or getattr(view, 'ordering_prefix', []): |
| 262 | + params = [param.strip() for param in params.split(',')] |
| 263 | + ordering_expressions = self.get_ordering_expressions(queryset, params, view, request) |
| 264 | + if ordering_expressions: |
| 265 | + return ordering_expressions |
242 | 266 |
|
243 | 267 | # No ordering was included, or all the ordering fields were invalid |
244 | 268 | return self.get_default_ordering(view) |
245 | 269 |
|
| 270 | + def get_ordering_expressions(self, queryset, fields_in_query: Iterable[str], view, request): |
| 271 | + valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request})) |
| 272 | + |
| 273 | + ordering_expressions = list(getattr(view, 'ordering_prefix', [])) or [] |
| 274 | + for field in fields_in_query: |
| 275 | + exp = valid_fields.get(field, None) or valid_fields.get(field[1:], None) |
| 276 | + if exp: |
| 277 | + ordering_expressions.append( |
| 278 | + exp.get_expression(field) if isinstance(exp, OrderingExpressionFactory) |
| 279 | + else field |
| 280 | + ) |
| 281 | + return ordering_expressions |
| 282 | + |
246 | 283 | def get_default_ordering(self, view): |
247 | 284 | ordering = getattr(view, 'ordering', None) |
248 | 285 | if isinstance(ordering, str): |
|
0 commit comments