Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions aidb/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from aidb.engine.approx_aggregate_join_engine import ApproximateAggregateJoinEngine
import networkx as nx
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import delete

from aidb.engine.approx_aggregate_join_engine import \
ApproximateAggregateJoinEngine
from aidb.engine.approx_select_engine import ApproxSelectEngine
from aidb.engine.limit_engine import LimitEngine
from aidb.engine.non_select_query_engine import NonSelectQueryEngine
from aidb.utils.asyncio import asyncio_run
from aidb.inference.bound_inference_service import CachedBoundInferenceService
from aidb.query.query import Query
from aidb.utils.asyncio import asyncio_run
from aidb.utils.logger import logger


class Engine(LimitEngine, NonSelectQueryEngine, ApproxSelectEngine, ApproximateAggregateJoinEngine):
Expand Down Expand Up @@ -42,3 +49,37 @@ def execute(self, query: str, **kwargs):
raise e
finally:
self.__del__()

async def clear_ml_cache(self, service_name_list = None):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this function is so complicated?

Copy link
Contributor Author

@hjk1030 hjk1030 May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this function is building the foreign key relationship graph since deleting the data referenced by another output table will cause error. I believe this is necessary unless there exists such a graph already.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other functions that build the fk relationship graph? If so, that should be refactored

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'''
Clear the cache and output table if the ML model has changed.
Delete the tables following the inference services' topological order to maintain integrity during deletion.
service_name_list: the name of all the changed services. A list of str or None.
If the service name list is not given, the output for all the services will be cleared.
'''
async with self._sql_engine.begin() as conn:
service_ordering = self._config.inference_topological_order
if service_name_list is None:
service_name_list = [bounded_service.service.name for bounded_service in service_ordering]
service_name_list = set(service_name_list)

# Get all the services that need to be cleared because of foreign key constraints
inference_graph = self._config.inference_graph
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The edge in inference_graph doesn't represent two nodes have foreign key constraints

for bounded_service in service_ordering:
if bounded_service.service.name in service_name_list:
for input_column in bounded_service.binding.input_columns:
for in_edge in inference_graph.in_edges(input_column):
service_name_list.add(inference_graph.get_edge_data(*in_edge)['bound_service'])

# Clear the services in reversed topological order
for bounded_service in reversed(service_ordering):
if isinstance(bounded_service, CachedBoundInferenceService):
if bounded_service.service.name in service_name_list:
asyncio_run(conn.execute(delete(bounded_service._cache_table)))
output_tables_to_be_deleted = set()
for output_column in bounded_service.binding.output_columns:
output_tables_to_be_deleted.add(output_column.split('.')[0])
for table_name in output_tables_to_be_deleted:
asyncio_run(conn.execute(delete(bounded_service._tables[table_name]._table)))
else:
logger.debug(f"Service binding for {bounded_service.service.name} is not cached")
10 changes: 8 additions & 2 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import pandas as pd

from aidb_utilities.command_line_setup.command_line_setup import command_line_utility
from aidb.utils.asyncio import asyncio_run
from aidb_utilities.aidb_setup.aidb_factory import AIDB
from aidb_utilities.command_line_setup.command_line_setup import \
command_line_utility
from aidb_utilities.db_setup.blob_table import BaseTablesSetup
from aidb_utilities.db_setup.create_tables import create_output_tables
from aidb.utils.asyncio import asyncio_run


def setup_blob_tables(config):
Expand All @@ -22,6 +23,7 @@ def setup_blob_tables(config):
parser.add_argument("--setup-blob-table", action='store_true')
parser.add_argument("--setup-output-tables", action='store_true')
parser.add_argument("--verbose", action='store_true')
parser.add_argument("--clear-cache", nargs='*')
args = parser.parse_args()

config = importlib.import_module(args.config)
Expand All @@ -33,4 +35,8 @@ def setup_blob_tables(config):
asyncio_run(create_output_tables(config.DB_URL, config.DB_NAME, config.tables))

aidb_engine = AIDB.from_config(args.config, args.verbose)

if args.clear_cache is not None:
asyncio_run(aidb_engine.clear_ml_cache(None if len(args.clear_cache) == 0 else args.clear_cache))

command_line_utility(aidb_engine)
114 changes: 104 additions & 10 deletions tests/tests_caching_logic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from multiprocessing import Process
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you added this test to Github Action?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The test is modified based on an original test and I verified it is executed.

from sqlalchemy.sql import text
import time
import unittest
from multiprocessing import Process
from unittest import IsolatedAsyncioTestCase

from sqlalchemy.sql import text

from aidb.utils.asyncio import asyncio_run
from aidb.utils.logger import logger
from tests.inference_service_utils.inference_service_setup import register_inference_services
from tests.inference_service_utils.http_inference_service_setup import run_server
from tests.inference_service_utils.http_inference_service_setup import \
run_server
from tests.inference_service_utils.inference_service_setup import \
register_inference_services
from tests.utils import setup_gt_and_aidb_engine, setup_test_logger

setup_test_logger('caching_logic')
Expand Down Expand Up @@ -43,28 +47,118 @@ async def test_num_infer_calls(self):
'''SELECT * FROM objects00 WHERE object_name='car' AND frame < 400;'''
),
]

# Get the initial call count since inference services may be called by other tests before
initial_infer_one_calls = aidb_engine._config.inference_services["objects00"].infer_one.calls

# May have cache before test so clear them
asyncio_run(aidb_engine.clear_ml_cache())

# no service calls before executing query
assert aidb_engine._config.inference_services["objects00"].infer_one.calls == 0

calls = [20, 27]
calls = [[initial_infer_one_calls + 20, initial_infer_one_calls + 40],
[initial_infer_one_calls + 47, initial_infer_one_calls + 74]]
# First 300 need 20 calls, 300 to 400 need 7 calls
for index, (query_type, aidb_query, exact_query) in enumerate(queries):
logger.info(f'Running query {exact_query} in ground truth database')
# Run the query on the ground truth database
async with gt_engine.begin() as conn:
gt_res = await conn.execute(text(exact_query))
gt_res = gt_res.fetchall()
# Run the query on the aidb database
logger.info(f'Running query {aidb_query} in aidb database')
logger.info(f'Running initial query {aidb_query} in aidb database')
aidb_res = aidb_engine.execute(aidb_query)
assert len(gt_res) == len(aidb_res)
# running the same query, so number of inference calls should remain same
# temporarily commenting this out because we no longer call infer_one
assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index]
assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index][0]
logger.info(f'Running cached query {aidb_query} in aidb database')
aidb_res = aidb_engine.execute(aidb_query)
assert len(gt_res) == len(aidb_res)
# run again, because cache exists, there should be no new calls
assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index][0]
asyncio_run(aidb_engine.clear_ml_cache())
logger.info(f'Running uncached query {aidb_query} in aidb database')
aidb_res = aidb_engine.execute(aidb_query)
assert len(gt_res) == len(aidb_res)
# cleared cache, so should accumulate new calls same as the first call
assert aidb_engine._config.inference_services["objects00"].infer_one.calls == calls[index][1]
del gt_engine
del aidb_engine
p.terminate()
p.join()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I cannot terminate the test server completely without this. I'm not sure whether it's a bug or I'm not using the correct way to write multiple tests.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check this in depth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the test server wasn't properly closed before. The server will still occupy the port if only the terminate() function is called. However most test classes only have one unit test or the service is not called, so it did not cause any problems.


async def test_only_one_service_deleted(self):
'''
Testing whether cache for other service remains when only one service is deleted.
Do query on two different services first. Then delete cache for one service.
Finally do query on these services again and check whether the call count is correct.
'''
dirname = os.path.dirname(__file__)
data_dir = os.path.join(dirname, 'data/jackson')

p = Process(target=run_server, args=[str(data_dir)])
p.start()
time.sleep(1)
db_url_list = [POSTGRESQL_URL]

for db_url in db_url_list:
gt_engine, aidb_engine = await setup_gt_and_aidb_engine(db_url, data_dir)
register_inference_services(aidb_engine, data_dir, batch_supported=False)

queries = [
(
'full_scan',
'''SELECT * FROM lights01 WHERE light_1='red' AND frame < 300;''',
'''SELECT * FROM lights01 WHERE light_1='red' AND frame < 300;'''
),
(
'full_scan',
'''SELECT * FROM counts03 WHERE count = 1 AND frame < 300;''',
'''SELECT * FROM counts03 WHERE count = 1 AND frame < 300;'''
),
]

# Get the initial call count since inference services may be called by other tests before
# all the infer_one call use the same counter, so checking only one of them should be enough
initial_infer_one_calls = aidb_engine._config.inference_services["counts03"].infer_one.calls

# May have cache before test so clear them
asyncio_run(aidb_engine.clear_ml_cache())

calls = [[initial_infer_one_calls + 20, initial_infer_one_calls + 40],
[initial_infer_one_calls + 60, initial_infer_one_calls + 60]]
# each query calls 20 inference.
# For the first two queries, all the inference is needed.
# The third query need to infer again but the fourth don't.
for index, (query_type, aidb_query, exact_query) in enumerate(queries):
logger.info(f'Running query {exact_query} in ground truth database')
# Run the query on the ground truth database
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write comments in the correct locations.

async with gt_engine.begin() as conn:
gt_res = await conn.execute(text(exact_query))
gt_res = gt_res.fetchall()
logger.info(f'Running initial query {aidb_query} in aidb database')
# Run the query on the aidb database
aidb_res = aidb_engine.execute(aidb_query)
assert len(gt_res) == len(aidb_res)
assert aidb_engine._config.inference_services["counts03"].infer_one.calls == calls[0][index]

asyncio_run(aidb_engine.clear_ml_cache(["lights01"]))

for index, (query_type, aidb_query, exact_query) in enumerate(queries):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the code. Could this loop be merged into previous one?

logger.info(f'Running query {exact_query} in ground truth database')
# Run the query on the ground truth database
async with gt_engine.begin() as conn:
gt_res = await conn.execute(text(exact_query))
gt_res = gt_res.fetchall()
logger.info(f'Running query {aidb_query} in aidb database after cache deleted')
# Run the query on the aidb database
aidb_res = aidb_engine.execute(aidb_query)
assert len(gt_res) == len(aidb_res)
assert aidb_engine._config.inference_services["counts03"].infer_one.calls == calls[1][index]

del gt_engine
del aidb_engine
p.terminate()
p.join()

if __name__ == '__main__':
unittest.main()