|
14 | 14 | Compliance, |
15 | 15 | Correlation, |
16 | 16 | CountDistinct, |
| 17 | + CustomSql, |
17 | 18 | DataType, |
18 | 19 | Distinctness, |
19 | 20 | Entropy, |
@@ -111,6 +112,14 @@ def CountDistinct(self, columns): |
111 | 112 | df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) |
112 | 113 | self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) |
113 | 114 | return result_df.select("value").collect() |
| 115 | + |
| 116 | + def CustomSql(self, expression, disambiguator="*"): |
| 117 | + result = self.AnalysisRunner.onData(self.df).addAnalyzer(CustomSql(expression, disambiguator)).run() |
| 118 | + result_df = AnalyzerContext.successMetricsAsDataFrame(self.spark, result) |
| 119 | + result_json = AnalyzerContext.successMetricsAsJson(self.spark, result) |
| 120 | + df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) |
| 121 | + self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) |
| 122 | + return result_df.select("value").collect() |
114 | 123 |
|
115 | 124 | def Datatype(self, column, where=None): |
116 | 125 | result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run() |
@@ -298,6 +307,16 @@ def test_CountDistinct(self): |
298 | 307 | def test_fail_CountDistinct(self): |
299 | 308 | self.assertEqual(self.CountDistinct("b"), [Row(value=1.0)]) |
300 | 309 |
|
| 310 | + def test_CustomSql(self): |
| 311 | + self.df.createOrReplaceTempView("input_table") |
| 312 | + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0)]) |
| 313 | + self.assertEqual(self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table"), [Row(value=3.0)]) |
| 314 | + self.assertEqual(self.CustomSql("SELECT MAX(c) FROM input_table"), [Row(value=6.0)]) |
| 315 | + |
| 316 | + @pytest.mark.xfail(reason="@unittest.expectedFailure") |
| 317 | + def test_fail_CustomSql(self): |
| 318 | + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)]) |
| 319 | + |
301 | 320 | def test_DataType(self): |
302 | 321 | self.assertEqual( |
303 | 322 | self.Datatype("b"), |
|
0 commit comments