|
3 | 3 | returned by list views. |
4 | 4 | """ |
5 | 5 | import operator |
| 6 | +from typing import Iterable |
6 | 7 | import warnings |
7 | 8 | from functools import reduce |
8 | 9 |
|
@@ -216,33 +217,68 @@ def get_schema_operation_parameters(self, view): |
216 | 217 | }, |
217 | 218 | ] |
218 | 219 |
|
| 220 | +class OrderingExpressionFactory: |
| 221 | + |
| 222 | + def __init__(self, queryset_value: str = '', nulls_as_low: bool = True): |
| 223 | + self.queryset_value = queryset_value |
| 224 | + self.nulls_as_low = nulls_as_low |
| 225 | + |
| 226 | + def get_expression(self, given_value:str): |
| 227 | + queryset_value = self.queryset_value or given_value |
| 228 | + if given_value.startswith('-'): |
| 229 | + return models.F(queryset_value.strip('-')).desc(nulls_last=self.nulls_as_low) |
| 230 | + else: |
| 231 | + return models.F(queryset_value).asc(nulls_first=self.nulls_as_low) |
| 232 | + |
219 | 233 |
|
220 | 234 | class OrderingFilter(BaseFilterBackend): |
| 235 | + ''' |
| 236 | + If you use this class, the ordering fields can be a tuple of |
| 237 | + (<query_value>, <OrderingExpressionFactory instance>). Using `OrderingExpressionFactory`, |
| 238 | + you can map query_param values to certain values in the queryset and whether `null` values |
| 239 | + are considered as low or high. |
| 240 | +
|
| 241 | + Using this class, you cas set `ordering_prefix (Iterable)` attribute to the view. This prefix would be |
| 242 | + considered as the first ordering expressions. |
| 243 | + ''' |
221 | 244 | # The URL query parameter used for the ordering. |
222 | 245 | ordering_param = api_settings.ORDERING_PARAM |
223 | 246 | ordering_fields = None |
224 | 247 | ordering_title = _('Ordering') |
225 | 248 | ordering_description = _('Which field to use when ordering the results.') |
226 | 249 | template = 'rest_framework/filters/ordering.html' |
227 | 250 |
|
228 | | - def get_ordering(self, request, queryset, view): |
| 251 | + def get_ordering(self, request, queryset, view): # returns an iterable of expressions for ordering |
229 | 252 | """ |
230 | 253 | Ordering is set by a comma delimited ?ordering=... query parameter. |
231 | 254 |
|
232 | 255 | The `ordering` query parameter can be overridden by setting |
233 | 256 | the `ordering_param` value on the OrderingFilter or by |
234 | 257 | specifying an `ORDERING_PARAM` value in the API settings. |
235 | 258 | """ |
236 | | - params = request.query_params.get(self.ordering_param) |
237 | | - if params: |
238 | | - fields = [param.strip() for param in params.split(',')] |
239 | | - ordering = self.remove_invalid_fields(queryset, fields, view, request) |
240 | | - if ordering: |
241 | | - return ordering |
| 259 | + params = request.query_params.get(self.ordering_param, '') |
| 260 | + if params or getattr(view, 'ordering_prefix', []): |
| 261 | + params = [param.strip() for param in params.split(',')] |
| 262 | + ordering_expressions = self.get_ordering_expressions(queryset, params, view, request) |
| 263 | + if ordering_expressions: |
| 264 | + return ordering_expressions |
242 | 265 |
|
243 | 266 | # No ordering was included, or all the ordering fields were invalid |
244 | 267 | return self.get_default_ordering(view) |
245 | 268 |
|
| 269 | + def get_ordering_expressions(self, queryset, fields_in_query: Iterable[str], view, request): |
| 270 | + valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request})) |
| 271 | + |
| 272 | + ordering_expressions = list(getattr(view, 'ordering_prefix', [])) or [] |
| 273 | + for field in fields_in_query: |
| 274 | + exp = valid_fields.get(field, None) or valid_fields.get(field[1:], None) |
| 275 | + if exp: |
| 276 | + ordering_expressions.append( |
| 277 | + exp.get_expression(field) if isinstance(exp, OrderingExpressionFactory) |
| 278 | + else field |
| 279 | + ) |
| 280 | + return ordering_expressions |
| 281 | + |
246 | 282 | def get_default_ordering(self, view): |
247 | 283 | ordering = getattr(view, 'ordering', None) |
248 | 284 | if isinstance(ordering, str): |
|
0 commit comments