diff --git a/integration_tests/tests/test_disable_samples_config.py b/integration_tests/tests/test_disable_samples_config.py new file mode 100644 index 000000000..1e8467eb3 --- /dev/null +++ b/integration_tests/tests/test_disable_samples_config.py @@ -0,0 +1,121 @@ +import json + +import pytest +from dbt_project import DbtProject + +COLUMN_NAME = "sensitive_data" + +SAMPLES_QUERY = """ + with latest_elementary_test_result as ( + select id + from {{ ref("elementary_test_results") }} + where lower(table_name) = lower('{test_id}') + order by created_at desc + limit 1 + ) + + select result_row + from {{ ref("test_result_rows") }} + where elementary_test_results_id in (select * from latest_elementary_test_result) +""" + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_disable_samples_config_prevents_sampling( + test_id: str, dbt_project: DbtProject +): + null_count = 20 + data = [{COLUMN_NAME: None} for _ in range(null_count)] + + columns = [ + { + "name": COLUMN_NAME, + "config": {"disable_samples": True}, + "tests": [{"not_null": {}}], + } + ] + + test_result = dbt_project.test( + test_id, + "not_null", + columns=columns, + data=data, + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": 5, + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + assert len(samples) == 0 + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_disable_samples_false_allows_sampling(test_id: str, dbt_project: DbtProject): + null_count = 20 + data = [{COLUMN_NAME: None} for _ in range(null_count)] + + columns = [ + { + "name": COLUMN_NAME, + "config": {"disable_samples": False}, + "tests": [{"not_null": {}}], + } + ] + + test_result = dbt_project.test( + test_id, + "not_null", + columns=columns, + data=data, + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": 5, + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + assert len(samples) == 5 + assert all([row == {COLUMN_NAME: None} for row in samples]) + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_disable_samples_config_overrides_pii_tags( + test_id: str, dbt_project: DbtProject +): + null_count = 20 + data = [{COLUMN_NAME: None} for _ in range(null_count)] + + columns = [ + { + "name": COLUMN_NAME, + "config": {"disable_samples": True, "tags": ["pii"]}, + "tests": [{"not_null": {}}], + } + ] + + test_result = dbt_project.test( + test_id, + "not_null", + columns=columns, + data=data, + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": 5, + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + assert len(samples) == 0 diff --git a/macros/edr/materializations/test/test.sql b/macros/edr/materializations/test/test.sql index a63f89f0a..bcbb58831 100644 --- a/macros/edr/materializations/test/test.sql +++ b/macros/edr/materializations/test/test.sql @@ -111,6 +111,12 @@ {% do elementary.debug_log("Skipping sample query because the test passed.") %} {% do return([]) %} {% endif %} + + {% if elementary.is_sampling_disabled_for_column(elementary.flatten_test(model)) %} + {% do elementary.debug_log("Skipping sample query because disable_samples is true for this column.") %} + {% do return([]) %} + {% endif %} + {% set query %} with test_results as ( {{ sql }} @@ -120,6 +126,24 @@ {% do return(elementary.agate_to_dicts(elementary.run_query(query))) %} {% endmacro %} +{% macro is_sampling_disabled_for_column(flattened_test) %} + {% set test_column_name = elementary.insensitive_get_dict_value(flattened_test, 'test_column_name') %} + {% set parent_model_unique_id = elementary.insensitive_get_dict_value(flattened_test, 'parent_model_unique_id') %} + + {% if not test_column_name or not parent_model_unique_id %} + {% do return(false) %} + {% endif %} + + {% set parent_model = elementary.get_node(parent_model_unique_id) %} + {% if parent_model and parent_model.get('columns') %} + {% set column_config = parent_model.get('columns', {}).get(test_column_name, {}).get('config', {}) %} + {% set disable_samples = elementary.safe_get_with_default(column_config, 'disable_samples', false) %} + {% do return(disable_samples) %} + {% endif %} + + {% do return(false) %} +{% endmacro %} + {% macro cache_elementary_test_results_rows(elementary_test_results_rows) %} {% do elementary.get_cache("elementary_test_results").update({model.unique_id: elementary_test_results_rows}) %} {% endmacro %}