Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions integration_tests/tests/test_disable_samples_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json

import pytest
from dbt_project import DbtProject

COLUMN_NAME = "sensitive_data"

SAMPLES_QUERY = """
with latest_elementary_test_result as (
select id
from {{ ref("elementary_test_results") }}
where lower(table_name) = lower('{test_id}')
order by created_at desc
limit 1
)

select result_row
from {{ ref("test_result_rows") }}
where elementary_test_results_id in (select * from latest_elementary_test_result)
"""


@pytest.mark.skip_targets(["clickhouse"])
def test_disable_samples_config_prevents_sampling(
test_id: str, dbt_project: DbtProject
):
null_count = 20
data = [{COLUMN_NAME: None} for _ in range(null_count)]

columns = [
{
"name": COLUMN_NAME,
"config": {"disable_samples": True},
"tests": [{"not_null": {}}],
}
]

test_result = dbt_project.test(
test_id,
"not_null",
columns=columns,
data=data,
test_vars={
"enable_elementary_test_materialization": True,
"test_sample_row_count": 5,
},
)
assert test_result["status"] == "fail"

samples = [
json.loads(row["result_row"])
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
]
assert len(samples) == 0


@pytest.mark.skip_targets(["clickhouse"])
def test_disable_samples_false_allows_sampling(test_id: str, dbt_project: DbtProject):
null_count = 20
data = [{COLUMN_NAME: None} for _ in range(null_count)]

columns = [
{
"name": COLUMN_NAME,
"config": {"disable_samples": False},
"tests": [{"not_null": {}}],
}
]

test_result = dbt_project.test(
test_id,
"not_null",
columns=columns,
data=data,
test_vars={
"enable_elementary_test_materialization": True,
"test_sample_row_count": 5,
},
)
assert test_result["status"] == "fail"

samples = [
json.loads(row["result_row"])
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
]
assert len(samples) == 5
assert all([row == {COLUMN_NAME: None} for row in samples])


@pytest.mark.skip_targets(["clickhouse"])
def test_disable_samples_config_overrides_pii_tags(
test_id: str, dbt_project: DbtProject
):
null_count = 20
data = [{COLUMN_NAME: None} for _ in range(null_count)]

columns = [
{
"name": COLUMN_NAME,
"config": {"disable_samples": True, "tags": ["pii"]},
"tests": [{"not_null": {}}],
}
]

test_result = dbt_project.test(
test_id,
"not_null",
columns=columns,
data=data,
test_vars={
"enable_elementary_test_materialization": True,
"test_sample_row_count": 5,
},
)
assert test_result["status"] == "fail"

samples = [
json.loads(row["result_row"])
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
]
assert len(samples) == 0
24 changes: 24 additions & 0 deletions macros/edr/materializations/test/test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@
{% do elementary.debug_log("Skipping sample query because the test passed.") %}
{% do return([]) %}
{% endif %}

{% if elementary.is_sampling_disabled_for_column(elementary.flatten_test(model)) %}
{% do elementary.debug_log("Skipping sample query because disable_samples is true for this column.") %}
{% do return([]) %}
{% endif %}

{% set query %}
with test_results as (
{{ sql }}
Expand All @@ -120,6 +126,24 @@
{% do return(elementary.agate_to_dicts(elementary.run_query(query))) %}
{% endmacro %}

{% macro is_sampling_disabled_for_column(flattened_test) %}
{% set test_column_name = elementary.insensitive_get_dict_value(flattened_test, 'test_column_name') %}
{% set parent_model_unique_id = elementary.insensitive_get_dict_value(flattened_test, 'parent_model_unique_id') %}

{% if not test_column_name or not parent_model_unique_id %}
{% do return(false) %}
{% endif %}

{% set parent_model = elementary.get_node(parent_model_unique_id) %}
{% if parent_model and parent_model.get('columns') %}
{% set column_config = parent_model.get('columns', {}).get(test_column_name, {}).get('config', {}) %}
{% set disable_samples = elementary.safe_get_with_default(column_config, 'disable_samples', false) %}
{% do return(disable_samples) %}
{% endif %}

{% do return(false) %}
{% endmacro %}

{% macro cache_elementary_test_results_rows(elementary_test_results_rows) %}
{% do elementary.get_cache("elementary_test_results").update({model.unique_id: elementary_test_results_rows}) %}
{% endmacro %}
Loading