Skip to content

Commit e92ef55

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-55640][GEO][SQL] Propagate WKB parsing errors for Geometry and Geography
### What changes were proposed in this pull request? WKB reader was implemented for Geometry and Geography, but only using internal exception handling. This PR addresses this by introducing proper user-facing error classes for WKB parsing. ### Why are the changes needed? Propagate the WKB parsing errors properly to the user. ### Does this PR introduce _any_ user-facing change? Yes, users now get proper errors for invalid WKB parsing. ### How was this patch tested? Added new unit tests and end-to-end SQL tests for WKB parsing. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #54424 from uros-db/geo-wkb-parse-exceptions. Authored-by: Uros Bojanic <uros.bojanic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent e3b6b10 commit e92ef55

File tree

18 files changed

+359
-81
lines changed

18 files changed

+359
-81
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7754,6 +7754,12 @@
77547754
],
77557755
"sqlState" : "42601"
77567756
},
7757+
"WKB_PARSE_ERROR" : {
7758+
"message" : [
7759+
"Error parsing WKB: <parseError> at position <pos>"
7760+
],
7761+
"sqlState" : "22023"
7762+
},
77577763
"WRITE_STREAM_NOT_ALLOWED" : {
77587764
"message" : [
77597765
"`writeStream` can be called only on streaming Dataset/DataFrame."

python/pyspark/errors/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,12 @@
14881488
"Value for `<arg_name>` must be between <lower_bound> and <upper_bound> (inclusive), got <actual>"
14891489
]
14901490
},
1491+
"WKB_PARSE_ERROR" : {
1492+
"message" : [
1493+
"Error parsing WKB: <parseError> at position <pos>"
1494+
],
1495+
"sqlState" : "22023"
1496+
},
14911497
"WRONG_NUM_ARGS_FOR_HIGHER_ORDER_FUNCTION": {
14921498
"message": [
14931499
"Function `<func_name>` should take between 1 and 3 arguments, but the provided function takes <num_args>."

python/pyspark/sql/tests/test_functions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import unittest
2727

2828
from pyspark.errors import PySparkTypeError, PySparkValueError, SparkRuntimeException
29+
from pyspark.errors.exceptions.base import IllegalArgumentException
2930
from pyspark.sql import Row, Window, functions as F, types
3031
from pyspark.sql.avro.functions import from_avro, to_avro
3132
from pyspark.sql.column import Column
@@ -3798,6 +3799,25 @@ def test_st_asbinary(self):
37983799
)
37993800
self.assertEqual(results, [expected])
38003801

3802+
def test_st_geogfromwkb(self):
3803+
df = self.spark.createDataFrame(
3804+
[(bytes.fromhex("0101000000000000000000F03F0000000000000040"),)],
3805+
["wkb"],
3806+
)
3807+
results = df.select(
3808+
F.hex(F.st_asbinary(F.st_geogfromwkb("wkb"))),
3809+
).collect()
3810+
expected = Row(
3811+
"0101000000000000000000F03F0000000000000040",
3812+
)
3813+
self.assertEqual(results, [expected])
3814+
# ST_GeogFromWKB with invalid WKB.
3815+
df = self.spark.createDataFrame([(bytearray(b"\x6f"),)], ["wkb"])
3816+
with self.assertRaises(IllegalArgumentException) as error_context:
3817+
df.select(F.st_geogfromwkb("wkb")).collect()
3818+
self.assertIn("[WKB_PARSE_ERROR]", str(error_context.exception))
3819+
self.assertIn("Unexpected end of WKB buffer", str(error_context.exception))
3820+
38013821
def test_st_geomfromwkb(self):
38023822
df = self.spark.createDataFrame(
38033823
[(bytes.fromhex("0101000000000000000000F03F0000000000000040"), 4326)],
@@ -3814,6 +3834,12 @@ def test_st_geomfromwkb(self):
38143834
"0101000000000000000000F03F0000000000000040",
38153835
)
38163836
self.assertEqual(results, [expected])
3837+
# ST_GeomFromWKB with invalid WKB.
3838+
df = self.spark.createDataFrame([(bytearray(b"\x6f"),)], ["wkb"])
3839+
with self.assertRaises(IllegalArgumentException) as error_context:
3840+
df.select(F.st_geomfromwkb("wkb")).collect()
3841+
self.assertIn("[WKB_PARSE_ERROR]", str(error_context.exception))
3842+
self.assertIn("Unexpected end of WKB buffer", str(error_context.exception))
38173843

38183844
def test_st_setsrid(self):
38193845
df = self.spark.createDataFrame(

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geography.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package org.apache.spark.sql.catalyst.util;
1818

1919
import org.apache.spark.sql.catalyst.util.geo.GeometryModel;
20+
import org.apache.spark.sql.catalyst.util.geo.WkbParseException;
2021
import org.apache.spark.sql.catalyst.util.geo.WkbReader;
2122
import org.apache.spark.sql.catalyst.util.geo.WkbWriter;
23+
import org.apache.spark.sql.errors.QueryExecutionErrors;
2224
import org.apache.spark.unsafe.types.GeographyVal;
2325

2426
import java.nio.ByteBuffer;
@@ -81,13 +83,17 @@ public Geography copy() {
8183

8284
// Returns a Geography object with the specified SRID value by parsing the input WKB.
8385
public static Geography fromWkb(byte[] wkb, int srid) {
84-
WkbReader reader = new WkbReader(true);
85-
reader.read(wkb); // Validate WKB with geography coordinate bounds.
86-
87-
byte[] bytes = new byte[HEADER_SIZE + wkb.length];
88-
ByteBuffer.wrap(bytes).order(DEFAULT_ENDIANNESS).putInt(srid);
89-
System.arraycopy(wkb, 0, bytes, WKB_OFFSET, wkb.length);
90-
return fromBytes(bytes);
86+
try {
87+
WkbReader reader = new WkbReader(true);
88+
reader.read(wkb); // Validate WKB with geography coordinate bounds.
89+
90+
byte[] bytes = new byte[HEADER_SIZE + wkb.length];
91+
ByteBuffer.wrap(bytes).order(DEFAULT_ENDIANNESS).putInt(srid);
92+
System.arraycopy(wkb, 0, bytes, WKB_OFFSET, wkb.length);
93+
return fromBytes(bytes);
94+
} catch (WkbParseException e) {
95+
throw QueryExecutionErrors.wkbParseError(e.getParseError(), e.getPosition());
96+
}
9197
}
9298

9399
// Overload for the WKB reader where we use the default SRID for Geography.

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geometry.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package org.apache.spark.sql.catalyst.util;
1818

1919
import org.apache.spark.sql.catalyst.util.geo.GeometryModel;
20+
import org.apache.spark.sql.catalyst.util.geo.WkbParseException;
2021
import org.apache.spark.sql.catalyst.util.geo.WkbReader;
2122
import org.apache.spark.sql.catalyst.util.geo.WkbWriter;
23+
import org.apache.spark.sql.errors.QueryExecutionErrors;
2224
import org.apache.spark.unsafe.types.GeometryVal;
2325

2426
import java.nio.ByteBuffer;
@@ -81,13 +83,17 @@ public Geometry copy() {
8183

8284
// Returns a Geometry object with the specified SRID value by parsing the input WKB.
8385
public static Geometry fromWkb(byte[] wkb, int srid) {
84-
WkbReader reader = new WkbReader();
85-
reader.read(wkb); // Validate WKB
86-
87-
byte[] bytes = new byte[HEADER_SIZE + wkb.length];
88-
ByteBuffer.wrap(bytes).order(DEFAULT_ENDIANNESS).putInt(srid);
89-
System.arraycopy(wkb, 0, bytes, WKB_OFFSET, wkb.length);
90-
return fromBytes(bytes);
86+
try {
87+
WkbReader reader = new WkbReader();
88+
reader.read(wkb); // Validate WKB
89+
90+
byte[] bytes = new byte[HEADER_SIZE + wkb.length];
91+
ByteBuffer.wrap(bytes).order(DEFAULT_ENDIANNESS).putInt(srid);
92+
System.arraycopy(wkb, 0, bytes, WKB_OFFSET, wkb.length);
93+
return fromBytes(bytes);
94+
} catch (WkbParseException e) {
95+
throw QueryExecutionErrors.wkbParseError(e.getParseError(), e.getPosition());
96+
}
9197
}
9298

9399
// Overload for the WKB reader where we use the default SRID for Geometry.

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/geo/WkbParseException.java

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,27 @@
1919
/**
2020
* Exception thrown when parsing WKB data fails.
2121
*/
22-
class WkbParseException extends RuntimeException {
22+
public class WkbParseException extends RuntimeException {
23+
private final String parseError;
2324
private final long position;
24-
private final String wkbString;
25+
private final byte[] wkb;
2526

26-
WkbParseException(String message, long position, byte[] wkb) {
27-
super(formatMessage(message, position, wkb));
27+
WkbParseException(String parseError, long position, byte[] wkb) {
28+
super();
29+
this.parseError = parseError;
2830
this.position = position;
29-
this.wkbString = wkb != null ? bytesToHex(wkb) : "";
31+
this.wkb = wkb;
3032
}
3133

32-
private static String formatMessage(String message, long position, byte[] wkb) {
33-
String baseMessage = message + " at position " + position;
34-
if (wkb != null && wkb.length > 0) {
35-
baseMessage += " in WKB: " + bytesToHex(wkb);
36-
}
37-
return baseMessage;
34+
public String getParseError() {
35+
return parseError;
3836
}
3937

40-
private static String bytesToHex(byte[] bytes) {
41-
if (bytes == null || bytes.length == 0) {
42-
return "";
43-
}
44-
StringBuilder sb = new StringBuilder(bytes.length * 2);
45-
for (byte b : bytes) {
46-
sb.append(String.format("%02X", b));
47-
}
48-
return sb.toString();
49-
}
50-
51-
long getPosition() {
38+
public long getPosition() {
5239
return position;
5340
}
5441

55-
String getWkbString() {
56-
return wkbString;
42+
public byte[] getWkb() {
43+
return wkb;
5744
}
5845
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/geo/WkbReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ private GeometryModel readGeometry(int defaultSrid) {
222222

223223
// Check that we have enough bytes for header (endianness byte + 4-byte type)
224224
if (currentWkb.length < WkbUtil.BYTE_SIZE + WkbUtil.TYPE_SIZE) {
225-
throw new WkbParseException("WKB data too short", 0, currentWkb);
225+
throw new WkbParseException("Unexpected end of WKB buffer", 0, currentWkb);
226226
}
227227

228228
// Create buffer wrapping the entire byte array

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
676676
stInvalidSridValueError(srid.toString)
677677
}
678678

679+
def wkbParseError(msg: String, pos: String): SparkIllegalArgumentException = {
680+
new SparkIllegalArgumentException(errorClass = "WKB_PARSE_ERROR",
681+
messageParameters = Map("parseError" -> msg, "pos" -> pos))
682+
}
683+
684+
def wkbParseError(msg: String, pos: Long): SparkIllegalArgumentException = {
685+
wkbParseError(msg, pos.toString)
686+
}
687+
679688
def withSuggestionIntervalArithmeticOverflowError(
680689
suggestedFunc: String,
681690
context: QueryContext): ArithmeticException = {

sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/util/geo/WkbErrorHandlingTest.java

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,12 @@ private void assertParseError(String hex, String expectedMessagePart, int valida
4646
WkbParseException ex = Assertions.assertThrows(
4747
WkbParseException.class, () -> reader.read(wkb),
4848
"Should throw WkbParseException for WKB: " + hex);
49-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
50-
"Exception message should contain the WKB hex: " + hex + ", actual: " + ex.getMessage());
49+
Assertions.assertSame(wkb, ex.getWkb());
5150
if (expectedMessagePart != null && !expectedMessagePart.isEmpty()) {
5251
Assertions.assertTrue(
53-
ex.getMessage().toLowerCase().contains(expectedMessagePart.toLowerCase()),
52+
ex.getParseError().toLowerCase().contains(expectedMessagePart.toLowerCase()),
5453
"Exception message should contain '" + expectedMessagePart + "', actual: " +
55-
ex.getMessage());
54+
ex.getParseError());
5655
}
5756
}
5857

@@ -63,69 +62,60 @@ public void testEmptyWkb() {
6362
WkbParseException ex = Assertions.assertThrows(
6463
WkbParseException.class, () -> reader.read(emptyWkb));
6564
// Empty WKB produces empty hex string, so just verify exception was thrown
66-
Assertions.assertNotNull(ex.getMessage());
65+
Assertions.assertNotNull(ex.getParseError());
6766
}
6867

6968
@Test
7069
public void testTooShortWkb() {
7170
// Only endianness byte
72-
String hex = "01";
73-
byte[] tooShort = hexToBytes(hex);
71+
byte[] tooShort = hexToBytes("01");
7472
WkbReader reader = new WkbReader();
7573
WkbParseException ex = Assertions.assertThrows(
7674
WkbParseException.class, () -> reader.read(tooShort));
77-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
78-
"Exception message should contain the WKB hex: " + hex);
75+
Assertions.assertSame(tooShort, ex.getWkb());
7976
}
8077

8178
@Test
8279
public void testInvalidGeometryTypeZero() {
8380
// Type = 0 (invalid, should be 1-7)
84-
String hex = "0100000000000000000000F03F0000000000000040";
85-
byte[] invalidType = hexToBytes(hex);
81+
byte[] invalidType = hexToBytes("0100000000000000000000F03F0000000000000040");
8682
WkbReader reader = new WkbReader();
8783
WkbParseException ex = Assertions.assertThrows(
8884
WkbParseException.class, () -> reader.read(invalidType));
89-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
90-
"Exception message should contain the WKB hex: " + hex);
85+
Assertions.assertSame(invalidType, ex.getWkb());
9186
}
9287

9388
@Test
9489
public void testTruncatedPointCoordinates() {
9590
// Point WKB with truncated coordinates (missing Y coordinate)
96-
String hex = "0101000000000000000000F03F";
97-
byte[] truncated = hexToBytes(hex);
91+
byte[] truncated = hexToBytes("0101000000000000000000F03F");
9892
WkbReader reader = new WkbReader();
9993
WkbParseException ex = Assertions.assertThrows(
10094
WkbParseException.class, () -> reader.read(truncated));
101-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
102-
"Exception message should contain the WKB hex: " + hex);
95+
Assertions.assertSame(truncated, ex.getWkb());
10396
}
10497

10598
@Test
10699
public void testTruncatedByte() {
107100
// Only one byte (FF) of the 4-byte INT field.
108-
String hex = "0102000000ff";
109-
byte[] truncated = hexToBytes(hex);
101+
byte[] truncated = hexToBytes("0102000000ff");
110102
WkbReader reader = new WkbReader();
111103
WkbParseException ex = Assertions.assertThrows(
112104
WkbParseException.class, () -> reader.read(truncated));
113-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
114-
"Exception message should contain the WKB hex: " + hex);
105+
Assertions.assertSame(truncated, ex.getWkb());
115106
}
116107

117108
@Test
118109
public void testTruncatedLineString() {
119110
// LineString with declared 2 points but only 1 provided
120-
String hex = "010200000002000000" + // LineString with 2 points
121-
"0000000000000000" + // X of first point
122-
"0000000000000000"; // Y of first point (missing second point)
123-
byte[] truncated = hexToBytes(hex);
111+
byte[] truncated = hexToBytes(
112+
"010200000002000000" + // LineString with 2 points
113+
"0000000000000000" + // X of first point
114+
"0000000000000000"); // Y of first point (missing second point)
124115
WkbReader reader = new WkbReader();
125116
WkbParseException ex = Assertions.assertThrows(
126117
WkbParseException.class, () -> reader.read(truncated));
127-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
128-
"Exception message should contain the WKB hex: " + hex);
118+
Assertions.assertSame(truncated, ex.getWkb());
129119
}
130120

131121
@Test
@@ -164,8 +154,7 @@ public void testRingWithTooFewPoints() {
164154

165155
WkbParseException ex = Assertions.assertThrows(
166156
WkbParseException.class, () -> reader.read(invalidPolygon));
167-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
168-
"Exception message should contain the WKB hex: " + hex);
157+
Assertions.assertSame(invalidPolygon, ex.getWkb());
169158
}
170159

171160
@Test
@@ -174,20 +163,19 @@ public void testNonClosedRing() {
174163
WkbReader reader = new WkbReader(1);
175164

176165
// Polygon with ring where first and last points don't match
177-
String hex = "01" + // Little endian
166+
byte[] nonClosedRing = hexToBytes(
167+
"01" + // Little endian
178168
"03000000" + // Polygon type
179169
"01000000" + // 1 ring
180170
"04000000" + // 4 points
181171
"0000000000000000" + "0000000000000000" + // (0, 0)
182172
"000000000000F03F" + "0000000000000000" + // (1, 0)
183173
"000000000000F03F" + "000000000000F03F" + // (1, 1)
184-
"0000000000000040" + "0000000000000040"; // (2, 2) - doesn't match first point!
185-
byte[] nonClosedRing = hexToBytes(hex);
174+
"0000000000000040" + "0000000000000040"); // (2, 2) - doesn't match first point!
186175

187176
WkbParseException ex = Assertions.assertThrows(
188177
WkbParseException.class, () -> reader.read(nonClosedRing));
189-
Assertions.assertTrue(ex.getMessage().toUpperCase().contains(hex.toUpperCase()),
190-
"Exception message should contain the WKB hex: " + hex);
178+
Assertions.assertSame(nonClosedRing, ex.getWkb());
191179
}
192180

193181
@Test
@@ -197,7 +185,7 @@ public void testNullByteArray() {
197185
WkbParseException.class, () -> reader.read(null),
198186
"Should throw WKBParseException for null byte array");
199187
// Null WKB cannot produce hex string, just verify exception was thrown
200-
Assertions.assertNotNull(ex.getMessage());
188+
Assertions.assertNotNull(ex.getParseError());
201189
}
202190

203191
// ========== Invalid Byte Order Tests ==========

sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/util/geo/WkbGeographyTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ public void testGeographyErrorMessageContainsBoundsInfo() {
629629
byte[] wkb = makePointWkb2D(200.0, 0.0);
630630
WkbParseException ex = Assertions.assertThrows(
631631
WkbParseException.class, () -> geographyReader1().read(wkb));
632-
String msg = ex.getMessage();
632+
String msg = ex.getParseError();
633633
Assertions.assertTrue(msg.contains("Invalid coordinate value"));
634634
}
635635
}

0 commit comments

Comments
 (0)