Skip to content

Commit 0c0e1e4

Browse files
andreimekandrzej.nescior
andauthored
Feat: support where filter (#176)
* added support to where method and tests * corrected method * added tests and changed flow structure * corrected failing test --------- Co-authored-by: andrzej.nescior <[email protected]>
1 parent e74e974 commit 0c0e1e4

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

pydeequ/checks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from enum import Enum
33

4+
from py4j.protocol import Py4JError
45
from pyspark.sql import SparkSession
56

67
from pydeequ.check_functions import is_one
@@ -116,6 +117,13 @@ def addConstraint(self, constraint):
116117
self.constraints.append(constraint)
117118
self._Check = constraint._Check
118119

120+
def where(self, filter: str):
121+
try:
122+
self._Check = self._Check.where(filter)
123+
except Py4JError:
124+
raise TypeError(f"Method doesn't exist in {self._Check.getClass()}, class has to be filterable")
125+
return self
126+
119127
def addFilterableContstraint(self, creationFunc):
120128
"""Adds a constraint that can subsequently be replaced with a filtered version
121129
:param creationFunc:

tests/test_checks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@ def hasNumberOfDistinctValues(self, column, assertion, binningUdf, maxBins, hint
467467
df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
468468
return df.select("constraint_status").collect()
469469

470+
def where(self, assertion, filter, hint=None):
471+
check = Check(self.spark, CheckLevel.Warning, "test where")
472+
result = VerificationSuite(self.spark).onData(self.df).addCheck(check.hasSize(assertion, hint).where(filter)).run()
473+
df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
474+
return df.select("constraint_status").collect()
475+
470476
def test_hasSize(self):
471477
self.assertEqual(self.hasSize(lambda x: x == 3.0), [Row(constraint_status="Success")])
472478
self.assertEqual(
@@ -1245,6 +1251,41 @@ def test_fail_isGreaterThanOrEqualTo(self):
12451251
)
12461252
self.assertEqual(self.isGreaterThanOrEqualTo("h", "f", lambda x: x == 1), [Row(constraint_status="Success")])
12471253

1254+
def test_where(self):
1255+
self.assertEqual(self.where(lambda x: x == 2.0, "boolean='true'", "column 'boolean' has two values true"),
1256+
[Row(constraint_status="Success")])
1257+
self.assertEqual(
1258+
self.where(lambda x: x == 3.0, "d=5", "column 'd' has three values 3"),
1259+
[Row(constraint_status="Success")],
1260+
)
1261+
self.assertEqual(
1262+
self.where(lambda x: x == 2.0, "ssn='000-00-0000'", "column 'ssn' has one value 000-00-0000"),
1263+
[Row(constraint_status="Failure")],
1264+
)
1265+
check = Check(self.spark, CheckLevel.Warning, "test where").hasMin("f", lambda x: x == 2, "The f has min value 2 becasue of the additional filter").where('f>=2')
1266+
result = VerificationSuite(self.spark).onData(self.df).addCheck(check.isGreaterThan("e", "h", lambda x: x == 1, "Column H is not smaller than Column E")).run()
1267+
df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
1268+
self.assertEqual(
1269+
df.select("constraint_status").collect(),
1270+
[Row(constraint_status="Success"), Row(constraint_status="Failure")],
1271+
)
1272+
with self.assertRaises(TypeError):
1273+
Check(self.spark, CheckLevel.Warning, "test where").kllSketchSatisfies(
1274+
"b", lambda x: x.parameters().apply(0) == 1.0, KLLParameters(self.spark, 2, 0.64, 2)
1275+
).where("d=5")
1276+
1277+
@pytest.mark.xfail(reason="@unittest.expectedFailure")
1278+
def test_fail_where(self):
1279+
self.assertEqual(self.where(lambda x: x == 2.0, "boolean='false'", "column 'boolean' has one value false"),
1280+
[Row(constraint_status="Success")])
1281+
self.assertEqual(
1282+
self.where(lambda x: x == 3.0, "a='bar'", "column 'a' has one value 'bar'"),
1283+
[Row(constraint_status="Success")],
1284+
)
1285+
self.assertEqual(
1286+
self.where(lambda x: x == 1.0, "f=1", "column 'f' has one value 1"),
1287+
[Row(constraint_status="Failure")],
1288+
)
12481289
# def test_hasNumberOfDistinctValues(self):
12491290
# #Todo: test binningUDf
12501291
# self.assertEqual(self.hasNumberOfDistinctValues('b', lambda x: x == 3, None, 3, "Column B has 3 distinct values"),

0 commit comments

Comments
 (0)