Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 124 additions & 35 deletions query/services/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict

from asyncpg import Connection
import asyncpg

from query.tables.product_nutrient import (
NUTRIENT_TAG,
Expand Down Expand Up @@ -96,7 +97,7 @@ async def import_with_filter(
if config_settings.SKIP_DATA_MIGRATIONS:
return

# Keep a not of the last message id at the start of the upgrade as we want to re-play any messages
# Keep a note of the last message id at the start of the upgrade as we want to re-play any messages
# that were processed by the old version after this point
await set_pre_migration_message_id()
else:
Expand All @@ -108,6 +109,10 @@ async def import_with_filter(
await transaction.execute(
"CREATE TEMP TABLE product_temp (id int PRIMARY KEY, last_updated timestamptz, data jsonb)"
)
# Commit the temporary table so it isn't rolled back on failure. Stays active for the session
await transaction.execute("COMMIT")
await transaction.execute("BEGIN TRANSACTION")

try:
process_id = await get_process_id(transaction)
projection = {key: True for key in tags if key != PRODUCT_TAG}
Expand All @@ -128,6 +133,7 @@ async def import_with_filter(
for obsolete in [False, True]:
update_count = 0
skip_count = 0
product_updates = []
async with find_products(filter, projection, obsolete) as cursor:
async for product_data in cursor:
product_code = product_data["code"]
Expand Down Expand Up @@ -169,28 +175,30 @@ async def import_with_filter(
tag_data = product_data.get(tag, None)
strip_nuls(tag_data, f"Product: {product_code}, tag: {tag}")

await transaction.execute(
"INSERT INTO product_temp (id, last_updated, data) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING",
existing_product["id"],
last_updated,
product_data,
product_updates.append(
[existing_product["id"], last_updated, product_data]
)

update_count += 1
if not (update_count % batch_size):
await apply_staged_changes(
await apply_product_updates(
transaction,
product_updates,
obsolete,
update_count,
process_id,
source,
tags,
)

# Apply any remaining staged changes
if update_count % batch_size:
await apply_staged_changes(
transaction, obsolete, update_count, process_id, source, tags
await apply_product_updates(
transaction,
product_updates,
obsolete,
process_id,
source,
tags,
)
if skip_count % batch_size:
logger.info(
Expand Down Expand Up @@ -219,39 +227,120 @@ async def import_with_filter(
return max_last_updated


async def apply_staged_changes(
transaction: Connection, obsolete, update_count, process_id, source, tags
async def apply_product_updates(
transaction: Connection,
product_updates,
obsolete,
process_id,
source,
tags,
):
""" "Copies data from the product_temp temporary table to the relational tables.
Assumes that a basic product record has already been created"""
# Analyze the temp table first as this improves the generated query plans
await transaction.execute("ANALYZE product_temp")

log = logger.debug
if PRODUCT_TAG in tags:
await update_products_from_staging(
transaction, log, obsolete, process_id, source
)
"""Inserts data from product_updates into the the product_temp temporary table
and then copies from here to the relational tables.
Assumes that a minimal product record has already been created"""

remaining_updates = []
# We remember the last SQL error so that if we have a failure, retry successsfully but only have
# one remaining product to test then we know the error must relate to this one
last_sqlerror = None
retrying = False
while len(product_updates):
try:
await transaction.executemany(
"""INSERT INTO product_temp (id, last_updated, data)
values ($1, $2, $3) ON CONFLICT DO NOTHING""",
product_updates,
)

# Analyze the temp table first as this improves the generated query plans
await transaction.execute("ANALYZE product_temp")

if retrying:
# We have to re-create any minimal products as they will have been rolled back
# if we have had an error in this batch
await transaction.execute(
"""INSERT INTO product (id, code)
SELECT id, data->>'code'
FROM product_temp pt
WHERE NOT EXISTS (SELECT * FROM product p WHERE p.id = pt.id)"""
)

if INGREDIENTS_TAG in tags:
await create_ingredients_from_staging(transaction, log, obsolete)
log = logger.debug
if PRODUCT_TAG in tags:
await update_products_from_staging(
transaction, log, obsolete, process_id, source
)

await create_tags_from_staging(transaction, log, obsolete, tags)
if INGREDIENTS_TAG in tags:
await create_ingredients_from_staging(transaction, log, obsolete)

if COUNTRIES_TAG in tags:
await fixup_product_countries(transaction, obsolete)
await create_tags_from_staging(transaction, log, obsolete, tags)

if NUTRIENT_TAG in tags:
await create_product_nutrients_from_staging(transaction, log)
if COUNTRIES_TAG in tags:
await fixup_product_countries(transaction, obsolete)

await transaction.execute("TRUNCATE TABLE product_temp")
if NUTRIENT_TAG in tags:
await create_product_nutrients_from_staging(transaction, log)

# Start a new transaction for the next batch
# The calling process will commit the final transaction
await transaction.execute("COMMIT")
await transaction.execute("BEGIN TRANSACTION")
await transaction.execute("TRUNCATE TABLE product_temp")

# Start a new transaction for the next batch
# The calling process will commit the final transaction
await transaction.execute("COMMIT")
await transaction.execute("BEGIN TRANSACTION")

logger.info(f"Imported {update_count}{' obsolete' if obsolete else ''} products")
product_count = len(product_updates)
logger.info(
f"Imported {product_count}{' obsolete' if obsolete else ''} products"
)

del product_updates[:]
if len(remaining_updates):
if len(remaining_updates) == 1:
# We have saved everything except our bad product. No need to keep going
logger.error(
f"Error updating product: {remaining_updates[0][2]['code']}, {repr(last_sqlerror)}"
)
elif product_count == 1:
# We previously will have tried a batch of 2 which presumably failed, so we can assume the
# Product with the error is the first one in the remaining updates
logger.error(
f"Error updating product: {remaining_updates[0][2]['code']}, {repr(last_sqlerror)}"
)
del product_updates[:]
product_updates.extend(remaining_updates[1:])
del remaining_updates[:]
else:
# Move the first half of remaining back in for retry
# As mentioned below, if there are, say, 3 left then we want
# to retry the first two. Don't need -1 on the range as the upper is not inclusive
next_retry_count = (len(remaining_updates) + 1) // 2
product_updates[0:0] = remaining_updates[:next_retry_count]
del remaining_updates[:next_retry_count]

except asyncpg.PostgresError as sql_error:
last_sqlerror = sql_error
await transaction.execute("ROLLBACK")
if len(product_updates) == 1:
# We have found our bad product
logger.error(
f"Error updating product: {product_updates[0][2]['code']}, {repr(last_sqlerror)}"
)
last_sqlerror = None
del product_updates[:]
product_updates.extend(remaining_updates)
del remaining_updates[:]
else:
# move the last half of the products into remaining. We want our penultimate retry to be on
# a group of two so that if we then split again and succeed we know the problem was with the second one
# Hence if we have 3 updates to try we want to then split it into 2 and 1 for retries
next_retry_count = (len(product_updates) + 1) // 2
remaining_updates[0:0] = product_updates[next_retry_count:]
del product_updates[next_retry_count:]

# Start a new transaction
await transaction.execute("BEGIN TRANSACTION")
retrying = True


import_running = False
Expand Down
72 changes: 66 additions & 6 deletions query/services/ingestion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from datetime import datetime, timezone
from unittest.mock import Mock, patch

from query.models.query import Filter
from query.services import query
from query.tables.nutrient import get_nutrient
from query.tables.product_nutrient import get_product_nutrients

Expand Down Expand Up @@ -62,9 +64,7 @@ def get_test_products():
f"{random_code()}": {
"value": random.uniform(100, 0.000001),
},
"invalid": {
"value": "100g"
}
"invalid": {"value": "100g"},
}
}
},
Expand Down Expand Up @@ -236,9 +236,7 @@ def next_process_id(_):
)

# Should ignore old schema if both are present
carbohydrate_nutrient = await get_nutrient(
transaction, "carbohydrates"
)
carbohydrate_nutrient = await get_nutrient(transaction, "carbohydrates")
found_carbohydrate = [
item
for item in existing_product_nutrients
Expand Down Expand Up @@ -780,3 +778,65 @@ async def test_event_load_should_restore_deleted_products(
product_countries = await get_product_countries(transaction, deleted_product)
assert len(product_countries) == 1
assert product_countries[0]["obsolete"] == False


@patch.object(ingestion, "find_products")
@patch.object(ingestion, "create_product_nutrients_from_staging")
@patch.object(ingestion, "logger")
async def assert_for_failing_product_indices(
product_count: int,
error_indices: list,
logger_mock: Mock,
update_nutrients_mock: Mock,
find_products_mock: Mock,
):
products = []
owner = random_code()
for i in range(product_count):
products.append(
{
"code": random_code(),
"last_updated_t": last_updated,
"owners_tags": owner,
}
)

patch_context_manager(find_products_mock, mock_cursor(products))

error_products = []
for i in error_indices:
error_products.append(products[i]["code"])

async def error_on_nutrient(transaction, _):
# The following SQL will fail if the temp table contains the rogue product
await transaction.execute(
"""INSERT INTO product_nutrient (product_id)
SELECT 0 FROM product_temp
WHERE data->>'code' = ANY($1)
""",
error_products,
)

update_nutrients_mock.side_effect = error_on_nutrient

await ingestion.import_from_mongo("")

response = await query.count(Filter(owners_tags=owner))
assert response == product_count - len(error_indices)

error_calls = logger_mock.error.call_args_list
assert len(error_calls) == len(error_indices)
for i in range(len(error_indices)):
assert error_products[i] in error_calls[i][0][0]


async def test_skips_products_where_sql_fails():
await assert_for_failing_product_indices(10, [7])


async def test_skips_products_where_multiple_sql_fails():
await assert_for_failing_product_indices(10, [0, 9])

async def test_skips_products_where_second_product_fails():
await assert_for_failing_product_indices(5, [1])

Loading