Skip to content

Commit 2ec9c85

Browse files
committed
fix(identity-vault): pagination while scanning overrides in the API
In `parallel_dynamo.py`, we were paginating over the results for our callers. While this simplifies the callers' code, it also increases their run time, which is an issue when paired with AWS API Gateway. The fix here is to do what we say we're going to do: allow the users to paginate if they feel. We provide a way for users to specify where their next page should start, via `scan`'s `exclusive_start_key` argument. We hint as much in the return, naming this key `nextPage`. But, doing this requires fixing our segmentation logic! Each segment receives it's own `ExclusiveStartKey`, which needs to be used the next time we call our next scan operation. Jira: IAM-1793
1 parent 87c7030 commit 2ec9c85

File tree

5 files changed

+200
-37
lines changed

5 files changed

+200
-37
lines changed

python-modules/cis_identity_vault/cis_identity_vault/models/user.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,65 @@ def all(self):
317317
users.extend(response["Items"])
318318
return users
319319

320-
def _last_evaluated_to_friendly(self, last_evaluated_key):
321-
if last_evaluated_key is None:
320+
def _last_evaluated_to_friendly(self, last_evaluated_keys):
321+
"""
322+
Received from Dynamo, and serialized into something our clients can
323+
understand (or rather, use: this _should_ be an opaque token to
324+
clients).
325+
326+
When we're paginating through Dynamo, each segment returns a
327+
`LastEvaluatedKey`, which we need to specify as `ExclusiveStartKey` in
328+
subsequent requests. These `ExclusiveStartKey` is segment-specific,
329+
hence the care here to serialize these in the order returned.
330+
331+
* `None`, indicating that we've completely finished, there are no more
332+
results any segment can return;
333+
* `list[Optional[Any]]`, indicating that _some_ segments have results
334+
left.
335+
336+
Our clients' pagination logic (at least as seen by various publishers),
337+
will consider the query over once we return `None`.
338+
"""
339+
if not last_evaluated_keys:
322340
return None
341+
next_page = []
342+
for last_evaluated_key in last_evaluated_keys:
343+
# A signal that the segment is done scanning.
344+
if last_evaluated_key is None:
345+
id = ""
346+
else:
347+
id = last_evaluated_key["id"]["S"]
348+
next_page.append(id)
349+
# If there are at all any segments left with work, then continue.
350+
if any(next_page):
351+
return ",".join(next_page)
323352
else:
324-
return last_evaluated_key["id"]["S"]
353+
return None
325354

326355
def _next_page_to_dynamodb(self, next_page):
327-
if next_page is not None:
328-
return {"id": {"S": next_page}}
356+
"""
357+
Received from _clients_, and deserialized into something our parallel
358+
Dynamo code understands.
359+
360+
A complication here is that we can't reuse `None`, since that would
361+
cause a segment to start from the beginning. So, we use a sentinel
362+
value of `"done"` to signal to the parallel Dynamo code that we should
363+
skip this segment.
364+
365+
When Dynamo returns `None`, that means _all_ segments are done. If it
366+
returns a `list[Optional[Any]]`, that means that we can still make
367+
progress on some segments.
368+
"""
369+
if not next_page:
370+
return None
371+
exclusive_start_keys = []
372+
for last_evaluated_key in next_page.split(","):
373+
if last_evaluated_key == "":
374+
id = "done"
375+
else:
376+
id = {"id": {"S": last_evaluated_key}}
377+
exclusive_start_keys.append(id)
378+
return exclusive_start_keys
329379

330380
def all_filtered(self, connection_method=None, active=None, next_page=None):
331381
"""
@@ -352,7 +402,7 @@ def all_filtered(self, connection_method=None, active=None, next_page=None):
352402
filter_expression=filter_expression,
353403
expression_attr=expression_attr,
354404
projection_expression=projection_expression,
355-
exclusive_start_key=next_page,
405+
exclusive_start_keys=next_page,
356406
)
357407
return dict(users=response["users"], nextPage=self._last_evaluated_to_friendly(response.get("nextPage")))
358408

python-modules/cis_identity_vault/cis_identity_vault/parallel_dynamo.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,35 +36,55 @@ def get_segment(
3636

3737
logger.debug("Running parallel scan with kwargs: {}".format(scan_kwargs))
3838
response = dynamodb_client.scan(**scan_kwargs)
39-
users = response.get("Items", [])
39+
# Return a dictionary of users, since sets can only contain hashable types
40+
# (and lists and dicts are not).
41+
users = {
42+
user["id"]["S"]: user
43+
for user in response.get("Items", [])
44+
}
4045
last_evaluated_key = response.get("LastEvaluatedKey")
4146

42-
while last_evaluated_key is not None:
43-
scan_kwargs["ExclusiveStartKey"] = last_evaluated_key
44-
response = dynamodb_client.scan(**scan_kwargs)
45-
users.extend(response.get("Items", []))
46-
last_evaluated_key = response.get("LastEvaluatedKey")
47-
48-
logger.debug("Running thread_id: {}".format(thread_id))
47+
logger.debug("Finished thread_id: {}, with nextPage: {}".format(thread_id, last_evaluated_key))
4948
return result_queue.put(dict(users=users, nextPage=last_evaluated_key, segment=thread_id))
5049

5150

5251
def scan(
53-
dynamodb_client, table_name, filter_expression, expression_attr, projection_expression, exclusive_start_key=None
52+
dynamodb_client, table_name, filter_expression, expression_attr, projection_expression, exclusive_start_keys=None
5453
):
5554
logger.debug("Creating new threads and queue.")
5655
result_queue = queue.Queue()
5756

58-
# The worker pool size should be equal to the max_segments. Ideally we want one segment per worker.
59-
pool_size = 128
60-
max_segments = 128
57+
# We use one worker per segment.
58+
max_segments = 48
6159

62-
users = []
63-
last_evaluated_key = None
60+
users = dict()
61+
last_evaluated_keys = [None] * max_segments
6462
threads = []
6563

66-
for thread_id in range(0, pool_size):
64+
# If this is the first request, then we'll receive a None from our
65+
# caller.
66+
if exclusive_start_keys is None:
67+
exclusive_start_keys = [None] * max_segments
68+
69+
# When we're continuing, we signal that a segment has no more work to
70+
# complete if it's ESK is "done". If _all_ of the segments have that, then
71+
# we're at the end of our result set.
72+
elif all(map(lambda esk: esk == "done", exclusive_start_keys)):
73+
return dict(users=[], nextPage=None)
74+
75+
for thread_id in range(0, max_segments):
6776
# What are we passing to each threaded function.
77+
try:
78+
exclusive_start_key = exclusive_start_keys[thread_id]
79+
except IndexError:
80+
logger.critical("Someone may be DOSing us or not doing pagination properly.")
81+
raise
82+
83+
# If we explicitly read a "done", then this is a signal that the
84+
# segment has no more records.
85+
if exclusive_start_key == "done":
86+
logger.debug(f"skipping thread {thread_id}")
87+
continue
6888

6989
thread_args = (
7090
result_queue,
@@ -102,13 +122,10 @@ def scan(
102122
while not result_queue.empty():
103123
logger.debug("Results queue is not empty.")
104124
result = result_queue.get()
105-
users_additional = result.get("users")
106-
users.extend(users_additional)
107-
if result.get("segment") == max_segments - 1:
108-
logger.debug("This is the last segment.")
109-
last_evaluated_key = result.get("nextPage")
110-
logger.debug("Last evaluated key in page was: {}".format(last_evaluated_key))
125+
users.update(result.get("users", {}))
126+
segment = result.get("segment")
127+
last_evaluated_keys[segment] = result.get("nextPage")
111128
result_queue.task_done()
112129

113130
logger.debug("Results queue is empty.")
114-
return dict(users=users, nextPage=last_evaluated_key)
131+
return dict(users=users.values(), nextPage=last_evaluated_keys)

python-modules/cis_profile_retrieval_service/cis_profile_retrieval_service/v2_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import orjson
23

34
from flask import Flask
@@ -370,7 +371,11 @@ def version():
370371

371372

372373
def main():
373-
app.run(host="0.0.0.0", debug=True)
374+
# DEBT: I think this has been fixed in a later version of Flask.
375+
# We don't call this in production, but instead lean on Serverless' WSGI handler.
376+
host, _, port = os.environ.get("SERVER_NAME", "127.0.0.1:5000").partition(":")
377+
port = int(port)
378+
app.run(host=host, port=port, debug=True)
374379

375380

376381
if __name__ == "__main__":
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
n.b. Do not run this as a part of your development cycle. Do not run this
3+
regularly. This can, and will, insert weird data into CIS, if misconfigured.
4+
5+
A pseudo-e2e test, where we run the code locally but use dev/stage resources.
6+
This is the code equivalent of running with scissors.
7+
8+
Requires the following environment variables:
9+
10+
CIS_ENVIRONMENT="testing"
11+
CIS_SEED_API_DATA="false"
12+
PERSON_API_ADVANCED_SEARCH="true"
13+
PERSON_API_INITIALIZE_VAULT="false"
14+
PERSON_API_JWT_VALIDATION="false"
15+
SERVER_NAME="127.0.0.1:8000"
16+
17+
Run the server with:
18+
19+
python cis_profile_retrieval_service/v2_api.py
20+
21+
Run the test with:
22+
23+
PSEUDO_E2E=yes pytest --log-cli-level=DEBUG tests/test_e2e.py
24+
25+
You'll also need an active AWS session, which is left as an exercise for the
26+
user.
27+
"""
28+
29+
import logging
30+
import os
31+
import pytest
32+
import requests
33+
34+
from tests.fake_auth0 import FakeBearer
35+
36+
logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s:%(message)s")
37+
logging.getLogger("faker.factory").setLevel(logging.INFO)
38+
logging.getLogger("urllib3").setLevel(logging.INFO)
39+
logger = logging.getLogger(__name__)
40+
41+
42+
@pytest.fixture
43+
def auth_headers():
44+
bearer = FakeBearer()
45+
token = bearer.generate_bearer_with_scope("display:all search:all")
46+
headers = {"Authorization": f"Bearer {token}"}
47+
return headers
48+
49+
50+
@pytest.mark.skipif(os.environ.get("PSEUDO_E2E") is None, reason="Not running in pseudo-E2E mode.")
51+
def test_retrieve_single_profile(auth_headers):
52+
res = requests.get(
53+
"http://localhost:8000/v2/users/id/all?connectionMethod=ad&active=True",
54+
headers=auth_headers,
55+
).json()
56+
next_page = res.get("nextPage")
57+
pages = 1
58+
while next_page:
59+
res = requests.get(
60+
f"http://localhost:8000/v2/users/id/all?connectionMethod=ad&active=True&nextPage={next_page}",
61+
headers=auth_headers,
62+
).json()
63+
next_page = res.get("nextPage")
64+
pages += 1
65+
assert pages >= 2, "Did not iterate through any pages."

python-modules/cis_profile_retrieval_service/tests/test_v2_api_pagination.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
2-
A simplification of the existing v2_api tests. We reuse Dynamo and the users we
3-
create across each test, since otherwise it'll take a while.
2+
A simplification of the existing v2_api tests. There's care to reuse Dynamo and
3+
the users we create across each test, since otherwise it'll take a while when
4+
adding additional tests.
45
56
See `DEBT` notes for where I ran into weirdness.
67
@@ -63,31 +64,56 @@ def identity_vault(environment):
6364
vault_client.find_or_create()
6465
# DEBT: requires side-effects.
6566
from cis_profile_retrieval_service.common import seed
66-
# DEBT?: doesn't seemingly generate only `number_of_fake_users` users.
67+
6768
# DEBT: doesn't generate `ad|Mozilla-LDAP` users.
68-
seed(number_of_fake_users=128)
69+
seed(number_of_fake_users=256)
6970
return (boto3.client("dynamodb"), boto3.resource("dynamodb"))
7071

7172

7273
@pytest.fixture
7374
def app(environment, monkeypatch):
7475
bearer = FakeBearer()
7576
token = bearer.generate_bearer_with_scope("display:all search:all")
76-
headers = {
77-
"Authorization": f"Bearer {token}"
78-
}
77+
headers = {"Authorization": f"Bearer {token}"}
7978
monkeypatch.setattr("cis_profile_retrieval_service.idp.get_jwks", lambda: json_form_of_pk)
8079
# DEBT: requires side-effects.
8180
from cis_profile_retrieval_service import v2_api
81+
8282
v2_api.app.testing = True
8383
return (headers, v2_api.app.test_client())
8484

8585

8686
def test_existing(identity_vault, app):
87+
"""
88+
As it turns out, our pagination was broken.
89+
90+
The `v2/users/id/all` endpoint, as written, did the right thing: paginate
91+
through pages, skipping empty pages.
92+
93+
The scan logic [0] has a slight bug in it, where it would continue
94+
paginating. Since it was doing this pagination for us, at a lower level, we
95+
simply weren't propagating any `nextPage` tokens forward.
96+
97+
Pagination mystery: solved.
98+
99+
[0]: python-modules/cis_identity_vault/cis_identity_vault/parallel_dynamo.py
100+
"""
87101
headers, client = app
88102
# DEBT: see note above, about not generating `ad|Mozilla-LDAP` users.
89103
results = client.get(
90104
"/v2/users/id/all?connectionMethod=github&active=True",
91105
headers=headers,
92106
follow_redirects=True,
93-
)
107+
).json
108+
next_page = results.get("nextPage")
109+
pages = 1
110+
# Pagination, as implemented elsewhere.
111+
while next_page:
112+
results = client.get(
113+
f"/v2/users/id/all?connectionMethod=github&active=True&nextPage={next_page}",
114+
headers=headers,
115+
follow_redirects=True,
116+
).json
117+
next_page = results.get("nextPage")
118+
pages += 1
119+
assert pages >= 2, "Did not paginate!"

0 commit comments

Comments
 (0)