diff --git a/elementary/monitor/cli.py b/elementary/monitor/cli.py index 2b786786e..ef87a0c90 100644 --- a/elementary/monitor/cli.py +++ b/elementary/monitor/cli.py @@ -144,10 +144,12 @@ def decorator(func): "--select", type=str, default=None, - help="Filter the report by last_invocation / invocation_id: / invocation_time:." - if cmd in (Command.REPORT, Command.SEND_REPORT) - else "DEPRECATED! Please use --filters instead! - Filter the alerts by tags: / owners: / models: / " - "statuses: / resource_types:.", + help=( + "Filter the report by last_invocation / invocation_id: / invocation_time:." + if cmd in (Command.REPORT, Command.SEND_REPORT) + else "DEPRECATED! Please use --filters instead! - Filter the alerts by tags: / owners: / models: / " + "statuses: / resource_types:." + ), )(func) return func @@ -364,7 +366,9 @@ def monitor( alert_filters = FiltersSchema() if bool(filters) or bool(excludes): - alert_filters = FiltersSchema.from_cli_params(filters, excludes) + alert_filters = FiltersSchema.from_cli_params( + filters, excludes, config, anonymous_tracking + ) elif select is not None: click.secho( '\n"--select" is deprecated and won\'t be supported in the near future.\n' diff --git a/elementary/monitor/data_monitoring/schema.py b/elementary/monitor/data_monitoring/schema.py index 25bc536fd..55691c650 100644 --- a/elementary/monitor/data_monitoring/schema.py +++ b/elementary/monitor/data_monitoring/schema.py @@ -190,8 +190,13 @@ def validate_report_selector(self) -> None: @staticmethod def from_cli_params( - cli_filters: Tuple[str], cli_excludes: Tuple[str] + cli_filters: Tuple[str], + cli_excludes: Tuple[str], + config: Optional[Any] = None, + tracking: Optional[Any] = None, ) -> "FiltersSchema": + from elementary.monitor.data_monitoring.selector_filter import SelectorFilter + all_filters: list[tuple[str, FilterType]] = [] for cli_filter in cli_filters: all_filters.append((cli_filter, FilterType.IS)) @@ -206,6 +211,7 @@ def from_cli_params( models = [] statuses = [] resource_types = [] + node_names = [] for cli_filter, filter_type in all_filters: tags_match = FiltersSchema._match_filter_regex( @@ -226,7 +232,22 @@ def from_cli_params( filter_string=cli_filter, regex=re.compile(r"models:(.*)") ) if models_match: - models.append(FilterSchema(values=models_match, type=filter_type)) + model_value = ( + models_match[0] + if len(models_match) == 1 + else ",".join(models_match) + ) + if ( + config + and filter_type == FilterType.IS + and SelectorFilter._has_graph_operators(model_value) + ): + selector_filter = SelectorFilter(config, tracking, model_value) + filter_result = selector_filter.get_filter() + if filter_result.node_names: + node_names.extend(filter_result.node_names) + else: + models.append(FilterSchema(values=models_match, type=filter_type)) continue statuses_match = FiltersSchema._match_filter_regex( @@ -269,6 +290,7 @@ def from_cli_params( models=models, statuses=statuses, resource_types=resource_types, + node_names=node_names, ) @staticmethod diff --git a/elementary/monitor/data_monitoring/selector_filter.py b/elementary/monitor/data_monitoring/selector_filter.py index 4c976c7a6..3e60ebf62 100644 --- a/elementary/monitor/data_monitoring/selector_filter.py +++ b/elementary/monitor/data_monitoring/selector_filter.py @@ -99,12 +99,25 @@ def _parse_selector(self, selector: Optional[str] = None) -> FiltersSchema: selector=selector, ) elif model_match: - if self.tracking: - self.tracking.set_env("select_method", "model") - data_monitoring_filter = FiltersSchema( - models=[FilterSchema(values=[model_match.group(1)])], - selector=selector, - ) + model_value = model_match.group(1) + if self.selector_fetcher and self._has_graph_operators(model_value): + if self.tracking: + self.tracking.set_env( + "select_method", "model with graph operators" + ) + node_names = self.selector_fetcher.get_selector_results( + selector=model_value + ) + data_monitoring_filter = FiltersSchema( + node_names=node_names, selector=selector + ) + else: + if self.tracking: + self.tracking.set_env("select_method", "model") + data_monitoring_filter = FiltersSchema( + models=[FilterSchema(values=[model_value])], + selector=selector, + ) elif statuses_match: if self.tracking: self.tracking.set_env("select_method", "statuses") @@ -148,6 +161,10 @@ def _create_user_dbt_runner(self, config: Config) -> Optional[BaseDbtRunner]: def get_filter(self) -> FiltersSchema: return self.filter + @staticmethod + def _has_graph_operators(selector: str) -> bool: + return "+" in selector + @staticmethod def _can_use_fetcher(selector): non_dbt_selectors = [ diff --git a/tests/unit/monitor/data_monitoring/test_filters_schema.py b/tests/unit/monitor/data_monitoring/test_filters_schema.py index f729c689f..97ee11018 100644 --- a/tests/unit/monitor/data_monitoring/test_filters_schema.py +++ b/tests/unit/monitor/data_monitoring/test_filters_schema.py @@ -1,6 +1,10 @@ +from unittest.mock import patch + import pytest from elementary.monitor.data_monitoring.schema import FiltersSchema, FilterType +from tests.mocks.anonymous_tracking_mock import MockAnonymousTracking +from tests.mocks.config_mock import MockConfig def test_empty_from_cli_params(): @@ -310,3 +314,102 @@ def test_exclude_statuses_filters(): ["fail", "error", "runtime error", "warn"] ) assert filter_schema.statuses[1].type == FilterType.IS + + +def test_models_with_graph_operators_from_cli_params(): + with patch( + "elementary.clients.dbt.command_line_dbt_runner.CommandLineDbtRunner.ls" + ) as mock_ls: + mock_ls.return_value = ["model.customers", "model.orders", "model.payments"] + + config = MockConfig("mock_project_dir") + tracking = MockAnonymousTracking() + + cli_filter = ("models:customers+",) + cli_excludes = () + filter_schema = FiltersSchema.from_cli_params( + cli_filter, cli_excludes, config, tracking + ) + assert len(filter_schema.tags) == 0 + assert len(filter_schema.models) == 0 + assert len(filter_schema.owners) == 0 + assert len(filter_schema.node_names) == 3 + assert sorted(filter_schema.node_names) == sorted( + ["model.customers", "model.orders", "model.payments"] + ) + assert len(filter_schema.statuses) == 1 + assert sorted(filter_schema.statuses[0].values) == sorted( + ["fail", "error", "runtime error", "warn"] + ) + assert len(filter_schema.resource_types) == 0 + + +def test_models_with_upstream_graph_operators_from_cli_params(): + with patch( + "elementary.clients.dbt.command_line_dbt_runner.CommandLineDbtRunner.ls" + ) as mock_ls: + mock_ls.return_value = [ + "model.raw_customers", + "model.stg_customers", + "model.customers", + ] + + config = MockConfig("mock_project_dir") + tracking = MockAnonymousTracking() + + cli_filter = ("models:+customers",) + cli_excludes = () + filter_schema = FiltersSchema.from_cli_params( + cli_filter, cli_excludes, config, tracking + ) + assert len(filter_schema.tags) == 0 + assert len(filter_schema.models) == 0 + assert len(filter_schema.owners) == 0 + assert len(filter_schema.node_names) == 3 + assert sorted(filter_schema.node_names) == sorted( + ["model.raw_customers", "model.stg_customers", "model.customers"] + ) + assert len(filter_schema.statuses) == 1 + assert sorted(filter_schema.statuses[0].values) == sorted( + ["fail", "error", "runtime error", "warn"] + ) + assert len(filter_schema.resource_types) == 0 + + +def test_models_without_graph_operators_from_cli_params_no_config(): + cli_filter = ("models:customers+",) + cli_excludes = () + filter_schema = FiltersSchema.from_cli_params(cli_filter, cli_excludes) + assert len(filter_schema.tags) == 0 + assert len(filter_schema.models) == 1 + assert filter_schema.models[0].values == ["customers+"] + assert len(filter_schema.owners) == 0 + assert len(filter_schema.node_names) == 0 + assert len(filter_schema.statuses) == 1 + assert sorted(filter_schema.statuses[0].values) == sorted( + ["fail", "error", "runtime error", "warn"] + ) + assert len(filter_schema.resource_types) == 0 + + +def test_exclude_models_with_graph_operators_from_cli_params(): + """Test that graph operators in excludes are NOT resolved to node_names""" + config = MockConfig("mock_project_dir") + tracking = MockAnonymousTracking() + + cli_filter = () + cli_excludes = ("models:customers+",) + filter_schema = FiltersSchema.from_cli_params( + cli_filter, cli_excludes, config, tracking + ) + assert len(filter_schema.tags) == 0 + assert len(filter_schema.models) == 1 + assert filter_schema.models[0].values == ["customers+"] + assert filter_schema.models[0].type == FilterType.IS_NOT + assert len(filter_schema.owners) == 0 + assert len(filter_schema.node_names) == 0 + assert len(filter_schema.statuses) == 1 + assert sorted(filter_schema.statuses[0].values) == sorted( + ["fail", "error", "runtime error", "warn"] + ) + assert len(filter_schema.resource_types) == 0 diff --git a/tests/unit/monitor/data_monitoring/test_selector_filter.py b/tests/unit/monitor/data_monitoring/test_selector_filter.py index b9f1cf78c..cc7e53c8d 100644 --- a/tests/unit/monitor/data_monitoring/test_selector_filter.py +++ b/tests/unit/monitor/data_monitoring/test_selector_filter.py @@ -265,3 +265,65 @@ def dbt_runner_with_models_mock() -> Generator[MagicMock, None, None]: ) as mock_ls: mock_ls.return_value = ["node_name_1", "node_name_2"] yield mock_ls + + +def test_parse_selector_with_graph_operators_downstream( + dbt_runner_with_models_mock, anonymous_tracking_mock +): + config = MockConfig("mock_project_dir") + + data_monitoring_filter = SelectorFilter( + tracking=anonymous_tracking_mock, + config=config, + selector="model:customers+", + ) + + assert data_monitoring_filter.get_filter().node_names == [ + "node_name_1", + "node_name_2", + ] + assert data_monitoring_filter.get_filter().selector == "model:customers+" + + +def test_parse_selector_with_graph_operators_upstream( + dbt_runner_with_models_mock, anonymous_tracking_mock +): + config = MockConfig("mock_project_dir") + + data_monitoring_filter = SelectorFilter( + tracking=anonymous_tracking_mock, + config=config, + selector="model:+customers", + ) + + assert data_monitoring_filter.get_filter().node_names == [ + "node_name_1", + "node_name_2", + ] + assert data_monitoring_filter.get_filter().selector == "model:+customers" + + +def test_parse_selector_with_graph_operators_both( + dbt_runner_with_models_mock, anonymous_tracking_mock +): + config = MockConfig("mock_project_dir") + + data_monitoring_filter = SelectorFilter( + tracking=anonymous_tracking_mock, + config=config, + selector="model:+customers+", + ) + + assert data_monitoring_filter.get_filter().node_names == [ + "node_name_1", + "node_name_2", + ] + assert data_monitoring_filter.get_filter().selector == "model:+customers+" + + +def test_has_graph_operators(): + assert SelectorFilter._has_graph_operators("customers+") is True + assert SelectorFilter._has_graph_operators("+customers") is True + assert SelectorFilter._has_graph_operators("+customers+") is True + assert SelectorFilter._has_graph_operators("customers") is False + assert SelectorFilter._has_graph_operators("my_model") is False