Skip to content

Commit 0d89182

Browse files
committed
add customsql analyzer
1 parent ca8e9e1 commit 0d89182

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

pydeequ/analyzers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,30 @@ def _analyzer_jvm(self):
360360
return self._deequAnalyzers.CountDistinct(to_scala_seq(self._jvm, self.columns))
361361

362362

363+
class CustomSql(_AnalyzerObject):
364+
"""
365+
A custom SQL-based analyzer executing provided SQL expression.
366+
The expression must return a single value.
367+
368+
:param str expression: A SQL expression to execute.
369+
:param str where: A label used to distinguish this metric
370+
when running multiple custom SQL analyzers. Defaults to "*".
371+
"""
372+
373+
def __init__(self, expression: str, disambiguator: str = "*"):
374+
self.expression = expression
375+
self.disambiguator = disambiguator
376+
377+
@property
378+
def _analyzer_jvm(self):
379+
"""
380+
Returns the result of SQL expression execution.
381+
382+
:return self
383+
"""
384+
return self._deequAnalyzers.CustomSql(self.expression, self.disambiguator)
385+
386+
363387
class DataType(_AnalyzerObject):
364388
"""
365389
Data Type Analyzer. Returns the datatypes of column

tests/test_analyzers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Compliance,
1515
Correlation,
1616
CountDistinct,
17+
CustomSql,
1718
DataType,
1819
Distinctness,
1920
Entropy,
@@ -111,6 +112,14 @@ def CountDistinct(self, columns):
111112
df_from_json = self.spark.read.json(self.sc.parallelize([result_json]))
112113
self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect())
113114
return result_df.select("value").collect()
115+
116+
def CustomSql(self, expression, disambiguator="*"):
117+
result = self.AnalysisRunner.onData(self.df).addAnalyzer(CustomSql(expression, disambiguator)).run()
118+
result_df = AnalyzerContext.successMetricsAsDataFrame(self.spark, result)
119+
result_json = AnalyzerContext.successMetricsAsJson(self.spark, result)
120+
df_from_json = self.spark.read.json(self.sc.parallelize([result_json]))
121+
self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect())
122+
return result_df.select("value").collect()
114123

115124
def Datatype(self, column, where=None):
116125
result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run()
@@ -298,6 +307,16 @@ def test_CountDistinct(self):
298307
def test_fail_CountDistinct(self):
299308
self.assertEqual(self.CountDistinct("b"), [Row(value=1.0)])
300309

310+
def test_CustomSql(self):
311+
self.df.createOrReplaceTempView("input_table")
312+
self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0)])
313+
self.assertEqual(self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table"), [Row(value=3.0)])
314+
self.assertEqual(self.CustomSql("SELECT MAX(c) FROM input_table"), [Row(value=6.0)])
315+
316+
@pytest.mark.xfail(reason="@unittest.expectedFailure")
317+
def test_fail_CustomSql(self):
318+
self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)])
319+
301320
def test_DataType(self):
302321
self.assertEqual(
303322
self.Datatype("b"),

0 commit comments

Comments
 (0)