Skip to content

Commit b274ea0

Browse files
committed
linting fix #2
1 parent a9a69e0 commit b274ea0

File tree

2 files changed

+168
-149
lines changed

2 files changed

+168
-149
lines changed

metrics/interfaces/management/commands/seed_random.py

Lines changed: 147 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _seed_metrics_data(
121121
scale_config: dict[str, int],
122122
truncate_first: bool,
123123
progress_callback: Callable[[str], None] | None = None,
124-
) -> dict[str, int]: # noqa: PLR0914
124+
) -> dict[str, int]:
125125
"""Seed supporting metric models and time series rows for the selected scale."""
126126
if progress_callback is not None:
127127
progress_callback("Preparing metric taxonomy and geography records...")
@@ -130,39 +130,7 @@ def _seed_metrics_data(
130130
if truncate_first:
131131
cls._truncate_metrics_data()
132132

133-
(
134-
theme_names,
135-
sub_theme_rows,
136-
topic_rows,
137-
) = cls._build_theme_hierarchy_records()
138-
themes = cls._bulk_create(
139-
Theme,
140-
[Theme(name=name) for name in theme_names],
141-
)
142-
themes_by_name = {theme.name: theme for theme in themes}
143-
144-
sub_themes = cls._bulk_create(
145-
SubTheme,
146-
[
147-
SubTheme(name=name, theme=themes_by_name[theme_name])
148-
for name, theme_name in sub_theme_rows
149-
],
150-
)
151-
sub_themes_by_key = {
152-
(sub_theme.name, sub_theme.theme.name): sub_theme
153-
for sub_theme in sub_themes
154-
}
155-
156-
topics = cls._bulk_create(
157-
Topic,
158-
[
159-
Topic(
160-
name=topic_name,
161-
sub_theme=sub_themes_by_key[(sub_theme_name, theme_name)],
162-
)
163-
for topic_name, sub_theme_name, theme_name in topic_rows
164-
],
165-
)
133+
themes, sub_themes, topics = cls._seed_theme_hierarchy()
166134

167135
metrics = cls._bulk_create(
168136
Metric,
@@ -175,34 +143,7 @@ def _seed_metrics_data(
175143
],
176144
)
177145

178-
geography_seed_values = cls._build_geography_seed_values(
179-
count=scale_config["geographies"]
180-
)
181-
geography_type_names = {
182-
record["geography_type"] for record in geography_seed_values
183-
}
184-
geography_types = cls._bulk_create(
185-
GeographyType,
186-
[GeographyType(name=name) for name in sorted(geography_type_names)],
187-
)
188-
geography_types_by_name = {
189-
geography_type.name: geography_type
190-
for geography_type in geography_types
191-
}
192-
193-
geographies = cls._bulk_create(
194-
Geography,
195-
[
196-
Geography(
197-
name=record["name"],
198-
geography_code=record["geography_code"],
199-
geography_type=geography_types_by_name[
200-
record["geography_type"]
201-
],
202-
)
203-
for record in geography_seed_values
204-
],
205-
)
146+
geographies = cls._seed_geographies(count=scale_config["geographies"])
206147

207148
stratum = Stratum.objects.create(name="All")
208149
age = Age.objects.create(name="All ages")
@@ -252,10 +193,9 @@ def _seed_time_series_rows(
252193
age: Age,
253194
days: int,
254195
progress_callback: Callable[[str], None] | None = None,
255-
) -> tuple[int, int]: # noqa: PLR0914
196+
) -> tuple[int, int]:
256197
frequency = TimePeriod.Weekly.value
257-
today = date.today()
258-
start_date = today - timedelta(days=days - 1)
198+
start_date = date.today() - timedelta(days=days - 1)
259199
batch_size = 5000
260200
core_rows: list[CoreTimeSeries] = []
261201
api_rows: list[APITimeSeries] = []
@@ -266,81 +206,29 @@ def _seed_time_series_rows(
266206
log_interval = max(1, total_metrics // 10) if total_metrics else 1
267207

268208
for metric_index, metric in enumerate(metrics, start=1):
269-
topic = metric.topic
270-
sub_theme = topic.sub_theme
271-
theme = sub_theme.theme
272-
273-
for geography in geographies:
274-
for day_offset in range(days):
275-
current_date = start_date + timedelta(days=day_offset)
276-
base_value = random.uniform(5.0, 250.0) # noqa: S311 # nosec B311
277-
metric_value = round(
278-
base_value
279-
+ random.uniform(-10.0, 10.0), # noqa: S311 # nosec B311
280-
2,
281-
)
282-
epidemiological_week = current_date.isocalendar().week
283-
284-
core_rows.append(
285-
CoreTimeSeries(
286-
metric=metric,
287-
metric_frequency=frequency,
288-
geography=geography,
289-
stratum=stratum,
290-
age=age,
291-
sex=None,
292-
year=current_date.year,
293-
month=current_date.month,
294-
epiweek=epidemiological_week,
295-
date=current_date,
296-
metric_value=Decimal(str(metric_value)),
297-
is_public=True,
298-
)
299-
)
300-
301-
if len(core_rows) >= batch_size:
302-
CoreTimeSeries.objects.bulk_create(
303-
core_rows, batch_size=batch_size
304-
)
305-
core_count += len(core_rows)
306-
core_rows = []
307-
308-
api_rows.append(
309-
APITimeSeries(
310-
metric_frequency=frequency,
311-
age=age.name,
312-
month=current_date.month,
313-
geography_code=geography.geography_code,
314-
metric_group=None,
315-
theme=theme.name,
316-
sub_theme=sub_theme.name,
317-
topic=topic.name,
318-
geography_type=geography.geography_type.name,
319-
geography=geography.name,
320-
metric=metric.name,
321-
stratum=stratum.name,
322-
sex=None,
323-
year=current_date.year,
324-
epiweek=epidemiological_week,
325-
date=current_date,
326-
metric_value=float(metric_value),
327-
is_public=True,
328-
)
329-
)
330-
331-
if len(api_rows) >= batch_size:
332-
APITimeSeries.objects.bulk_create(
333-
api_rows, batch_size=batch_size
334-
)
335-
api_count += len(api_rows)
336-
api_rows = []
337-
338-
if (
339-
progress_callback is not None
340-
and (
341-
metric_index == total_metrics
342-
or metric_index % log_interval == 0
343-
)
209+
for core_row, api_row in cls._build_time_series_rows_for_metric(
210+
metric=metric,
211+
geographies=geographies,
212+
stratum=stratum,
213+
age=age,
214+
days=days,
215+
start_date=start_date,
216+
frequency=frequency,
217+
):
218+
core_rows.append(core_row)
219+
if len(core_rows) >= batch_size:
220+
CoreTimeSeries.objects.bulk_create(core_rows, batch_size=batch_size)
221+
core_count += len(core_rows)
222+
core_rows = []
223+
224+
api_rows.append(api_row)
225+
if len(api_rows) >= batch_size:
226+
APITimeSeries.objects.bulk_create(api_rows, batch_size=batch_size)
227+
api_count += len(api_rows)
228+
api_rows = []
229+
230+
if progress_callback is not None and (
231+
metric_index == total_metrics or metric_index % log_interval == 0
344232
):
345233
processed_row_count = metric_index * len(geographies) * days
346234
progress_callback(
@@ -365,6 +253,123 @@ def _seed_time_series_rows(
365253

366254
return core_count, api_count
367255

256+
@classmethod
257+
def _seed_theme_hierarchy(cls) -> tuple[list[Theme], list[SubTheme], list[Topic]]:
258+
theme_names, sub_theme_rows, topic_rows = cls._build_theme_hierarchy_records()
259+
themes = cls._bulk_create(Theme, [Theme(name=name) for name in theme_names])
260+
themes_by_name = {theme.name: theme for theme in themes}
261+
sub_themes = cls._bulk_create(
262+
SubTheme,
263+
[
264+
SubTheme(name=name, theme=themes_by_name[theme_name])
265+
for name, theme_name in sub_theme_rows
266+
],
267+
)
268+
sub_themes_by_key = {
269+
(sub_theme.name, sub_theme.theme.name): sub_theme
270+
for sub_theme in sub_themes
271+
}
272+
topics = cls._bulk_create(
273+
Topic,
274+
[
275+
Topic(
276+
name=topic_name,
277+
sub_theme=sub_themes_by_key[(sub_theme_name, theme_name)],
278+
)
279+
for topic_name, sub_theme_name, theme_name in topic_rows
280+
],
281+
)
282+
return themes, sub_themes, topics
283+
284+
@classmethod
285+
def _seed_geographies(cls, *, count: int) -> list[Geography]:
286+
geography_seed_values = cls._build_geography_seed_values(count=count)
287+
geography_type_names = {
288+
record["geography_type"] for record in geography_seed_values
289+
}
290+
geography_types = cls._bulk_create(
291+
GeographyType,
292+
[GeographyType(name=name) for name in sorted(geography_type_names)],
293+
)
294+
geography_types_by_name = {
295+
geography_type.name: geography_type for geography_type in geography_types
296+
}
297+
return cls._bulk_create(
298+
Geography,
299+
[
300+
Geography(
301+
name=record["name"],
302+
geography_code=record["geography_code"],
303+
geography_type=geography_types_by_name[record["geography_type"]],
304+
)
305+
for record in geography_seed_values
306+
],
307+
)
308+
309+
@classmethod
310+
def _build_time_series_rows_for_metric(
311+
cls,
312+
*,
313+
metric: Metric,
314+
geographies: list[Geography],
315+
stratum: Stratum,
316+
age: Age,
317+
days: int,
318+
start_date: date,
319+
frequency: str,
320+
) -> Iterable[tuple[CoreTimeSeries, APITimeSeries]]:
321+
topic = metric.topic
322+
sub_theme = topic.sub_theme
323+
theme = sub_theme.theme
324+
325+
for geography in geographies:
326+
for day_offset in range(days):
327+
current_date = start_date + timedelta(days=day_offset)
328+
base_value = random.uniform(5.0, 250.0) # noqa: S311 # nosec B311
329+
metric_value = round(
330+
base_value
331+
+ random.uniform(-10.0, 10.0), # noqa: S311 # nosec B311
332+
2,
333+
)
334+
epidemiological_week = current_date.isocalendar().week
335+
336+
yield (
337+
CoreTimeSeries(
338+
metric=metric,
339+
metric_frequency=frequency,
340+
geography=geography,
341+
stratum=stratum,
342+
age=age,
343+
sex=None,
344+
year=current_date.year,
345+
month=current_date.month,
346+
epiweek=epidemiological_week,
347+
date=current_date,
348+
metric_value=Decimal(str(metric_value)),
349+
is_public=True,
350+
),
351+
APITimeSeries(
352+
metric_frequency=frequency,
353+
age=age.name,
354+
month=current_date.month,
355+
geography_code=geography.geography_code,
356+
metric_group=None,
357+
theme=theme.name,
358+
sub_theme=sub_theme.name,
359+
topic=topic.name,
360+
geography_type=geography.geography_type.name,
361+
geography=geography.name,
362+
metric=metric.name,
363+
stratum=stratum.name,
364+
sex=None,
365+
year=current_date.year,
366+
epiweek=epidemiological_week,
367+
date=current_date,
368+
metric_value=float(metric_value),
369+
is_public=True,
370+
),
371+
)
372+
368373
@staticmethod
369374
def _bulk_create(model, records: Iterable):
370375
"""Materialise and bulk insert a sequence of model instances."""
@@ -386,9 +391,7 @@ def _build_theme_hierarchy_records(
386391
)
387392
for sub_theme_name in child_theme_group.return_list():
388393
child_to_parent[sub_theme_name] = resolved_parent
389-
normalised_to_child[cls._normalise_key(sub_theme_name)] = (
390-
sub_theme_name
391-
)
394+
normalised_to_child[cls._normalise_key(sub_theme_name)] = sub_theme_name
392395

393396
topic_rows: list[tuple[str, str, str]] = []
394397
sub_theme_pairs: set[tuple[str, str]] = set()
@@ -418,9 +421,7 @@ def _build_geography_seed_values(cls, *, count: int) -> list[dict[str, str]]:
418421
{
419422
"name": "United Kingdom",
420423
"geography_code": UNITED_KINGDOM_GEOGRAPHY_CODE,
421-
"geography_type": (
422-
validation_enums.GeographyType.UNITED_KINGDOM.value
423-
),
424+
"geography_type": (validation_enums.GeographyType.UNITED_KINGDOM.value),
424425
}
425426
]
426427

tests/unit/metrics/interfaces/management/test_seed_random.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,12 @@ def test_seed_time_series_rows_flushes_remainder(
320320
assert api_count == 1
321321
spy_core_time_series.objects.bulk_create.assert_called_once()
322322
spy_api_time_series.objects.bulk_create.assert_called_once()
323-
progress_messages = [call.args[0] for call in spy_progress_callback.call_args_list]
324-
assert any(message.startswith("Processed 1/1 metrics") for message in progress_messages)
323+
progress_messages = [
324+
call.args[0] for call in spy_progress_callback.call_args_list
325+
]
326+
assert any(
327+
message.startswith("Processed 1/1 metrics") for message in progress_messages
328+
)
325329
assert any(message.startswith("Inserted ") for message in progress_messages)
326330

327331
@mock.patch(f"{MODULE_PATH}.APITimeSeries")
@@ -415,7 +419,21 @@ def test_build_theme_hierarchy_records_contains_expected_real_values():
415419

416420
assert "infectious_disease" in theme_names
417421
assert any(sub_theme == "respiratory" for sub_theme, _ in sub_theme_rows)
418-
assert any(topic == "COVID-19" and sub_theme == "respiratory" for topic, sub_theme, _ in topic_rows)
422+
assert any(
423+
topic == "COVID-19" and sub_theme == "respiratory"
424+
for topic, sub_theme, _ in topic_rows
425+
)
426+
427+
428+
def test_build_theme_hierarchy_records_skips_unmatched_topic_group():
429+
fake_topic_group = mock.Mock()
430+
fake_topic_group.name = "DOES_NOT_MATCH_CHILD_THEME"
431+
fake_topic_group.return_list.return_value = ["dummy-topic"]
432+
433+
with mock.patch(f"{MODULE_PATH}.validation_enums.Topic", [fake_topic_group]):
434+
_, _, topic_rows = Command._build_theme_hierarchy_records()
435+
436+
assert topic_rows == []
419437

420438

421439
def test_build_geography_seed_values_returns_required_count():

0 commit comments

Comments
 (0)