diff --git a/query/services/ingestion.py b/query/services/ingestion.py index 7cb831c..d8207f7 100644 --- a/query/services/ingestion.py +++ b/query/services/ingestion.py @@ -6,6 +6,7 @@ from typing import Dict from asyncpg import Connection +import asyncpg from query.tables.product_nutrient import ( NUTRIENT_TAG, @@ -96,7 +97,7 @@ async def import_with_filter( if config_settings.SKIP_DATA_MIGRATIONS: return - # Keep a not of the last message id at the start of the upgrade as we want to re-play any messages + # Keep a note of the last message id at the start of the upgrade as we want to re-play any messages # that were processed by the old version after this point await set_pre_migration_message_id() else: @@ -108,6 +109,10 @@ async def import_with_filter( await transaction.execute( "CREATE TEMP TABLE product_temp (id int PRIMARY KEY, last_updated timestamptz, data jsonb)" ) + # Commit the temporary table so it isn't rolled back on failure. Stays active for the session + await transaction.execute("COMMIT") + await transaction.execute("BEGIN TRANSACTION") + try: process_id = await get_process_id(transaction) projection = {key: True for key in tags if key != PRODUCT_TAG} @@ -128,6 +133,7 @@ async def import_with_filter( for obsolete in [False, True]: update_count = 0 skip_count = 0 + product_updates = [] async with find_products(filter, projection, obsolete) as cursor: async for product_data in cursor: product_code = product_data["code"] @@ -169,19 +175,16 @@ async def import_with_filter( tag_data = product_data.get(tag, None) strip_nuls(tag_data, f"Product: {product_code}, tag: {tag}") - await transaction.execute( - "INSERT INTO product_temp (id, last_updated, data) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", - existing_product["id"], - last_updated, - product_data, + product_updates.append( + [existing_product["id"], last_updated, product_data] ) update_count += 1 if not (update_count % batch_size): - await apply_staged_changes( + await apply_product_updates( transaction, + product_updates, obsolete, - update_count, process_id, source, tags, @@ -189,8 +192,13 @@ async def import_with_filter( # Apply any remaining staged changes if update_count % batch_size: - await apply_staged_changes( - transaction, obsolete, update_count, process_id, source, tags + await apply_product_updates( + transaction, + product_updates, + obsolete, + process_id, + source, + tags, ) if skip_count % batch_size: logger.info( @@ -219,39 +227,120 @@ async def import_with_filter( return max_last_updated -async def apply_staged_changes( - transaction: Connection, obsolete, update_count, process_id, source, tags +async def apply_product_updates( + transaction: Connection, + product_updates, + obsolete, + process_id, + source, + tags, ): - """ "Copies data from the product_temp temporary table to the relational tables. - Assumes that a basic product record has already been created""" - # Analyze the temp table first as this improves the generated query plans - await transaction.execute("ANALYZE product_temp") - - log = logger.debug - if PRODUCT_TAG in tags: - await update_products_from_staging( - transaction, log, obsolete, process_id, source - ) + """Inserts data from product_updates into the the product_temp temporary table + and then copies from here to the relational tables. + Assumes that a minimal product record has already been created""" + + remaining_updates = [] + # We remember the last SQL error so that if we have a failure, retry successsfully but only have + # one remaining product to test then we know the error must relate to this one + last_sqlerror = None + retrying = False + while len(product_updates): + try: + await transaction.executemany( + """INSERT INTO product_temp (id, last_updated, data) + values ($1, $2, $3) ON CONFLICT DO NOTHING""", + product_updates, + ) + + # Analyze the temp table first as this improves the generated query plans + await transaction.execute("ANALYZE product_temp") + + if retrying: + # We have to re-create any minimal products as they will have been rolled back + # if we have had an error in this batch + await transaction.execute( + """INSERT INTO product (id, code) + SELECT id, data->>'code' + FROM product_temp pt + WHERE NOT EXISTS (SELECT * FROM product p WHERE p.id = pt.id)""" + ) - if INGREDIENTS_TAG in tags: - await create_ingredients_from_staging(transaction, log, obsolete) + log = logger.debug + if PRODUCT_TAG in tags: + await update_products_from_staging( + transaction, log, obsolete, process_id, source + ) - await create_tags_from_staging(transaction, log, obsolete, tags) + if INGREDIENTS_TAG in tags: + await create_ingredients_from_staging(transaction, log, obsolete) - if COUNTRIES_TAG in tags: - await fixup_product_countries(transaction, obsolete) + await create_tags_from_staging(transaction, log, obsolete, tags) - if NUTRIENT_TAG in tags: - await create_product_nutrients_from_staging(transaction, log) + if COUNTRIES_TAG in tags: + await fixup_product_countries(transaction, obsolete) - await transaction.execute("TRUNCATE TABLE product_temp") + if NUTRIENT_TAG in tags: + await create_product_nutrients_from_staging(transaction, log) - # Start a new transaction for the next batch - # The calling process will commit the final transaction - await transaction.execute("COMMIT") - await transaction.execute("BEGIN TRANSACTION") + await transaction.execute("TRUNCATE TABLE product_temp") + + # Start a new transaction for the next batch + # The calling process will commit the final transaction + await transaction.execute("COMMIT") + await transaction.execute("BEGIN TRANSACTION") - logger.info(f"Imported {update_count}{' obsolete' if obsolete else ''} products") + product_count = len(product_updates) + logger.info( + f"Imported {product_count}{' obsolete' if obsolete else ''} products" + ) + + del product_updates[:] + if len(remaining_updates): + if len(remaining_updates) == 1: + # We have saved everything except our bad product. No need to keep going + logger.error( + f"Error updating product: {remaining_updates[0][2]['code']}, {repr(last_sqlerror)}" + ) + elif product_count == 1: + # We previously will have tried a batch of 2 which presumably failed, so we can assume the + # Product with the error is the first one in the remaining updates + logger.error( + f"Error updating product: {remaining_updates[0][2]['code']}, {repr(last_sqlerror)}" + ) + del product_updates[:] + product_updates.extend(remaining_updates[1:]) + del remaining_updates[:] + else: + # Move the first half of remaining back in for retry + # As mentioned below, if there are, say, 3 left then we want + # to retry the first two. Don't need -1 on the range as the upper is not inclusive + next_retry_count = (len(remaining_updates) + 1) // 2 + product_updates[0:0] = remaining_updates[:next_retry_count] + del remaining_updates[:next_retry_count] + + except asyncpg.PostgresError as sql_error: + last_sqlerror = sql_error + await transaction.execute("ROLLBACK") + if len(product_updates) == 1: + # We have found our bad product + logger.error( + f"Error updating product: {product_updates[0][2]['code']}, {repr(last_sqlerror)}" + ) + last_sqlerror = None + del product_updates[:] + product_updates.extend(remaining_updates) + del remaining_updates[:] + else: + # move the last half of the products into remaining. We want our penultimate retry to be on + # a group of two so that if we then split again and succeed we know the problem was with the second one + # Hence if we have 3 updates to try we want to then split it into 2 and 1 for retries + next_retry_count = (len(product_updates) + 1) // 2 + remaining_updates[0:0] = product_updates[next_retry_count:] + del product_updates[next_retry_count:] + + # Start a new transaction + await transaction.execute("BEGIN TRANSACTION") + retrying = True import_running = False diff --git a/query/services/ingestion_test.py b/query/services/ingestion_test.py index c6e6a9a..2e40bfa 100644 --- a/query/services/ingestion_test.py +++ b/query/services/ingestion_test.py @@ -5,6 +5,8 @@ from datetime import datetime, timezone from unittest.mock import Mock, patch +from query.models.query import Filter +from query.services import query from query.tables.nutrient import get_nutrient from query.tables.product_nutrient import get_product_nutrients @@ -62,9 +64,7 @@ def get_test_products(): f"{random_code()}": { "value": random.uniform(100, 0.000001), }, - "invalid": { - "value": "100g" - } + "invalid": {"value": "100g"}, } } }, @@ -236,9 +236,7 @@ def next_process_id(_): ) # Should ignore old schema if both are present - carbohydrate_nutrient = await get_nutrient( - transaction, "carbohydrates" - ) + carbohydrate_nutrient = await get_nutrient(transaction, "carbohydrates") found_carbohydrate = [ item for item in existing_product_nutrients @@ -780,3 +778,65 @@ async def test_event_load_should_restore_deleted_products( product_countries = await get_product_countries(transaction, deleted_product) assert len(product_countries) == 1 assert product_countries[0]["obsolete"] == False + + +@patch.object(ingestion, "find_products") +@patch.object(ingestion, "create_product_nutrients_from_staging") +@patch.object(ingestion, "logger") +async def assert_for_failing_product_indices( + product_count: int, + error_indices: list, + logger_mock: Mock, + update_nutrients_mock: Mock, + find_products_mock: Mock, +): + products = [] + owner = random_code() + for i in range(product_count): + products.append( + { + "code": random_code(), + "last_updated_t": last_updated, + "owners_tags": owner, + } + ) + + patch_context_manager(find_products_mock, mock_cursor(products)) + + error_products = [] + for i in error_indices: + error_products.append(products[i]["code"]) + + async def error_on_nutrient(transaction, _): + # The following SQL will fail if the temp table contains the rogue product + await transaction.execute( + """INSERT INTO product_nutrient (product_id) + SELECT 0 FROM product_temp + WHERE data->>'code' = ANY($1) + """, + error_products, + ) + + update_nutrients_mock.side_effect = error_on_nutrient + + await ingestion.import_from_mongo("") + + response = await query.count(Filter(owners_tags=owner)) + assert response == product_count - len(error_indices) + + error_calls = logger_mock.error.call_args_list + assert len(error_calls) == len(error_indices) + for i in range(len(error_indices)): + assert error_products[i] in error_calls[i][0][0] + + +async def test_skips_products_where_sql_fails(): + await assert_for_failing_product_indices(10, [7]) + + +async def test_skips_products_where_multiple_sql_fails(): + await assert_for_failing_product_indices(10, [0, 9]) + +async def test_skips_products_where_second_product_fails(): + await assert_for_failing_product_indices(5, [1]) +