Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions elementary/monitor/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def decorator(func):
"--select",
type=str,
default=None,
help="Filter the report by last_invocation / invocation_id:<INVOCATION_ID> / invocation_time:<INVOCATION_TIME>."
if cmd in (Command.REPORT, Command.SEND_REPORT)
else "DEPRECATED! Please use --filters instead! - Filter the alerts by tags:<TAGS> / owners:<OWNERS> / models:<MODELS> / "
"statuses:<warn/fail/error/skipped> / resource_types:<model/test>.",
help=(
"Filter the report by last_invocation / invocation_id:<INVOCATION_ID> / invocation_time:<INVOCATION_TIME>."
if cmd in (Command.REPORT, Command.SEND_REPORT)
else "DEPRECATED! Please use --filters instead! - Filter the alerts by tags:<TAGS> / owners:<OWNERS> / models:<MODELS> / "
"statuses:<warn/fail/error/skipped> / resource_types:<model/test>."
),
)(func)
return func

Expand Down Expand Up @@ -364,7 +366,9 @@ def monitor(

alert_filters = FiltersSchema()
if bool(filters) or bool(excludes):
alert_filters = FiltersSchema.from_cli_params(filters, excludes)
alert_filters = FiltersSchema.from_cli_params(
filters, excludes, config, anonymous_tracking
)
elif select is not None:
click.secho(
'\n"--select" is deprecated and won\'t be supported in the near future.\n'
Expand Down
26 changes: 24 additions & 2 deletions elementary/monitor/data_monitoring/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,13 @@ def validate_report_selector(self) -> None:

@staticmethod
def from_cli_params(
cli_filters: Tuple[str], cli_excludes: Tuple[str]
cli_filters: Tuple[str],
cli_excludes: Tuple[str],
config: Optional[Any] = None,
tracking: Optional[Any] = None,
) -> "FiltersSchema":
from elementary.monitor.data_monitoring.selector_filter import SelectorFilter

all_filters: list[tuple[str, FilterType]] = []
for cli_filter in cli_filters:
all_filters.append((cli_filter, FilterType.IS))
Expand All @@ -206,6 +211,7 @@ def from_cli_params(
models = []
statuses = []
resource_types = []
node_names = []

for cli_filter, filter_type in all_filters:
tags_match = FiltersSchema._match_filter_regex(
Expand All @@ -226,7 +232,22 @@ def from_cli_params(
filter_string=cli_filter, regex=re.compile(r"models:(.*)")
)
if models_match:
models.append(FilterSchema(values=models_match, type=filter_type))
model_value = (
models_match[0]
if len(models_match) == 1
else ",".join(models_match)
)
if (
config
and filter_type == FilterType.IS
and SelectorFilter._has_graph_operators(model_value)
):
selector_filter = SelectorFilter(config, tracking, model_value)
filter_result = selector_filter.get_filter()
if filter_result.node_names:
node_names.extend(filter_result.node_names)
else:
models.append(FilterSchema(values=models_match, type=filter_type))
continue

statuses_match = FiltersSchema._match_filter_regex(
Expand Down Expand Up @@ -269,6 +290,7 @@ def from_cli_params(
models=models,
statuses=statuses,
resource_types=resource_types,
node_names=node_names,
)

@staticmethod
Expand Down
29 changes: 23 additions & 6 deletions elementary/monitor/data_monitoring/selector_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,25 @@ def _parse_selector(self, selector: Optional[str] = None) -> FiltersSchema:
selector=selector,
)
elif model_match:
if self.tracking:
self.tracking.set_env("select_method", "model")
data_monitoring_filter = FiltersSchema(
models=[FilterSchema(values=[model_match.group(1)])],
selector=selector,
)
model_value = model_match.group(1)
if self.selector_fetcher and self._has_graph_operators(model_value):
if self.tracking:
self.tracking.set_env(
"select_method", "model with graph operators"
)
node_names = self.selector_fetcher.get_selector_results(
selector=model_value
)
data_monitoring_filter = FiltersSchema(
node_names=node_names, selector=selector
)
else:
if self.tracking:
self.tracking.set_env("select_method", "model")
data_monitoring_filter = FiltersSchema(
models=[FilterSchema(values=[model_value])],
selector=selector,
)
elif statuses_match:
if self.tracking:
self.tracking.set_env("select_method", "statuses")
Expand Down Expand Up @@ -148,6 +161,10 @@ def _create_user_dbt_runner(self, config: Config) -> Optional[BaseDbtRunner]:
def get_filter(self) -> FiltersSchema:
return self.filter

@staticmethod
def _has_graph_operators(selector: str) -> bool:
return "+" in selector

@staticmethod
def _can_use_fetcher(selector):
non_dbt_selectors = [
Expand Down
103 changes: 103 additions & 0 deletions tests/unit/monitor/data_monitoring/test_filters_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from unittest.mock import patch

import pytest

from elementary.monitor.data_monitoring.schema import FiltersSchema, FilterType
from tests.mocks.anonymous_tracking_mock import MockAnonymousTracking
from tests.mocks.config_mock import MockConfig


def test_empty_from_cli_params():
Expand Down Expand Up @@ -310,3 +314,102 @@ def test_exclude_statuses_filters():
["fail", "error", "runtime error", "warn"]
)
assert filter_schema.statuses[1].type == FilterType.IS


def test_models_with_graph_operators_from_cli_params():
with patch(
"elementary.clients.dbt.command_line_dbt_runner.CommandLineDbtRunner.ls"
) as mock_ls:
mock_ls.return_value = ["model.customers", "model.orders", "model.payments"]

config = MockConfig("mock_project_dir")
tracking = MockAnonymousTracking()

cli_filter = ("models:customers+",)
cli_excludes = ()
filter_schema = FiltersSchema.from_cli_params(
cli_filter, cli_excludes, config, tracking
)
assert len(filter_schema.tags) == 0
assert len(filter_schema.models) == 0
assert len(filter_schema.owners) == 0
assert len(filter_schema.node_names) == 3
assert sorted(filter_schema.node_names) == sorted(
["model.customers", "model.orders", "model.payments"]
)
assert len(filter_schema.statuses) == 1
assert sorted(filter_schema.statuses[0].values) == sorted(
["fail", "error", "runtime error", "warn"]
)
assert len(filter_schema.resource_types) == 0


def test_models_with_upstream_graph_operators_from_cli_params():
with patch(
"elementary.clients.dbt.command_line_dbt_runner.CommandLineDbtRunner.ls"
) as mock_ls:
mock_ls.return_value = [
"model.raw_customers",
"model.stg_customers",
"model.customers",
]

config = MockConfig("mock_project_dir")
tracking = MockAnonymousTracking()

cli_filter = ("models:+customers",)
cli_excludes = ()
filter_schema = FiltersSchema.from_cli_params(
cli_filter, cli_excludes, config, tracking
)
assert len(filter_schema.tags) == 0
assert len(filter_schema.models) == 0
assert len(filter_schema.owners) == 0
assert len(filter_schema.node_names) == 3
assert sorted(filter_schema.node_names) == sorted(
["model.raw_customers", "model.stg_customers", "model.customers"]
)
assert len(filter_schema.statuses) == 1
assert sorted(filter_schema.statuses[0].values) == sorted(
["fail", "error", "runtime error", "warn"]
)
assert len(filter_schema.resource_types) == 0


def test_models_without_graph_operators_from_cli_params_no_config():
cli_filter = ("models:customers+",)
cli_excludes = ()
filter_schema = FiltersSchema.from_cli_params(cli_filter, cli_excludes)
assert len(filter_schema.tags) == 0
assert len(filter_schema.models) == 1
assert filter_schema.models[0].values == ["customers+"]
assert len(filter_schema.owners) == 0
assert len(filter_schema.node_names) == 0
assert len(filter_schema.statuses) == 1
assert sorted(filter_schema.statuses[0].values) == sorted(
["fail", "error", "runtime error", "warn"]
)
assert len(filter_schema.resource_types) == 0


def test_exclude_models_with_graph_operators_from_cli_params():
"""Test that graph operators in excludes are NOT resolved to node_names"""
config = MockConfig("mock_project_dir")
tracking = MockAnonymousTracking()

cli_filter = ()
cli_excludes = ("models:customers+",)
filter_schema = FiltersSchema.from_cli_params(
cli_filter, cli_excludes, config, tracking
)
assert len(filter_schema.tags) == 0
assert len(filter_schema.models) == 1
assert filter_schema.models[0].values == ["customers+"]
assert filter_schema.models[0].type == FilterType.IS_NOT
assert len(filter_schema.owners) == 0
assert len(filter_schema.node_names) == 0
assert len(filter_schema.statuses) == 1
assert sorted(filter_schema.statuses[0].values) == sorted(
["fail", "error", "runtime error", "warn"]
)
assert len(filter_schema.resource_types) == 0
62 changes: 62 additions & 0 deletions tests/unit/monitor/data_monitoring/test_selector_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,65 @@ def dbt_runner_with_models_mock() -> Generator[MagicMock, None, None]:
) as mock_ls:
mock_ls.return_value = ["node_name_1", "node_name_2"]
yield mock_ls


def test_parse_selector_with_graph_operators_downstream(
dbt_runner_with_models_mock, anonymous_tracking_mock
):
config = MockConfig("mock_project_dir")

data_monitoring_filter = SelectorFilter(
tracking=anonymous_tracking_mock,
config=config,
selector="model:customers+",
)

assert data_monitoring_filter.get_filter().node_names == [
"node_name_1",
"node_name_2",
]
assert data_monitoring_filter.get_filter().selector == "model:customers+"


def test_parse_selector_with_graph_operators_upstream(
dbt_runner_with_models_mock, anonymous_tracking_mock
):
config = MockConfig("mock_project_dir")

data_monitoring_filter = SelectorFilter(
tracking=anonymous_tracking_mock,
config=config,
selector="model:+customers",
)

assert data_monitoring_filter.get_filter().node_names == [
"node_name_1",
"node_name_2",
]
assert data_monitoring_filter.get_filter().selector == "model:+customers"


def test_parse_selector_with_graph_operators_both(
dbt_runner_with_models_mock, anonymous_tracking_mock
):
config = MockConfig("mock_project_dir")

data_monitoring_filter = SelectorFilter(
tracking=anonymous_tracking_mock,
config=config,
selector="model:+customers+",
)

assert data_monitoring_filter.get_filter().node_names == [
"node_name_1",
"node_name_2",
]
assert data_monitoring_filter.get_filter().selector == "model:+customers+"


def test_has_graph_operators():
assert SelectorFilter._has_graph_operators("customers+") is True
assert SelectorFilter._has_graph_operators("+customers") is True
assert SelectorFilter._has_graph_operators("+customers+") is True
assert SelectorFilter._has_graph_operators("customers") is False
assert SelectorFilter._has_graph_operators("my_model") is False
Loading