-
-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Description
Hey, I appreciate I'm sidestepping the discussion phase, but I think this is a fairly clear bug (and there's no test coverage for it).
The CursorPagination class makes two assumptions that don't always hold true: that there will be one ordering parameter; and that ordering parameter will be unique. It's not clear from the documentation that this needs to be enforced, and more importantly, the CursorPagination class defaults to -created which has no reason to be unique.
This is an example of where it fails:
import base64
import itertools
import re
from base64 import b64encode
from urllib import parse
import pytest
from django.db import models
from rest_framework import generics
from rest_framework.pagination import Cursor
from rest_framework.pagination import CursorPagination as BrokenCursorPagination
from rest_framework.permissions import AllowAny
from rest_framework.serializers import ModelSerializer
from rest_framework.test import APIRequestFactory
from rest_framework.pagination import HeaderCursorPagination
factory = APIRequestFactory()
class ExampleModel(models.Model):
# Don't use an auto field because we can't reset
# sequences and that's needed for this test
id = models.IntegerField(primary_key=True)
field = models.IntegerField()
timestamp = models.IntegerField()
class Meta:
app_label = "test_app"
class SerializerCls(ModelSerializer):
class Meta:
model = ExampleModel
fields = "__all__"
def create_cursor(offset, reverse, position):
# Taken from rest_framework.pagination
cursor = Cursor(offset=offset, reverse=reverse, position=position)
tokens = {}
if cursor.offset != 0:
tokens["o"] = str(cursor.offset)
if cursor.reverse:
tokens["r"] = "1"
if cursor.position is not None:
tokens["p"] = cursor.position
querystring = parse.urlencode(tokens, doseq=True)
return b64encode(querystring.encode("ascii")).decode("ascii")
def decode_cursor(response):
cursors = {}
for match in re.finditer('<(.*?)>; rel="(.*?)"', response["link"]):
link = match.group(1)
rel = match.group(2)
# Don't hate my laziness - copied from an IPDB prompt
cursor_dict = dict(
parse.parse_qsl(
base64.decodebytes(
(parse.parse_qs(parse.urlparse(link).query)["cursor"][0]).encode()
)
)
)
offset = cursor_dict.get(b"o", 0)
if offset:
offset = int(offset)
reverse = cursor_dict.get(b"r", False)
if reverse:
reverse = int(reverse)
position = cursor_dict.get(b"p", None)
cursors[rel] = Cursor(
offset=offset,
reverse=reverse,
position=position,
)
return type(
"prev_next_stuct",
(object,),
{"next": cursors.get("next"), "prev": cursors.get("previous")},
)
@pytest.mark.django_db
def test_filtered_items_are_paginated():
class PaginationCls(HeaderCursorPagination):
page_size = 2
max_page_size = 20
offset_cutoff = 6
example_models = []
for id_, (field_1, field_2) in enumerate(
itertools.product(range(1, 11), range(1, 3))
):
# field_1 is a unique range from 1-10 inclusive
# field_2 is the 'timestamp' field. 1 or 2
example_models.append(
ExampleModel(
# manual primary key
id=id_ + 1,
field=field_1,
timestamp=field_2,
)
)
ExampleModel.objects.bulk_create(example_models)
view = generics.ListAPIView.as_view(
serializer_class=SerializerCls,
queryset=ExampleModel.objects.all(),
pagination_class=PaginationCls,
permission_classes=(AllowAny,),
filter_backends=[OrderingFilter],
)
def _request(offset, reverse, position):
return view(
factory.get(
"/",
{
PaginationCls.cursor_query_param: create_cursor(
offset, reverse, position
),
"ordering": "timestamp",
},
)
)
# This is the result we would expect
expected_result = list(
ExampleModel.objects.order_by("timestamp", "id").values(
"timestamp",
"id",
"field",
)
)
assert expected_result == [
{"field": 1, "id": 1, "timestamp": 1},
{"field": 2, "id": 3, "timestamp": 1},
{"field": 3, "id": 5, "timestamp": 1},
{"field": 4, "id": 7, "timestamp": 1},
{"field": 5, "id": 9, "timestamp": 1},
{"field": 6, "id": 11, "timestamp": 1},
{"field": 7, "id": 13, "timestamp": 1},
{"field": 8, "id": 15, "timestamp": 1},
{"field": 9, "id": 17, "timestamp": 1},
{"field": 10, "id": 19, "timestamp": 1},
{"field": 1, "id": 2, "timestamp": 2},
{"field": 2, "id": 4, "timestamp": 2},
{"field": 3, "id": 6, "timestamp": 2},
{"field": 4, "id": 8, "timestamp": 2},
{"field": 5, "id": 10, "timestamp": 2},
{"field": 6, "id": 12, "timestamp": 2},
{"field": 7, "id": 14, "timestamp": 2},
{"field": 8, "id": 16, "timestamp": 2},
{"field": 9, "id": 18, "timestamp": 2},
{"field": 10, "id": 20, "timestamp": 2},
]
response = _request(0, False, None)
next_cursor = decode_cursor(response).next
position = 0
while next_cursor:
assert (
expected_result[position : position + len(response.data)] == response.data
)
position += len(response.data)
response = _request(*next_cursor)
next_cursor = decode_cursor(response).next
prev_cursor = decode_cursor(response).prev
position = 20
while prev_cursor:
assert (
expected_result[position - len(response.data) : position] == response.data
)
position -= len(response.data)
response = _request(*prev_cursor)
prev_cursor = decode_cursor(response).prev
The cursor created by the paginator will create a query like this timestamp>1, which isn't enough information to actually paginate a result like this.
I've written a fix for this, where I create a compound cursor with all of the ordering parameters plus the primary key. I then do something like a tuple comparison between the cursor and the table (using some Q objects). It works fine, and this test passes with the fix.
Let me know if you'd like the patch.