diff --git a/integration_tests/tests/dbt_project.py b/integration_tests/tests/dbt_project.py index 3eba93fc2..a871f0f53 100644 --- a/integration_tests/tests/dbt_project.py +++ b/integration_tests/tests/dbt_project.py @@ -109,6 +109,7 @@ def test( materialization: str = "table", # Only relevant if as_model=True test_vars: Optional[dict] = None, elementary_enabled: bool = True, + model_config: Optional[Dict[str, Any]] = None, *, multiple_results: Literal[False] = False, ) -> Dict[str, Any]: @@ -128,6 +129,7 @@ def test( materialization: str = "table", # Only relevant if as_model=True test_vars: Optional[dict] = None, elementary_enabled: bool = True, + model_config: Optional[Dict[str, Any]] = None, *, multiple_results: Literal[True], ) -> List[Dict[str, Any]]: @@ -146,6 +148,7 @@ def test( materialization: str = "table", # Only relevant if as_model=True test_vars: Optional[dict] = None, elementary_enabled: bool = True, + model_config: Optional[Dict[str, Any]] = None, *, multiple_results: bool = False, ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: @@ -161,6 +164,9 @@ def test( test_args = test_args or {} table_yaml: Dict[str, Any] = {"name": test_id} + if model_config: + table_yaml.update(model_config) + if columns: table_yaml["columns"] = columns diff --git a/integration_tests/tests/test_sampling_pii.py b/integration_tests/tests/test_sampling_pii.py new file mode 100644 index 000000000..8b7a60874 --- /dev/null +++ b/integration_tests/tests/test_sampling_pii.py @@ -0,0 +1,110 @@ +import json + +import pytest +from dbt_project import DbtProject + +COLUMN_NAME = "value" + + +SAMPLES_QUERY = """ + with latest_elementary_test_result as ( + select id + from {{{{ ref("elementary_test_results") }}}} + where lower(table_name) = lower('{test_id}') + order by created_at desc, id desc + limit 1 + ) + + select result_row + from {{{{ ref("test_result_rows") }}}} + where elementary_test_results_id in (select * from latest_elementary_test_result) +""" + +TEST_SAMPLE_ROW_COUNT = 7 + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_sampling_pii_disabled(test_id: str, dbt_project: DbtProject): + """Test that PII-tagged tables don't upload samples even when tests fail""" + null_count = 50 + data = [{COLUMN_NAME: None} for _ in range(null_count)] + + test_result = dbt_project.test( + test_id, + "not_null", + dict(column_name=COLUMN_NAME), + data=data, + as_model=True, + model_config={"config": {"tags": ["pii"]}}, + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, + "disable_samples_on_pii_tables": True, + "pii_table_tags": ["pii", "sensitive"], + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + assert len(samples) == 0 + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_sampling_non_pii_enabled(test_id: str, dbt_project: DbtProject): + """Test that non-PII tables still collect samples normally""" + null_count = 50 + data = [{COLUMN_NAME: None} for _ in range(null_count)] + + test_result = dbt_project.test( + test_id, + "not_null", + dict(column_name=COLUMN_NAME), + data=data, + as_model=True, + model_config={"config": {"tags": ["normal"]}}, + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, + "disable_samples_on_pii_tables": True, + "pii_table_tags": ["pii", "sensitive"], + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + assert len(samples) == TEST_SAMPLE_ROW_COUNT + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_sampling_pii_feature_disabled(test_id: str, dbt_project: DbtProject): + """Test that when PII feature is disabled, PII tables still collect samples""" + null_count = 50 + data = [{COLUMN_NAME: None} for _ in range(null_count)] + + test_result = dbt_project.test( + test_id, + "not_null", + dict(column_name=COLUMN_NAME), + data=data, + as_model=True, + model_config={"config": {"tags": ["pii"]}}, + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, + "disable_samples_on_pii_tables": False, + "pii_table_tags": ["pii", "sensitive"], + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + assert len(samples) == TEST_SAMPLE_ROW_COUNT diff --git a/macros/edr/materializations/test/test.sql b/macros/edr/materializations/test/test.sql index a63f89f0a..abf585cc9 100644 --- a/macros/edr/materializations/test/test.sql +++ b/macros/edr/materializations/test/test.sql @@ -50,7 +50,11 @@ {% macro handle_dbt_test(flattened_test, materialization_macro) %} {% set result = materialization_macro() %} - {% set result_rows = elementary.query_test_result_rows(sample_limit=elementary.get_config_var('test_sample_row_count'), + {% set sample_limit = elementary.get_config_var('test_sample_row_count') %} + {% if elementary.is_pii_table(flattened_test) %} + {% set sample_limit = 0 %} + {% endif %} + {% set result_rows = elementary.query_test_result_rows(sample_limit=sample_limit, ignore_passed_tests=true) %} {% set elementary_test_results_row = elementary.get_dbt_test_result_row(flattened_test, result_rows) %} {% do elementary.cache_elementary_test_results_rows([elementary_test_results_row]) %} diff --git a/macros/edr/system/system_utils/get_config_var.sql b/macros/edr/system/system_utils/get_config_var.sql index deedd2f24..69d6577f8 100644 --- a/macros/edr/system/system_utils/get_config_var.sql +++ b/macros/edr/system/system_utils/get_config_var.sql @@ -64,7 +64,9 @@ }, 'include_other_warehouse_specific_columns': false, 'fail_on_zero': false, - 'anomaly_exclude_metrics': none + 'anomaly_exclude_metrics': none, + 'disable_samples_on_pii_tables': false, + 'pii_table_tags': ['pii'] } %} {{- return(default_config) -}} {%- endmacro -%} diff --git a/macros/edr/system/system_utils/is_pii_table.sql b/macros/edr/system/system_utils/is_pii_table.sql new file mode 100644 index 000000000..6fa3ad324 --- /dev/null +++ b/macros/edr/system/system_utils/is_pii_table.sql @@ -0,0 +1,14 @@ +{% macro is_pii_table(flattened_test) %} + {% set disable_samples_on_pii_tables = elementary.get_config_var('disable_samples_on_pii_tables') %} + {% if not disable_samples_on_pii_tables %} + {% do return(false) %} + {% endif %} + + {% set pii_table_tags = elementary.get_config_var('pii_table_tags') %} + {% set model_tags = elementary.insensitive_get_dict_value(flattened_test, 'model_tags', []) %} + + {% set intersection = elementary.lists_intersection(model_tags, pii_table_tags) %} + {% set is_pii = intersection | length > 0 %} + + {% do return(is_pii) %} +{% endmacro %}