Skip to content

Commit 04a03bd

Browse files
author
ci bot
committed
Merge branch 'validate-params' into 'enterprise'
fix(nav): validate query parameters See merge request dkinternal/testgen/dataops-testgen!286
2 parents c603bd3 + fdea797 commit 04a03bd

File tree

11 files changed

+47
-9
lines changed

11 files changed

+47
-9
lines changed

testgen/common/models/profiling_run.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,13 @@ def select_minimal_where(
131131
@classmethod
132132
@st.cache_data(show_spinner=False)
133133
def select_summary(
134-
cls, project_code: str, table_group_id: str | None = None, profiling_run_ids: list[str] | None = None
134+
cls, project_code: str, table_group_id: str | UUID | None = None, profiling_run_ids: list[str] | None = None
135135
) -> Iterable[ProfilingRunSummary]:
136+
if (table_group_id and not is_uuid4(table_group_id)) or (
137+
profiling_run_ids and not all(is_uuid4(run_id) for run_id in profiling_run_ids)
138+
):
139+
return []
140+
136141
query = f"""
137142
WITH profile_anomalies AS (
138143
SELECT profile_anomaly_results.profile_run_id,

testgen/common/models/scores.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"transform_level",
5151
"data_product",
5252
]
53+
ScoreTypes = Literal["score", "cde_score"]
5354

5455

5556
class ScoreCategory(enum.Enum):
@@ -344,7 +345,7 @@ def get_score_card_issues(
344345

345346
dq_dimension_filter = ""
346347
if group_by == "dq_dimension":
347-
dq_dimension_filter = f" AND dq_dimension = '{value_}'"
348+
dq_dimension_filter = " AND dq_dimension = :value"
348349

349350
query = (
350351
read_template_sql_file(query_template_file, sub_directory="score_cards")
@@ -594,7 +595,7 @@ def filter(
594595
*,
595596
definition_id: str,
596597
category: Categories,
597-
score_type: Literal["score", "cde_score"],
598+
score_type: ScoreTypes,
598599
) -> Iterable[Self]:
599600
items = []
600601
db_session = get_current_session()

testgen/common/models/test_run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def select_summary(
125125
test_suite_id: str | None = None,
126126
test_run_ids: list[str] | None = None,
127127
) -> Iterable[TestRunSummary]:
128+
if (
129+
(table_group_id and not is_uuid4(table_group_id))
130+
or (test_suite_id and not is_uuid4(test_suite_id))
131+
or (test_run_ids and not all(is_uuid4(run_id) for run_id in test_run_ids))
132+
):
133+
return []
134+
128135
query = f"""
129136
WITH run_results AS (
130137
SELECT test_run_id,

testgen/common/models/test_suite.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from testgen.common.models import get_current_session
1212
from testgen.common.models.custom_types import NullIfEmptyString, YNString
1313
from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal
14+
from testgen.utils import is_uuid4
1415

1516

1617
@dataclass
@@ -83,7 +84,10 @@ def select_minimal_where(
8384

8485
@classmethod
8586
@st.cache_data(show_spinner=False)
86-
def select_summary(cls, project_code: str, table_group_id: str | None = None) -> Iterable[TestSuiteSummary]:
87+
def select_summary(cls, project_code: str, table_group_id: str | UUID | None = None) -> Iterable[TestSuiteSummary]:
88+
if table_group_id and not is_uuid4(table_group_id):
89+
return []
90+
8791
query = f"""
8892
WITH last_run AS (
8993
SELECT test_runs.test_suite_id,

testgen/ui/navigation/router.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import testgen.ui.navigation.page
99
from testgen.common.mixpanel_service import MixpanelService
10+
from testgen.common.models.project import Project
1011
from testgen.ui.session import session
1112
from testgen.utils.singleton import Singleton
1213

@@ -114,8 +115,8 @@ def navigate(self, /, to: str, with_args: dict = {}) -> None: # noqa: B006
114115
def navigate_with_warning(self, warning: str, to: str, with_args: dict = {}) -> None: # noqa: B006
115116
st.warning(warning)
116117
time.sleep(3)
117-
self.navigate(to, with_args)
118-
118+
session.sidebar_project = session.sidebar_project or Project.select_where()[0].project_code
119+
self.navigate(to, {"project_code": session.sidebar_project, **with_args})
119120

120121
def set_query_params(self, with_args: dict) -> None:
121122
params = st.query_params

testgen/ui/views/data_catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def get_table_group_columns(table_group_id: str) -> list[dict]:
426426

427427

428428
def get_selected_item(selected: str, table_group_id: str) -> dict | None:
429-
if not selected or not is_uuid4(table_group_id):
429+
if not selected or "_" not in selected or not is_uuid4(table_group_id):
430430
return None
431431

432432
item_type, item_id = selected.split("_", 2)

testgen/ui/views/profiling_runs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs
5959
with group_filter_column:
6060
table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code)
6161
table_groups_df = to_dataframe(table_groups, TableGroupMinimal.columns())
62+
table_groups_df["id"] = table_groups_df["id"].apply(lambda x: str(x))
6263
table_group_id = testgen.select(
6364
options=table_groups_df,
6465
value_column="id",

testgen/ui/views/score_details.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import typing
23
from io import BytesIO
34
from typing import ClassVar
45

@@ -8,7 +9,14 @@
89
from testgen.commands.run_refresh_score_cards_results import run_recalculate_score_card
910
from testgen.common.mixpanel_service import MixpanelService
1011
from testgen.common.models import with_database_session
11-
from testgen.common.models.scores import ScoreCategory, ScoreDefinition, ScoreDefinitionBreakdownItem, SelectedIssue
12+
from testgen.common.models.scores import (
13+
Categories,
14+
ScoreCategory,
15+
ScoreDefinition,
16+
ScoreDefinitionBreakdownItem,
17+
ScoreTypes,
18+
SelectedIssue,
19+
)
1220
from testgen.ui.components import widgets as testgen
1321
from testgen.ui.components.widgets.download_dialog import FILE_DATA_TYPE, download_dialog, zip_multi_file_data
1422
from testgen.ui.navigation.page import Page
@@ -60,6 +68,9 @@ def render(
6068
],
6169
)
6270

71+
if category not in typing.get_args(Categories):
72+
category = None
73+
6374
if not category and score_definition.category:
6475
category = score_definition.category.value
6576

@@ -72,6 +83,8 @@ def render(
7283
with st.spinner(text="Loading data :gray[:small[(This might take a few minutes)]] ..."):
7384
user_can_edit = user_session_service.user_can_edit()
7485
score_card = format_score_card(score_definition.as_cached_score_card())
86+
if score_type not in typing.get_args(ScoreTypes):
87+
score_type = None
7588
if not score_type:
7689
score_type = "cde_score" if score_card["cde_score"] and not score_card["score"] else "score"
7790
if not drilldown:

testgen/ui/views/score_explorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def render(
105105
score_definition.name = name
106106
score_definition.total_score = total_score and total_score.lower() == "true"
107107
score_definition.cde_score = cde_score and cde_score.lower() == "true"
108-
score_definition.category = ScoreCategory(category) if category else None
108+
score_definition.category = ScoreCategory(category) if category in [cat.value for cat in ScoreCategory] else None
109109

110110
if filters:
111111
applied_filters: list[dict] = try_json(filters, default=[])

testgen/ui/views/table_groups.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class TableGroupsPage(Page):
3030
can_activate: typing.ClassVar = [
3131
lambda: session.authentication_status,
3232
lambda: not user_session_service.user_has_catalog_role(),
33+
lambda: "project_code" in st.query_params,
3334
]
3435
menu_item = MenuItem(
3536
icon="table_view",
@@ -44,6 +45,9 @@ def render(self, project_code: str, connection_id: str | None = None, **_kwargs)
4445

4546
user_can_edit = user_session_service.user_can_edit()
4647
project_summary = Project.get_summary(project_code)
48+
if connection_id and not connection_id.isdigit():
49+
connection_id = None
50+
4751
if connection_id:
4852
table_groups = TableGroup.select_minimal_where(
4953
TableGroup.project_code == project_code,

0 commit comments

Comments
 (0)