Skip to content

Commit b5fe1c1

Browse files
authored
Translate SAFE_CAST to TRY_CAST in Spark SQL (opensearch-project#4788)
Signed-off-by: Lantao Jin <[email protected]>
1 parent e8e9a5b commit b5fe1c1

File tree

7 files changed

+75
-78
lines changed

7 files changed

+75
-78
lines changed

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCastFunctionTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public void testCast() {
2828
verifyLogical(root, expectedLogical);
2929

3030
// TODO there is no SAFE_CAST() in Spark, the Spark CAST is always safe (return null).
31-
String expectedSparkSql = "SELECT SAFE_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`";
31+
String expectedSparkSql = "SELECT TRY_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`";
3232
verifyPPLToSparkSQL(root, expectedSparkSql);
3333
}
3434

@@ -40,7 +40,7 @@ public void testCastInsensitive() {
4040
"" + "LogicalProject(a=[SAFE_CAST($3)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n";
4141
verifyLogical(root, expectedLogical);
4242

43-
String expectedSparkSql = "SELECT SAFE_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`";
43+
String expectedSparkSql = "SELECT TRY_CAST(`MGR` AS STRING) `a`\nFROM `scott`.`EMP`";
4444
verifyPPLToSparkSQL(root, expectedSparkSql);
4545
}
4646

@@ -56,7 +56,7 @@ public void testCastOverriding() {
5656

5757
String expectedSparkSql =
5858
"SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
59-
+ " SAFE_CAST(`MGR` AS STRING) `age`\n"
59+
+ " TRY_CAST(`MGR` AS STRING) `age`\n"
6060
+ "FROM `scott`.`EMP`";
6161
verifyPPLToSparkSQL(root, expectedSparkSql);
6262
}
@@ -83,7 +83,7 @@ public void testChainedCast() {
8383
verifyLogical(root, expectedLogical);
8484

8585
String expectedSparkSql =
86-
"" + "SELECT SAFE_CAST(SAFE_CAST(`MGR` AS STRING) AS INTEGER) `a`\n" + "FROM `scott`.`EMP`";
86+
"" + "SELECT TRY_CAST(TRY_CAST(`MGR` AS STRING) AS INTEGER) `a`\n" + "FROM `scott`.`EMP`";
8787
verifyPPLToSparkSQL(root, expectedSparkSql);
8888
}
8989

@@ -117,7 +117,7 @@ public void testChainedCast2() {
117117

118118
String expectedSparkSql =
119119
""
120-
+ "SELECT SAFE_CAST(CONCAT(SAFE_CAST(`MGR` AS STRING), '0') AS INTEGER) `a`\n"
120+
+ "SELECT TRY_CAST(CONCAT(TRY_CAST(`MGR` AS STRING), '0') AS INTEGER) `a`\n"
121121
+ "FROM `scott`.`EMP`";
122122
verifyPPLToSparkSQL(root, expectedSparkSql);
123123
}

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLChartTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,14 @@ public void testChartWithMultipleGroupKeys() {
111111
"SELECT `t2`.`gender`, CASE WHEN `t2`.`age` IS NULL THEN 'NULL' WHEN"
112112
+ " `t9`.`_row_number_chart_` <= 10 THEN `t2`.`age` ELSE 'OTHER' END `age`,"
113113
+ " AVG(`t2`.`avg(balance)`) `avg(balance)`\n"
114-
+ "FROM (SELECT `gender`, SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`)"
114+
+ "FROM (SELECT `gender`, TRY_CAST(`age` AS STRING) `age`, AVG(`balance`)"
115115
+ " `avg(balance)`\n"
116116
+ "FROM `scott`.`bank`\n"
117117
+ "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n"
118118
+ "GROUP BY `gender`, `age`) `t2`\n"
119119
+ "LEFT JOIN (SELECT `age`, SUM(`avg(balance)`) `__grand_total__`, ROW_NUMBER() OVER"
120120
+ " (ORDER BY SUM(`avg(balance)`) DESC) `_row_number_chart_`\n"
121-
+ "FROM (SELECT SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n"
121+
+ "FROM (SELECT TRY_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n"
122122
+ "FROM `scott`.`bank`\n"
123123
+ "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n"
124124
+ "GROUP BY `gender`, `age`) `t6`\n"
@@ -139,14 +139,14 @@ public void testChartWithMultipleGroupKeysAlternativeSyntax() {
139139
"SELECT `t2`.`gender`, CASE WHEN `t2`.`age` IS NULL THEN 'NULL' WHEN"
140140
+ " `t9`.`_row_number_chart_` <= 10 THEN `t2`.`age` ELSE 'OTHER' END `age`,"
141141
+ " AVG(`t2`.`avg(balance)`) `avg(balance)`\n"
142-
+ "FROM (SELECT `gender`, SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`)"
142+
+ "FROM (SELECT `gender`, TRY_CAST(`age` AS STRING) `age`, AVG(`balance`)"
143143
+ " `avg(balance)`\n"
144144
+ "FROM `scott`.`bank`\n"
145145
+ "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n"
146146
+ "GROUP BY `gender`, `age`) `t2`\n"
147147
+ "LEFT JOIN (SELECT `age`, SUM(`avg(balance)`) `__grand_total__`, ROW_NUMBER() OVER"
148148
+ " (ORDER BY SUM(`avg(balance)`) DESC) `_row_number_chart_`\n"
149-
+ "FROM (SELECT SAFE_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n"
149+
+ "FROM (SELECT TRY_CAST(`age` AS STRING) `age`, AVG(`balance`) `avg(balance)`\n"
150150
+ "FROM `scott`.`bank`\n"
151151
+ "WHERE `gender` IS NOT NULL AND `balance` IS NOT NULL\n"
152152
+ "GROUP BY `gender`, `age`) `t6`\n"

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsEarliestLatestTest.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void testEventstatsEarliestWithoutSecondArgument() {
4848
verifyLogical(root, expectedLogical);
4949

5050
String expectedSparkSql =
51-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`,"
51+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`,"
5252
+ " `@timestamp`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"
5353
+ " `earliest_message`\n"
5454
+ "FROM `POST`.`LOGS`";
@@ -66,7 +66,7 @@ public void testEventstatsLatestWithoutSecondArgument() {
6666
verifyLogical(root, expectedLogical);
6767

6868
String expectedSparkSql =
69-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY (`message`,"
69+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY(`message`,"
7070
+ " `@timestamp`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"
7171
+ " `latest_message`\n"
7272
+ "FROM `POST`.`LOGS`";
@@ -84,7 +84,7 @@ public void testEventstatsEarliestByServerWithoutSecondArgument() {
8484
verifyLogical(root, expectedLogical);
8585

8686
String expectedSparkSql =
87-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`,"
87+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`,"
8888
+ " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND"
8989
+ " UNBOUNDED FOLLOWING) `earliest_message`\n"
9090
+ "FROM `POST`.`LOGS`";
@@ -102,7 +102,7 @@ public void testEventstatsLatestByServerWithoutSecondArgument() {
102102
verifyLogical(root, expectedLogical);
103103

104104
String expectedSparkSql =
105-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY (`message`,"
105+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY(`message`,"
106106
+ " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND"
107107
+ " UNBOUNDED FOLLOWING) `latest_message`\n"
108108
+ "FROM `POST`.`LOGS`";
@@ -122,7 +122,7 @@ public void testEventstatsEarliestWithOtherAggregatesWithoutSecondArgument() {
122122
verifyLogical(root, expectedLogical);
123123

124124
String expectedSparkSql =
125-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`,"
125+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`,"
126126
+ " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND"
127127
+ " UNBOUNDED FOLLOWING) `earliest_message`, COUNT(*) OVER (PARTITION BY `server` RANGE"
128128
+ " BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) `cnt`\n"
@@ -141,7 +141,7 @@ public void testEventstatsEarliestWithExplicitTimestampField() {
141141
verifyLogical(root, expectedLogical);
142142

143143
String expectedSparkSql =
144-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`,"
144+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`,"
145145
+ " `created_at`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"
146146
+ " `earliest_message`\n"
147147
+ "FROM `POST`.`LOGS`";
@@ -159,7 +159,7 @@ public void testEventstatsLatestWithExplicitTimestampField() {
159159
verifyLogical(root, expectedLogical);
160160

161161
String expectedSparkSql =
162-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY (`message`,"
162+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MAX_BY(`message`,"
163163
+ " `created_at`) OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"
164164
+ " `latest_message`\n"
165165
+ "FROM `POST`.`LOGS`";
@@ -180,9 +180,9 @@ public void testEventstatsEarliestLatestCombined() {
180180
verifyLogical(root, expectedLogical);
181181

182182
String expectedSparkSql =
183-
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY (`message`,"
183+
"SELECT `server`, `level`, `message`, `@timestamp`, `created_at`, MIN_BY(`message`,"
184184
+ " `@timestamp`) OVER (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND"
185-
+ " UNBOUNDED FOLLOWING) `earliest_msg`, MAX_BY (`message`, `@timestamp`) OVER"
185+
+ " UNBOUNDED FOLLOWING) `earliest_msg`, MAX_BY(`message`, `@timestamp`) OVER"
186186
+ " (PARTITION BY `server` RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"
187187
+ " `latest_msg`\n"
188188
+ "FROM `POST`.`LOGS`";

0 commit comments

Comments
 (0)