Skip to content

Commit 38ea749

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Extend Bigquery detect_anomalies tool to support future data anomaly detection
ARIMA supports both historical data and future data anomaly detection. This CL add how the tool support future table anomaly detection. PiperOrigin-RevId: 827803748
1 parent d2888a3 commit 38ea749

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

src/google/adk/tools/bigquery/query_tool.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ def detect_anomalies(
11001100
times_series_timestamp_col: str,
11011101
times_series_data_col: str,
11021102
horizon: Optional[int] = 10,
1103+
target_data: Optional[str] = None,
11031104
times_series_id_cols: Optional[list[str]] = None,
11041105
anomaly_prob_threshold: Optional[float] = 0.95,
11051106
*,
@@ -1121,6 +1122,9 @@ def detect_anomalies(
11211122
numerical values to be forecasted and anomaly detected.
11221123
horizon (int, optional): The number of time steps to forecast into the
11231124
future. Defaults to 10.
1125+
target_data (str, optional): The table id of the BigQuery table containing
1126+
the target time series data or a query statement that select the target
1127+
data.
11241128
times_series_id_cols (list, optional): The column names of the id columns
11251129
to indicate each time series when there are multiple time series in the
11261130
table. All elements must be strings. Defaults to None.
@@ -1264,6 +1268,18 @@ def detect_anomalies(
12641268
anomaly_detection_query = f"""
12651269
SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold))
12661270
"""
1271+
if target_data:
1272+
trimmed_upper_target_data = target_data.strip().upper()
1273+
if trimmed_upper_target_data.startswith(
1274+
"SELECT"
1275+
) or trimmed_upper_target_data.startswith("WITH"):
1276+
target_data_source = f"({target_data})"
1277+
else:
1278+
target_data_source = f"SELECT * FROM `{target_data}`"
1279+
1280+
anomaly_detection_query = f"""
1281+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold), {target_data_source})
1282+
"""
12671283

12681284
# Create a session and run the create model query.
12691285
original_write_mode = settings.write_mode

tests/unittests/tools/bigquery/test_bigquery_query_tool.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,62 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
15091509
)
15101510

15111511

1512+
# detect_anomalies calls execute_sql twice. We need to test that
1513+
# the queries are properly constructed and call execute_sql with the correct
1514+
# parameters exactly twice.
1515+
@mock.patch("google.adk.tools.bigquery.query_tool.execute_sql", autospec=True)
1516+
@mock.patch("uuid.uuid4", autospec=True)
1517+
def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
1518+
"""Test time series anomaly detection tool with target data is provided."""
1519+
mock_credentials = mock.MagicMock(spec=Credentials)
1520+
mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
1521+
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
1522+
mock_uuid.return_value = "test_uuid"
1523+
mock_execute_sql.return_value = {"status": "SUCCESS"}
1524+
1525+
history_data_query = "SELECT * FROM `test-dataset.history-table`"
1526+
target_data_query = "SELECT * FROM `test-dataset.target-table`"
1527+
detect_anomalies(
1528+
project_id="test-project",
1529+
history_data=history_data_query,
1530+
times_series_timestamp_col="ts_timestamp",
1531+
times_series_data_col="ts_data",
1532+
times_series_id_cols=["dim1", "dim2"],
1533+
horizon=20,
1534+
target_data=target_data_query,
1535+
anomaly_prob_threshold=0.8,
1536+
credentials=mock_credentials,
1537+
settings=mock_settings,
1538+
tool_context=mock_tool_context,
1539+
)
1540+
1541+
expected_create_model_query = """
1542+
CREATE TEMP MODEL detect_anomalies_model_test_uuid
1543+
OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 20, TIME_SERIES_ID_COL = ['dim1', 'dim2'])
1544+
AS (SELECT * FROM `test-dataset.history-table`)
1545+
"""
1546+
1547+
expected_anomaly_detection_query = """
1548+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.8 AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`))
1549+
"""
1550+
1551+
assert mock_execute_sql.call_count == 2
1552+
mock_execute_sql.assert_any_call(
1553+
"test-project",
1554+
expected_create_model_query,
1555+
mock_credentials,
1556+
mock_settings,
1557+
mock_tool_context,
1558+
)
1559+
mock_execute_sql.assert_any_call(
1560+
"test-project",
1561+
expected_anomaly_detection_query,
1562+
mock_credentials,
1563+
mock_settings,
1564+
mock_tool_context,
1565+
)
1566+
1567+
15121568
def test_detect_anomalies__with_invalid_id_cols():
15131569
"""Test time series anomaly detection tool invocation with invalid times_series_id_cols."""
15141570
mock_credentials = mock.MagicMock(spec=Credentials)

0 commit comments

Comments
 (0)