Skip to content

Commit cd740d0

Browse files
committed
add unit tests for disambiguator and incorrect query
1 parent 0d89182 commit cd740d0

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

pydeequ/analyzers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ class CustomSql(_AnalyzerObject):
366366
The expression must return a single value.
367367
368368
:param str expression: A SQL expression to execute.
369-
:param str where: A label used to distinguish this metric
369+
:param str disambiguator: A label used to distinguish this metric
370370
when running multiple custom SQL analyzers. Defaults to "*".
371371
"""
372372

tests/test_analyzers.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from pyspark.sql import Row
6+
from pyspark.errors import AnalysisException
67

78
from pydeequ import PyDeequSession
89
from pydeequ.analyzers import (
@@ -119,7 +120,7 @@ def CustomSql(self, expression, disambiguator="*"):
119120
result_json = AnalyzerContext.successMetricsAsJson(self.spark, result)
120121
df_from_json = self.spark.read.json(self.sc.parallelize([result_json]))
121122
self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect())
122-
return result_df.select("value").collect()
123+
return result_df.select("value", "instance").collect()
123124

124125
def Datatype(self, column, where=None):
125126
result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run()
@@ -309,14 +310,25 @@ def test_fail_CountDistinct(self):
309310

310311
def test_CustomSql(self):
311312
self.df.createOrReplaceTempView("input_table")
312-
self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0)])
313-
self.assertEqual(self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table"), [Row(value=3.0)])
314-
self.assertEqual(self.CustomSql("SELECT MAX(c) FROM input_table"), [Row(value=6.0)])
313+
self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0, instance="*")])
314+
self.assertEqual(
315+
self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table", disambiguator="foo"),
316+
[Row(value=3.0, instance="foo")]
317+
)
318+
self.assertEqual(
319+
self.CustomSql("SELECT MAX(c) FROM input_table", disambiguator="bar"),
320+
[Row(value=6.0, instance="bar")]
321+
)
315322

316323
@pytest.mark.xfail(reason="@unittest.expectedFailure")
317324
def test_fail_CustomSql(self):
318325
self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)])
319326

327+
@pytest.mark.xfail(reason="@unittest.expectedFailure")
328+
def test_fail_CustomSql_incorrect_query(self):
329+
with self.assertRaises(AnalysisException):
330+
self.CustomSql("SELECT SUM(b)")
331+
320332
def test_DataType(self):
321333
self.assertEqual(
322334
self.Datatype("b"),

0 commit comments

Comments
 (0)