|
3 | 3 |
|
4 | 4 | import pytest |
5 | 5 | from pyspark.sql import Row |
| 6 | +from pyspark.errors import AnalysisException |
6 | 7 |
|
7 | 8 | from pydeequ import PyDeequSession |
8 | 9 | from pydeequ.analyzers import ( |
@@ -119,7 +120,7 @@ def CustomSql(self, expression, disambiguator="*"): |
119 | 120 | result_json = AnalyzerContext.successMetricsAsJson(self.spark, result) |
120 | 121 | df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) |
121 | 122 | self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) |
122 | | - return result_df.select("value").collect() |
| 123 | + return result_df.select("value", "instance").collect() |
123 | 124 |
|
124 | 125 | def Datatype(self, column, where=None): |
125 | 126 | result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run() |
@@ -309,14 +310,25 @@ def test_fail_CountDistinct(self): |
309 | 310 |
|
310 | 311 | def test_CustomSql(self): |
311 | 312 | self.df.createOrReplaceTempView("input_table") |
312 | | - self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0)]) |
313 | | - self.assertEqual(self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table"), [Row(value=3.0)]) |
314 | | - self.assertEqual(self.CustomSql("SELECT MAX(c) FROM input_table"), [Row(value=6.0)]) |
| 313 | + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0, instance="*")]) |
| 314 | + self.assertEqual( |
| 315 | + self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table", disambiguator="foo"), |
| 316 | + [Row(value=3.0, instance="foo")] |
| 317 | + ) |
| 318 | + self.assertEqual( |
| 319 | + self.CustomSql("SELECT MAX(c) FROM input_table", disambiguator="bar"), |
| 320 | + [Row(value=6.0, instance="bar")] |
| 321 | + ) |
315 | 322 |
|
316 | 323 | @pytest.mark.xfail(reason="@unittest.expectedFailure") |
317 | 324 | def test_fail_CustomSql(self): |
318 | 325 | self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)]) |
319 | 326 |
|
| 327 | + @pytest.mark.xfail(reason="@unittest.expectedFailure") |
| 328 | + def test_fail_CustomSql_incorrect_query(self): |
| 329 | + with self.assertRaises(AnalysisException): |
| 330 | + self.CustomSql("SELECT SUM(b)") |
| 331 | + |
320 | 332 | def test_DataType(self): |
321 | 333 | self.assertEqual( |
322 | 334 | self.Datatype("b"), |
|
0 commit comments