Skip to content

Commit 6841531

Browse files
Add disable_samples column configuration flag
- Add is_sampling_disabled_for_column macro to check column config - Modify query_test_result_rows to skip sampling when disable_samples=true - Add integration tests for disable_samples functionality - Test prevents sampling, allows normal sampling, and overrides PII tags Co-Authored-By: Yosef Arbiv <[email protected]>
1 parent dc8fba5 commit 6841531

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import json
2+
3+
import pytest
4+
from dbt_project import DbtProject
5+
6+
COLUMN_NAME = "sensitive_data"
7+
8+
SAMPLES_QUERY = """
9+
with latest_elementary_test_result as (
10+
select id
11+
from {{ ref("elementary_test_results") }}
12+
where lower(table_name) = lower('{test_id}')
13+
order by created_at desc
14+
limit 1
15+
)
16+
17+
select result_row
18+
from {{ ref("test_result_rows") }}
19+
where elementary_test_results_id in (select * from latest_elementary_test_result)
20+
"""
21+
22+
23+
@pytest.mark.skip_targets(["clickhouse"])
24+
def test_disable_samples_config_prevents_sampling(
25+
test_id: str, dbt_project: DbtProject
26+
):
27+
null_count = 20
28+
data = [{COLUMN_NAME: None} for _ in range(null_count)]
29+
30+
columns = [
31+
{
32+
"name": COLUMN_NAME,
33+
"config": {"disable_samples": True},
34+
"tests": [{"not_null": {}}],
35+
}
36+
]
37+
38+
test_result = dbt_project.test(
39+
test_id,
40+
"not_null",
41+
columns=columns,
42+
data=data,
43+
test_vars={
44+
"enable_elementary_test_materialization": True,
45+
"test_sample_row_count": 5,
46+
},
47+
)
48+
assert test_result["status"] == "fail"
49+
50+
samples = [
51+
json.loads(row["result_row"])
52+
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
53+
]
54+
assert len(samples) == 0
55+
56+
57+
@pytest.mark.skip_targets(["clickhouse"])
58+
def test_disable_samples_false_allows_sampling(test_id: str, dbt_project: DbtProject):
59+
null_count = 20
60+
data = [{COLUMN_NAME: None} for _ in range(null_count)]
61+
62+
columns = [
63+
{
64+
"name": COLUMN_NAME,
65+
"config": {"disable_samples": False},
66+
"tests": [{"not_null": {}}],
67+
}
68+
]
69+
70+
test_result = dbt_project.test(
71+
test_id,
72+
"not_null",
73+
columns=columns,
74+
data=data,
75+
test_vars={
76+
"enable_elementary_test_materialization": True,
77+
"test_sample_row_count": 5,
78+
},
79+
)
80+
assert test_result["status"] == "fail"
81+
82+
samples = [
83+
json.loads(row["result_row"])
84+
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
85+
]
86+
assert len(samples) == 5
87+
assert all([row == {COLUMN_NAME: None} for row in samples])
88+
89+
90+
@pytest.mark.skip_targets(["clickhouse"])
91+
def test_disable_samples_config_overrides_pii_tags(
92+
test_id: str, dbt_project: DbtProject
93+
):
94+
null_count = 20
95+
data = [{COLUMN_NAME: None} for _ in range(null_count)]
96+
97+
columns = [
98+
{
99+
"name": COLUMN_NAME,
100+
"config": {"disable_samples": True, "tags": ["pii"]},
101+
"tests": [{"not_null": {}}],
102+
}
103+
]
104+
105+
test_result = dbt_project.test(
106+
test_id,
107+
"not_null",
108+
columns=columns,
109+
data=data,
110+
test_vars={
111+
"enable_elementary_test_materialization": True,
112+
"test_sample_row_count": 5,
113+
},
114+
)
115+
assert test_result["status"] == "fail"
116+
117+
samples = [
118+
json.loads(row["result_row"])
119+
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
120+
]
121+
assert len(samples) == 0

macros/edr/materializations/test/test.sql

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@
111111
{% do elementary.debug_log("Skipping sample query because the test passed.") %}
112112
{% do return([]) %}
113113
{% endif %}
114+
115+
{% if elementary.is_sampling_disabled_for_column(elementary.flatten_test(model)) %}
116+
{% do elementary.debug_log("Skipping sample query because disable_samples is true for this column.") %}
117+
{% do return([]) %}
118+
{% endif %}
119+
114120
{% set query %}
115121
with test_results as (
116122
{{ sql }}
@@ -120,6 +126,24 @@
120126
{% do return(elementary.agate_to_dicts(elementary.run_query(query))) %}
121127
{% endmacro %}
122128

129+
{% macro is_sampling_disabled_for_column(flattened_test) %}
130+
{% set test_column_name = elementary.insensitive_get_dict_value(flattened_test, 'test_column_name') %}
131+
{% set parent_model_unique_id = elementary.insensitive_get_dict_value(flattened_test, 'parent_model_unique_id') %}
132+
133+
{% if not test_column_name or not parent_model_unique_id %}
134+
{% do return(false) %}
135+
{% endif %}
136+
137+
{% set parent_model = elementary.get_node(parent_model_unique_id) %}
138+
{% if parent_model and parent_model.get('columns') %}
139+
{% set column_config = parent_model.get('columns', {}).get(test_column_name, {}).get('config', {}) %}
140+
{% set disable_samples = elementary.safe_get_with_default(column_config, 'disable_samples', false) %}
141+
{% do return(disable_samples) %}
142+
{% endif %}
143+
144+
{% do return(false) %}
145+
{% endmacro %}
146+
123147
{% macro cache_elementary_test_results_rows(elementary_test_results_rows) %}
124148
{% do elementary.get_cache("elementary_test_results").update({model.unique_id: elementary_test_results_rows}) %}
125149
{% endmacro %}

0 commit comments

Comments
 (0)