Skip to content

Commit 4335f9e

Browse files
committed
[NRL-1860] Switch nft seed script to use poisson distribution for counts. Output .csv file for test once nft table has been seeded
1 parent 6bfc4ac commit 4335f9e

File tree

1 file changed

+83
-26
lines changed

1 file changed

+83
-26
lines changed

scripts/seed_nft_tables.py

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import csv
12
from datetime import datetime, timedelta, timezone
23
from itertools import cycle
34
from math import gcd
45
from random import shuffle
5-
from typing import Any
6+
from typing import Any, Iterator
67

78
import boto3
89
import fire
910

11+
# import json
12+
import numpy as np
13+
1014
from nrlf.consumer.fhir.r4.model import DocumentReference
1115
from nrlf.core.constants import (
1216
CATEGORY_ATTRIBUTES,
@@ -145,7 +149,7 @@ def _populate_seed_table(
145149
px_with_pointers: int,
146150
pointers_per_px: float = 1.0,
147151
type_dists: dict[str, int] = DEFAULT_TYPE_DISTRIBUTIONS,
148-
custodian_dists: dict[str, int] = DEFAULT_CUSTODIAN_DISTRIBUTIONS,
152+
custodian_dists: dict[str, dict[str, int]] = DEFAULT_CUSTODIAN_DISTRIBUTIONS,
149153
):
150154
"""
151155
Seeds a table with example data for non-functional testing.
@@ -155,25 +159,40 @@ def _populate_seed_table(
155159
# set up iterations
156160
type_iter = _set_up_cyclical_iterator(type_dists)
157161
custodian_iters = _set_up_custodian_iterators(custodian_dists)
158-
count_iter = _set_up_cyclical_iterator(DEFAULT_COUNT_DISTRIBUTIONS)
162+
# count_iter = _set_up_cyclical_iterator(DEFAULT_COUNT_DISTRIBUTIONS)
163+
count_iter = _get_pointer_count_poisson_distributions(
164+
px_with_pointers, pointers_per_px
165+
)
166+
# count_iter = _get_pointer_count_negbinom_distributions(px_with_pointers, pointers_per_px)
159167
testnum_cls = TestNhsNumbersIterator()
160168
testnum_iter = iter(testnum_cls)
161169

162170
px_counter = 0
163171
doc_ref_target = int(pointers_per_px * px_with_pointers)
164172
print(
165-
f"Will upsert {doc_ref_target} test pointers for {px_with_pointers} patients."
173+
f"Will upsert ~{doc_ref_target} test pointers for {px_with_pointers} patients."
166174
)
167175
doc_ref_counter = 0
168176
batch_counter = 0
169177

178+
pointer_data: list[list[str]] = []
179+
170180
start_time = datetime.now(tz=timezone.utc)
171181

172-
batch_upsert_items = []
173-
while px_counter <= px_with_pointers:
182+
batch_upsert_items: list[dict[str, Any]] = []
183+
while px_counter < px_with_pointers:
174184
pointers_for_px = int(next(count_iter))
185+
175186
if batch_counter + pointers_for_px > 25 or px_counter == px_with_pointers:
176-
resource.batch_write_item(RequestItems={table_name: batch_upsert_items})
187+
response = resource.batch_write_item(
188+
RequestItems={table_name: batch_upsert_items}
189+
)
190+
191+
if response.get("UnprocessedItems"):
192+
logger.error(
193+
f"Unprocessed items in batch write: {len(response.get('UnprocessedItems'))}"
194+
)
195+
177196
batch_upsert_items = []
178197
batch_counter = 0
179198

@@ -189,54 +208,92 @@ def _populate_seed_table(
189208
)
190209
put_req = {"PutRequest": {"Item": pointer.model_dump()}}
191210
batch_upsert_items.append(put_req)
211+
pointer_data.append(
212+
[
213+
pointer.id,
214+
pointer.type,
215+
pointer.custodian,
216+
pointer.nhs_number,
217+
]
218+
)
192219
px_counter += 1
193220

221+
if px_counter % 1000 == 0:
222+
print(".", end="", flush=True)
223+
if px_counter % 100000 == 0:
224+
print(f" {px_counter} patients processed")
225+
226+
print(" Done.")
227+
194228
end_time = datetime.now(tz=timezone.utc)
195229
print(
196230
f"Created {doc_ref_counter} pointers in {timedelta.total_seconds(end_time - start_time)} seconds."
197231
)
198232

233+
with open("./seed-nft-pointers.csv", "w") as f:
234+
writer = csv.writer(f)
235+
writer.writerow(["pointer_id", "pointer_type", "custodian", "nhs_number"])
236+
writer.writerows(pointer_data)
237+
print(f"Pointer data saved to ./seed-nft-pointers.csv") # noqa
238+
199239

200-
def _set_up_cyclical_iterator(dists: dict[str, int]) -> iter:
240+
def _set_up_cyclical_iterator(dists: dict[str, int]) -> Iterator[str]:
201241
"""
202242
Given a dict of values and their relative frequencies,
203243
returns an iterator that will cycle through a the reduced and shuffled set of values.
204244
This should result in more live-like data than e.g. creating a bulk amount of each pointer type/custodian in series.
205245
It also means each batch will contain a representative sample of the distribution.
206246
"""
207247
d = gcd(*dists.values())
208-
value_list = []
248+
value_list: list[str] = []
209249
for entry in dists:
210250
value_list.extend([entry] * (dists[entry] // d))
211251
shuffle(value_list)
212252
return cycle(value_list)
213253

214254

255+
# def _get_pointer_count_negbinom_distributions(num_of_patients: int, pointers_per_px: float) -> cycle:
256+
# dispersion = 2 # lower = more variance; higher = closer to Poisson
257+
# p = dispersion / (dispersion + pointers_per_px)
258+
# n = dispersion
259+
260+
# p_count_distr = np.random.negative_binomial(n=n, p=p, size=num_of_patients)
261+
# return cycle(p_count_distr)
262+
263+
264+
def _get_pointer_count_poisson_distributions(
265+
num_of_patients: int, pointers_per_px: float
266+
) -> Iterator[int]:
267+
p_count_distr = np.random.poisson(lam=pointers_per_px - 1, size=num_of_patients) + 1
268+
p_count_distr = np.clip(p_count_distr, a_min=1, a_max=4)
269+
return cycle(p_count_distr)
270+
271+
215272
def _set_up_custodian_iterators(
216-
custodian_dists: dict[dict[str, int]]
217-
) -> dict[str, iter]:
218-
custodian_iters = {}
273+
custodian_dists: dict[str, dict[str, int]]
274+
) -> dict[str, Iterator[str]]:
275+
custodian_iters: dict[str, Iterator[str]] = {}
219276
for pointer_type in custodian_dists:
220277
custodian_iters[pointer_type] = _set_up_cyclical_iterator(
221278
custodian_dists[pointer_type]
222279
)
223280
return custodian_iters
224281

225282

226-
def _set_up_count_iterator(pointers_per_px: float) -> iter:
227-
"""
228-
Given a target average number of pointers per patient,
229-
generates a distribution of counts per individual patient.
230-
"""
231-
232-
extra_per_hundred = int(
233-
(pointers_per_px - 1.0) * 100
234-
) # no patients can have zero pointers
235-
counts = {}
236-
counts["3"] = extra_per_hundred // 10
237-
counts["2"] = extra_per_hundred - 2 * counts["3"]
238-
counts["1"] = 100 - counts[2] - counts[3]
239-
return _set_up_cyclical_iterator(counts)
283+
# def _set_up_count_iterator(pointers_per_px: float) -> iter:
284+
# """
285+
# Given a target average number of pointers per patient,
286+
# generates a distribution of counts per individual patient.
287+
# """
288+
#
289+
# extra_per_hundred = int(
290+
# (pointers_per_px - 1.0) * 100
291+
# ) # no patients can have zero pointers
292+
# counts = {}
293+
# counts["3"] = extra_per_hundred // 10
294+
# counts["2"] = extra_per_hundred - 2 * counts["3"]
295+
# counts["1"] = 100 - counts[2] - counts[3]
296+
# return _set_up_cyclical_iterator(counts)
240297

241298

242299
if __name__ == "__main__":

0 commit comments

Comments
 (0)