Skip to content

Commit ade2ddd

Browse files
authored
Merge pull request #790 from elementary-data/unstructured_data_tests
Added unstructred data validation tests for Snowflake, Databricks and…
2 parents f1602cf + d8c2767 commit ade2ddd

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
{% test ai_data_validation(model, column_name, expectation_prompt, llm_model_name=none, prompt_context='') %}
2+
{{ config(tags = ['elementary-tests']) }}
3+
{%- if execute and elementary.is_test_command() and elementary.is_elementary_enabled() %}
4+
{% set model_relation = elementary.get_model_relation_for_test(model, context["model"]) %}
5+
{% if not model_relation %}
6+
{{ exceptions.raise_compiler_error("Unsupported model: " ~ model ~ " (this might happen if you override 'ref' or 'source')") }}
7+
{% endif %}
8+
9+
{%- set full_table_name = elementary.relation_to_full_name(model_relation) %}
10+
11+
{# Prompt to supply to the LLM #}
12+
{% set prompt_context_part = prompt_context ~ " " if prompt_context else "" %}
13+
{% set prompt_template = "You are a data validator that should reply with string true if the expectation is met or the string false otherwise. " ~ prompt_context_part ~ "You got the following expectation: " ~ expectation_prompt ~ ". Your only role is to determine if the following text meets this expectation: " %}
14+
15+
{{ elementary.generate_ai_data_validation_sql(model, column_name, prompt_template, llm_model_name) }}
16+
17+
{%- else %}
18+
19+
{#- test must run an sql query -#}
20+
{{ elementary.no_results_query() }}
21+
22+
{%- endif %}
23+
{% endtest %}
24+
25+
26+
{% macro generate_ai_data_validation_sql(model, column_name, prompt_template, llm_model_name) %}
27+
{{ return(adapter.dispatch('generate_ai_data_validation_sql', 'elementary')(model, column_name, prompt_template, llm_model_name)) }}
28+
{% endmacro %}
29+
30+
{% macro default__generate_ai_data_validation_sql(model, column_name, prompt_template, llm_model_name) %}
31+
{{ exceptions.raise_compiler_error("AI data validation is not supported for target: " ~ target.type) }}
32+
{% endmacro %}
33+
34+
{% macro snowflake__generate_ai_data_validation_sql(model, column_name, prompt_template, llm_model_name) %}
35+
{% set default_snowflake_model_name = 'claude-3-5-sonnet' %}
36+
{% set chosen_llm_model_name = llm_model_name if llm_model_name is not none and llm_model_name|trim != '' else default_snowflake_model_name %}
37+
38+
with ai_data_validation_results as (
39+
select
40+
snowflake.cortex.complete(
41+
'{{ chosen_llm_model_name }}',
42+
concat('{{ prompt_template }}', {{ column_name }}::text)
43+
) as result
44+
from {{ model }}
45+
)
46+
47+
select *
48+
from ai_data_validation_results
49+
where lower(result) like '%false%'
50+
{% endmacro %}
51+
52+
{% macro databricks__generate_ai_data_validation_sql(model, column_name, prompt_template, llm_model_name) %}
53+
{% set default_databricks_model_name = 'databricks-meta-llama-3-3-70b-instruct' %}
54+
{% set chosen_llm_model_name = llm_model_name if llm_model_name is not none and llm_model_name|trim != '' else default_databricks_model_name %}
55+
56+
with ai_data_validation_results as (
57+
select
58+
ai_query(
59+
'{{ chosen_llm_model_name }}',
60+
concat('{{ prompt_template }}', cast({{ column_name }} as string))
61+
) as result
62+
from {{ model }}
63+
)
64+
65+
select *
66+
from ai_data_validation_results
67+
where lower(result) like '%false%'
68+
{% endmacro %}
69+
70+
71+
{% macro bigquery__generate_ai_data_validation_sql(model, column_name, prompt_template, llm_model_name) %}
72+
{% set default_bigquery_model_name = 'gemini-1.5-pro' %}
73+
{% set chosen_llm_model_name = llm_model_name if llm_model_name is not none and llm_model_name|trim != '' else default_bigquery_model_name %}
74+
75+
with ai_data_validation_results as (
76+
SELECT ml_generate_text_llm_result as result
77+
FROM
78+
ML.GENERATE_TEXT(
79+
MODEL `{{model.schema}}.{{chosen_llm_model_name}}`,
80+
(
81+
SELECT
82+
CONCAT(
83+
'{{ prompt_template }}',
84+
{{column_name}}) AS prompt
85+
FROM {{model}}),
86+
STRUCT(TRUE AS flatten_json_output))
87+
)
88+
89+
select *
90+
from ai_data_validation_results
91+
where lower(result) like '%false%'
92+
{% endmacro %}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{% test unstructured_data_validation(model, column_name, expectation_prompt, llm_model_name=none) %}
2+
{{ config(tags = ['elementary-tests']) }}
3+
{% set prompt_context = "You are a data validator specializing in validating unstructured data." %}
4+
{{ return(elementary.test_ai_data_validation(model, column_name, expectation_prompt, llm_model_name, prompt_context)) }}
5+
{% endtest %}

0 commit comments

Comments
 (0)