diff --git a/aidb/engine/engine.py b/aidb/engine/engine.py index 04738aac..4428fbac 100644 --- a/aidb/engine/engine.py +++ b/aidb/engine/engine.py @@ -1,9 +1,18 @@ -from aidb.engine.approx_aggregate_join_engine import ApproximateAggregateJoinEngine +from collections import deque + +import networkx as nx +from sqlalchemy.schema import ForeignKeyConstraint +from sqlalchemy.sql import delete + +from aidb.engine.approx_aggregate_join_engine import \ + ApproximateAggregateJoinEngine from aidb.engine.approx_select_engine import ApproxSelectEngine from aidb.engine.limit_engine import LimitEngine from aidb.engine.non_select_query_engine import NonSelectQueryEngine -from aidb.utils.asyncio import asyncio_run +from aidb.inference.bound_inference_service import CachedBoundInferenceService from aidb.query.query import Query +from aidb.utils.asyncio import asyncio_run +from aidb.utils.logger import logger class Engine(LimitEngine, NonSelectQueryEngine, ApproxSelectEngine, ApproximateAggregateJoinEngine): @@ -42,3 +51,72 @@ def execute(self, query: str, **kwargs): raise e finally: self.__del__() + + async def clear_ml_cache(self, services_to_clear = None): + ''' + Clear the cache and output table if the ML model has changed. + 1. Collect the output tables directly related to the selected services. + 2. Collect the output tables that need to be cleared considering the fk and service constraints. + 3. Delete the cache tables. + 4. Delete the output tables in the reversed topological order of table_graph. + + services_to_clear: the name of all the changed services. A list of str or None. + If the service name list is not given, the output for all the services will be cleared. + Note that the output for some other services may be also cleared because of fk constraints. + ''' + if services_to_clear is None: + services_to_clear = [bound_service.service.name for bound_service in self._config.inference_bindings] + services_to_clear = set(services_to_clear) + + # The services that has output columns in the table + table_related_service = {table_name: set() for table_name in self._config.tables.keys()} + # The output tables of each service + output_tables = {service_name: set() for service_name in self._config.inference_services.keys()} + tables_to_clear = set() + + for bound_service in self._config.inference_bindings: + if isinstance(bound_service, CachedBoundInferenceService): + # Construct the table to service map and the output table list + service_name = bound_service.service.name + for output_column in bound_service.binding.output_columns: + output_tables[service_name].add(output_column.split('.')[0]) + for output_table_name in output_tables[service_name]: + table_related_service[output_table_name].add(service_name) + # Collect the output tables directly related to service_to_clear + if service_name in services_to_clear: + tables_to_clear.update(output_tables[service_name]) + else: + logger.debug(f"Service binding for {bound_service.service.name} is not cached") + + # Collect the output tables that need to be cleared considering the fk and service constraints + # Do a bfs on the reversed table graph + table_graph = self._config.table_graph + table_queue = deque(tables_to_clear) + + def add_table_to_queue(table): + if table not in tables_to_clear: + tables_to_clear.add(table) + table_queue.append(table) + + while len(table_queue) > 0: + table_name = table_queue.popleft() + # Add tables considering fk constraints + for in_edge in table_graph.in_edges(table_name): + add_table_to_queue(in_edge[0]) + # Add tables considering service constraints + services_to_clear.update(table_related_service[table_name]) + for service_name in table_related_service[table_name]: + for table_to_add in output_tables[service_name]: + add_table_to_queue(table_to_add) + + async with self._sql_engine.begin() as conn: + # Delete cache tables + for bound_service in self._config.inference_bindings: + if bound_service.service.name in services_to_clear: + asyncio_run(conn.execute(delete(bound_service._cache_table))) + + # Delete output tables + table_order = nx.topological_sort(table_graph) + for table_name in table_order: + if table_name in tables_to_clear: + asyncio_run(conn.execute(delete(self._config.tables[table_name]._table))) \ No newline at end of file diff --git a/launch.py b/launch.py index d2bceb93..291d7308 100644 --- a/launch.py +++ b/launch.py @@ -3,11 +3,12 @@ import pandas as pd -from aidb_utilities.command_line_setup.command_line_setup import command_line_utility +from aidb.utils.asyncio import asyncio_run from aidb_utilities.aidb_setup.aidb_factory import AIDB +from aidb_utilities.command_line_setup.command_line_setup import \ + command_line_utility from aidb_utilities.db_setup.blob_table import BaseTablesSetup from aidb_utilities.db_setup.create_tables import create_output_tables -from aidb.utils.asyncio import asyncio_run def setup_blob_tables(config): @@ -22,6 +23,7 @@ def setup_blob_tables(config): parser.add_argument("--setup-blob-table", action='store_true') parser.add_argument("--setup-output-tables", action='store_true') parser.add_argument("--verbose", action='store_true') + parser.add_argument("--clear-cache", nargs='*') args = parser.parse_args() config = importlib.import_module(args.config) @@ -33,4 +35,8 @@ def setup_blob_tables(config): asyncio_run(create_output_tables(config.DB_URL, config.DB_NAME, config.tables)) aidb_engine = AIDB.from_config(args.config, args.verbose) + + if args.clear_cache is not None: + asyncio_run(aidb_engine.clear_ml_cache(None if len(args.clear_cache) == 0 else args.clear_cache)) + command_line_utility(aidb_engine) diff --git a/tests/tests_caching_logic.py b/tests/tests_caching_logic.py index f79539ee..ef7704f6 100644 --- a/tests/tests_caching_logic.py +++ b/tests/tests_caching_logic.py @@ -1,13 +1,17 @@ -from multiprocessing import Process import os -from sqlalchemy.sql import text import time import unittest +from multiprocessing import Process from unittest import IsolatedAsyncioTestCase +from sqlalchemy.sql import text + +from aidb.utils.asyncio import asyncio_run from aidb.utils.logger import logger -from tests.inference_service_utils.inference_service_setup import register_inference_services -from tests.inference_service_utils.http_inference_service_setup import run_server +from tests.inference_service_utils.http_inference_service_setup import \ + run_server +from tests.inference_service_utils.inference_service_setup import \ + register_inference_services from tests.utils import setup_gt_and_aidb_engine, setup_test_logger setup_test_logger('caching_logic') @@ -43,11 +47,16 @@ async def test_num_infer_calls(self): '''SELECT * FROM objects00 WHERE object_name='car' AND frame < 400;''' ), ] + + # Get the initial call count since inference services may be called by other tests before + initial_infer_one_calls = aidb_engine._config.inference_services["objects00"].infer_one.calls + + # May have cache before test so clear them + asyncio_run(aidb_engine.clear_ml_cache()) - # no service calls before executing query - assert aidb_engine._config.inference_services["objects00"].infer_one.calls == 0 - - calls = [20, 27] + calls = [[initial_infer_one_calls + 20, initial_infer_one_calls + 40], + [initial_infer_one_calls + 47, initial_infer_one_calls + 74]] + # First 300 need 20 calls, 300 to 400 need 7 calls for index, (query_type, aidb_query, exact_query) in enumerate(queries): logger.info(f'Running query {exact_query} in ground truth database') # Run the query on the ground truth database @@ -55,16 +64,91 @@ async def test_num_infer_calls(self): gt_res = await conn.execute(text(exact_query)) gt_res = gt_res.fetchall() # Run the query on the aidb database - logger.info(f'Running query {aidb_query} in aidb database') + logger.info(f'Running initial query {aidb_query} in aidb database') aidb_res = aidb_engine.execute(aidb_query) assert len(gt_res) == len(aidb_res) # running the same query, so number of inference calls should remain same # temporarily commenting this out because we no longer call infer_one - assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index] + assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index][0] + logger.info(f'Running cached query {aidb_query} in aidb database') + aidb_res = aidb_engine.execute(aidb_query) + assert len(gt_res) == len(aidb_res) + # run again, because cache exists, there should be no new calls + assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index][0] + asyncio_run(aidb_engine.clear_ml_cache()) + logger.info(f'Running uncached query {aidb_query} in aidb database') + aidb_res = aidb_engine.execute(aidb_query) + assert len(gt_res) == len(aidb_res) + # cleared cache, so should accumulate new calls same as the first call + assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index][1] del gt_engine del aidb_engine p.terminate() + p.join() + + async def test_only_one_service_deleted(self): + ''' + Testing whether cache for other service remains when only one service is deleted. + Do query on two different services first. Then delete cache for one service. + Finally do query on these services again and check whether the call count is correct. + ''' + dirname = os.path.dirname(__file__) + data_dir = os.path.join(dirname, 'data/jackson') + + p = Process(target=run_server, args=[str(data_dir)]) + p.start() + time.sleep(1) + db_url_list = [POSTGRESQL_URL] + + for db_url in db_url_list: + gt_engine, aidb_engine = await setup_gt_and_aidb_engine(db_url, data_dir) + register_inference_services(aidb_engine, data_dir, batch_supported=False) + queries = [ + ( + 'full_scan', + '''SELECT * FROM lights01 WHERE light_1='red' AND frame < 300;''', + '''SELECT * FROM lights01 WHERE light_1='red' AND frame < 300;''' + ), + ( + 'full_scan', + '''SELECT * FROM counts03 WHERE count = 1 AND frame < 300;''', + '''SELECT * FROM counts03 WHERE count = 1 AND frame < 300;''' + ), + ] + + # Get the initial call count since inference services may be called by other tests before + # all the infer_one call use the same counter, so checking only one of them should be enough + initial_infer_one_calls = aidb_engine._config.inference_services["counts03"].infer_one.calls + + # May have cache before test so clear them + asyncio_run(aidb_engine.clear_ml_cache()) + + calls = [[initial_infer_one_calls + 20, initial_infer_one_calls + 40], + [initial_infer_one_calls + 60, initial_infer_one_calls + 60]] + # each query calls 20 inference. + # For the first two queries, all the inference is needed. + # The third query need to infer again but the fourth don't. + for round in range(2): + for index, (query_type, aidb_query, exact_query) in enumerate(queries): + logger.info(f'Running query {exact_query} in ground truth database') + # Run the query on the ground truth database + async with gt_engine.begin() as conn: + gt_res = await conn.execute(text(exact_query)) + gt_res = gt_res.fetchall() + logger.info(f'Running query {aidb_query} in aidb database') + # Run the query on the aidb database + aidb_res = aidb_engine.execute(aidb_query) + assert len(gt_res) == len(aidb_res) + # Check the call number + assert aidb_engine._config.inference_services["counts03"].infer_one.calls == calls[round][index] + # Clear the cache for one of the services and retain the other one + asyncio_run(aidb_engine.clear_ml_cache(["lights01"])) + + del gt_engine + del aidb_engine + p.terminate() + p.join() if __name__ == '__main__': unittest.main()