Skip to content

Commit bf1122f

Browse files
committed
Add full custom query for observations
1 parent 0d02b76 commit bf1122f

File tree

6 files changed

+66
-40
lines changed

6 files changed

+66
-40
lines changed

src/domain/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ class Observation(Base):
176176
comment="Custom query to calculate the test_ids for the experiment. "
177177
"The query will be used in jinja template instead of the default query",
178178
)
179+
full_custom_query = Column(
180+
Text,
181+
nullable=True,
182+
comment="Full custom query to calculate metrics. "
183+
"If provided, this query will be used instead of the template-based query.",
184+
)
179185

180186
metric_tags = Column(JSON, nullable=True, comment="Filter metrics by tags")
181187
metric_groups = Column(JSON, nullable=True, comment="Filter metrics by groups")

src/services/runners/renderer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def render(
4646
Returns:
4747
str: The rendered SQL query string.
4848
"""
49+
if obs.full_custom_query:
50+
return str(obs.full_custom_query)
51+
4952
calc_scenario_path = self.templates_config.scenarios.get(str(obs.calculation_scenario))
5053
if not calc_scenario_path:
5154
error_message = (

src/ui/observations/inputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class ObservationFormData:
138138
audience_tables: list[str] | None
139139
filters: list[str] | None
140140
custom_test_ids_query: str | None
141+
full_custom_query: str | None
141142
metric_tags: list[str] | None
142143
metric_groups: list[str] | None
143144

@@ -345,6 +346,9 @@ def render(
345346
"Supports JOINs, CTEs, subqueries, and complex aggregations. "
346347
"Modifying operations (INSERT, UPDATE, DELETE) are not allowed.",
347348
)
349+
full_custom_query = st.text_area(
350+
"Full Custom Query", value="", key="full_custom_query_input_key"
351+
)
348352

349353
metric_tags: list[str] | None = st.multiselect(
350354
"Metric Tags",
@@ -425,6 +429,7 @@ def render(
425429
audience_tables=[t.strip() for t in audience_tables.split("\n") if t.strip()],
426430
filters=[f.strip() for f in filters.split("\n") if f.strip()],
427431
custom_test_ids_query=custom_test_ids_query,
432+
full_custom_query=full_custom_query,
428433
metric_tags=metric_tags,
429434
metric_groups=metric_groups,
430435
),

tests/services/runners/conftest.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,31 @@ def fake_metric_results():
8282
@pytest.fixture
8383
def mock_observation():
8484
"""Fixture providing mock observation object for testing."""
85-
return Observation(
86-
id=1,
87-
experiment_id=1,
88-
name="Test Observation",
89-
db_experiment_name="test_experiment_db_name",
90-
split_id="user_id",
91-
calculation_scenario="base",
92-
exposure_start_datetime=datetime(2024, 1, 1),
93-
exposure_end_datetime=datetime(2024, 1, 31),
94-
calc_start_datetime=datetime(2024, 1, 1),
95-
calc_end_datetime=datetime(2024, 1, 31),
96-
exposure_event="view",
97-
audience_tables=["active_users"],
98-
filters=["platform='web'", "(country='US' or country='CA')"],
99-
custom_test_ids_query=None,
100-
metric_tags=["main"],
101-
metric_groups=["core"],
102-
)
85+
86+
def _mock_observation(**kwargs):
87+
default_attrs = {
88+
"id": 1,
89+
"experiment_id": 1,
90+
"name": "Test Observation",
91+
"db_experiment_name": "test_experiment_db_name",
92+
"split_id": "user_id",
93+
"calculation_scenario": "base",
94+
"exposure_start_datetime": datetime(2024, 1, 1),
95+
"exposure_end_datetime": datetime(2024, 1, 31),
96+
"calc_start_datetime": datetime(2024, 1, 1),
97+
"calc_end_datetime": datetime(2024, 1, 31),
98+
"exposure_event": "view",
99+
"audience_tables": ["active_users"],
100+
"filters": ["platform='web'", "(country='US' or country='CA')"],
101+
"custom_test_ids_query": None,
102+
"full_custom_query": None,
103+
"metric_tags": ["main"],
104+
"metric_groups": ["core"],
105+
}
106+
default_attrs.update(kwargs)
107+
return Observation(**default_attrs)
108+
109+
return _mock_observation
103110

104111

105112
@pytest.fixture

tests/services/runners/test_executors.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,17 @@ def test_mock_calculation_runner_run_calculations_regular(
1616
):
1717
"""Test calculation runner executes regular calculations successfully."""
1818
mock_calculation_runner.exec._connector.fetch_results.return_value = mock_metric_results
19+
obs = mock_observation()
1920
# Execute
20-
result = mock_calculation_runner.run_calculation(
21-
obs=mock_observation, purpose=CalculationPurpose.REGULAR
22-
)
21+
result = mock_calculation_runner.run_calculation(obs=obs, purpose=CalculationPurpose.REGULAR)
2322

2423
assert result.job_id == 1
2524
assert result.success is True
2625
assert len(result.metric_results) == 2
2726
assert result.metric_results[0].metric_name == "conversion_rate"
2827
# get result of job
2928
executed_job = JobHandler(engine).get(1)
30-
rendered_query = mock_calculation_runner.renderer.render(
31-
obs=mock_observation, purpose=CalculationPurpose.REGULAR
32-
)
29+
rendered_query = mock_calculation_runner.renderer.render(obs=obs, purpose=CalculationPurpose.REGULAR)
3330
mock_calculation_runner.exec._connector.fetch_results.assert_called_once_with(rendered_query)
3431
assert executed_job.status == JobStatus.COMPLETED
3532

@@ -45,9 +42,10 @@ def test_mock_calculation_runner_run_calculations_planning(
4542
):
4643
"""Test calculation runner executes planning calculations without storing."""
4744
mock_calculation_runner.exec._connector.fetch_results.return_value = mock_metric_results
45+
obs = mock_observation()
4846
# Execute
4947
result = mock_calculation_runner.run_calculation(
50-
obs=mock_observation,
48+
obs=obs,
5149
purpose=CalculationPurpose.PLANNING,
5250
)
5351
assert isinstance(result.metric_results[0], MetricResult)
@@ -66,7 +64,8 @@ def test_mock_calculation_runner_run_calculations_planning_failed(
6664
):
6765
"""Test calculation runner handles database errors properly."""
6866
mock_calculation_runner.exec._connector.fetch_results.side_effect = Exception("Database error")
69-
mock_calculation_runner.run_calculation(obs=mock_observation, purpose=CalculationPurpose.REGULAR)
67+
obs = mock_observation()
68+
mock_calculation_runner.run_calculation(obs=obs, purpose=CalculationPurpose.REGULAR)
7069
executed_job = JobHandler(engine).get(1)
7170
assert executed_job.status == JobStatus.FAILED
7271
assert "Database error" in executed_job.error_message
@@ -76,10 +75,9 @@ def test_mock_calculation_runner_run_calculations_planning_failed(
7675

7776
def test_calculation_runner_job_creation_failure(mock_calculation_runner, mock_observation):
7877
"""Test that CalculationRunner handles job creation failure."""
78+
obs = mock_observation()
7979
with patch("src.services.runners.executors.JobHandler.create", return_value=None):
80-
result = mock_calculation_runner.run_calculation(
81-
obs=mock_observation, purpose=CalculationPurpose.REGULAR
82-
)
80+
result = mock_calculation_runner.run_calculation(obs=obs, purpose=CalculationPurpose.REGULAR)
8381

8482
assert result.success is False
8583
assert result.job_id is None
@@ -88,10 +86,9 @@ def test_calculation_runner_job_creation_failure(mock_calculation_runner, mock_o
8886

8987
def test_calculation_runner_job_creation_db_exception(mock_calculation_runner, mock_observation):
9088
"""Test that CalculationRunner handles job creation failure due to DB exception."""
89+
obs = mock_observation()
9190
with patch("src.services.runners.executors.JobHandler.create", side_effect=Exception("DB is down")):
92-
result = mock_calculation_runner.run_calculation(
93-
obs=mock_observation, purpose=CalculationPurpose.REGULAR
94-
)
91+
result = mock_calculation_runner.run_calculation(obs=obs, purpose=CalculationPurpose.REGULAR)
9592

9693
assert result.success is False
9794
assert result.job_id is None
@@ -100,10 +97,9 @@ def test_calculation_runner_job_creation_db_exception(mock_calculation_runner, m
10097

10198
def test_calculation_runner_render_failure(mock_calculation_runner, mock_observation, engine, tables):
10299
"""Test that CalculationRunner handles query rendering failure."""
100+
obs = mock_observation()
103101
with patch.object(mock_calculation_runner.renderer, "render", side_effect=Exception("Template error")):
104-
result = mock_calculation_runner.run_calculation(
105-
obs=mock_observation, purpose=CalculationPurpose.REGULAR
106-
)
102+
result = mock_calculation_runner.run_calculation(obs=obs, purpose=CalculationPurpose.REGULAR)
107103

108104
assert result.success is False
109105
assert result.job_id == 1
@@ -119,14 +115,13 @@ def test_calculation_runner_store_metrics_failure(
119115
):
120116
"""Test that CalculationRunner handles storing metrics failure."""
121117
mock_calculation_runner.exec._connector.fetch_results.return_value = mock_metric_results
118+
obs = mock_observation()
122119
with patch.object(
123120
mock_calculation_runner.jobs,
124121
"store_metrics",
125122
side_effect=Exception("DB connection failed"),
126123
):
127-
result = mock_calculation_runner.run_calculation(
128-
obs=mock_observation, purpose=CalculationPurpose.REGULAR
129-
)
124+
result = mock_calculation_runner.run_calculation(obs=obs, purpose=CalculationPurpose.REGULAR)
130125

131126
assert result.success is False
132127
assert result.job_id == 1

tests/services/runners/test_renderer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
def test_renderer(query_renderer, mock_observation):
55
"""Test rendering base calculation query for planning purpose"""
6-
query = query_renderer.render(obs=mock_observation, purpose=CalculationPurpose.PLANNING)
6+
obs = mock_observation()
7+
query = query_renderer.render(obs=obs, purpose=CalculationPurpose.PLANNING)
78
# TODO: fix mock observation and fix fake template according to custom_test_ids_query
89
assert isinstance(query, str)
910
assert "user_id as split_id" in query
@@ -15,8 +16,9 @@ def test_renderer(query_renderer, mock_observation):
1516

1617
def test_renderer_with_metric_names(query_renderer, mock_observation):
1718
"""Test rendering base calculation query with specific metric names"""
19+
obs = mock_observation()
1820
query = query_renderer.render(
19-
obs=mock_observation,
21+
obs=obs,
2022
purpose=CalculationPurpose.PLANNING,
2123
experiment_metric_names=["click_through_rate"],
2224
)
@@ -27,3 +29,11 @@ def test_renderer_with_metric_names(query_renderer, mock_observation):
2729
assert exp_alias not in query
2830
for user_alias in ["product_purchase_cnt", "session_duration"]:
2931
assert user_alias not in query
32+
33+
34+
def test_renderer_with_full_custom_query(query_renderer, mock_observation):
35+
"""Test that full_custom_query bypasses template rendering"""
36+
custom_query = "SELECT 1"
37+
obs = mock_observation(full_custom_query=custom_query)
38+
query = query_renderer.render(obs=obs, purpose=CalculationPurpose.REGULAR)
39+
assert query == custom_query

0 commit comments

Comments
 (0)