Skip to content

Commit 2f94313

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-54110][GEO][SQL] Introduce type encoders for Geography and Geometry types
### What changes were proposed in this pull request? This PR introduces type encoders for `Geography` and `Geometry`. Note that the server-side geospatial classes have already been introduced as part of: #52737; while client-side geospatial classes in external API have subsequently been introduced as part of: #52804. ### Why are the changes needed? These encoders are used to translate between (server) Spark Catalyst types and (client) Java/Scala types. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new Scala unit test suites for data frames: - `GeographyDataFrameSuite` - `GeometryDataFrameSuite` Also, added appropriate test cases to: - `RowSuite` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52813 from uros-db/geo-expression-encoders. Authored-by: Uros Bojanic <uros.bojanic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent a5e866f commit 2f94313

File tree

17 files changed

+611
-9
lines changed

17 files changed

+611
-9
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,12 @@
18881888
],
18891889
"sqlState" : "42623"
18901890
},
1891+
"GEO_ENCODER_SRID_MISMATCH_ERROR" : {
1892+
"message" : [
1893+
"Failed to encode <type> value because provided SRID <valueSrid> of a value to encode does not match type SRID: <typeSrid>."
1894+
],
1895+
"sqlState" : "42K09"
1896+
},
18911897
"GET_TABLES_BY_TYPE_UNSUPPORTED_BY_HIVE_VERSION" : {
18921898
"message" : [
18931899
"Hive 2.2 and lower versions don't support getTablesByType. Please use Hive 2.3 or higher version."

sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,20 @@ object Encoders {
162162
*/
163163
def BINARY: Encoder[Array[Byte]] = BinaryEncoder
164164

165+
/**
166+
* An encoder for Geometry data type.
167+
*
168+
* @since 4.1.0
169+
*/
170+
def GEOMETRY(dt: GeometryType): Encoder[Geometry] = GeometryEncoder(dt)
171+
172+
/**
173+
* An encoder for Geography data type.
174+
*
175+
* @since 4.1.0
176+
*/
177+
def GEOGRAPHY(dt: GeographyType): Encoder[Geography] = GeographyEncoder(dt)
178+
165179
/**
166180
* Creates an encoder that serializes instances of the `java.time.Duration` class to the
167181
* internal representation of nullable Catalyst's DayTimeIntervalType.

sql/api/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,24 @@ trait Row extends Serializable {
302302
*/
303303
def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)
304304

305+
/**
306+
* Returns the value at position i of date type as org.apache.spark.sql.types.Geometry.
307+
*
308+
* @throws ClassCastException
309+
* when data type does not match.
310+
*/
311+
def getGeometry(i: Int): org.apache.spark.sql.types.Geometry =
312+
getAs[org.apache.spark.sql.types.Geometry](i)
313+
314+
/**
315+
* Returns the value at position i of date type as org.apache.spark.sql.types.Geography.
316+
*
317+
* @throws ClassCastException
318+
* when data type does not match.
319+
*/
320+
def getGeography(i: Int): org.apache.spark.sql.types.Geography =
321+
getAs[org.apache.spark.sql.types.Geography](i)
322+
305323
/**
306324
* Returns the value at position i of date type as java.sql.Date.
307325
*

sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable {
104104
implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] =
105105
DEFAULT_SCALA_DECIMAL_ENCODER
106106

107+
/** @since 4.1.0 */
108+
implicit def newGeometryEncoder: Encoder[org.apache.spark.sql.types.Geometry] =
109+
DEFAULT_GEOMETRY_ENCODER
110+
111+
/** @since 4.1.0 */
112+
implicit def newGeographyEncoder: Encoder[org.apache.spark.sql.types.Geography] =
113+
DEFAULT_GEOGRAPHY_ENCODER
114+
107115
/** @since 2.2.0 */
108116
implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE
109117

sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
2727
import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils}
2828

2929
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
30-
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
30+
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_GEOGRAPHY_ENCODER, DEFAULT_GEOMETRY_ENCODER, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
3131
import org.apache.spark.sql.errors.ExecutionErrors
3232
import org.apache.spark.sql.types._
3333
import org.apache.spark.util.ArrayImplicits._
@@ -86,6 +86,10 @@ object JavaTypeInference {
8686

8787
case c: Class[_] if c == classOf[java.lang.String] => StringEncoder
8888
case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder
89+
case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geometry] =>
90+
DEFAULT_GEOMETRY_ENCODER
91+
case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geography] =>
92+
DEFAULT_GEOGRAPHY_ENCODER
8993
case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER
9094
case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder
9195
case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER

sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ object ScalaReflection extends ScalaReflection {
332332
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder
333333
case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => LocalTimeEncoder
334334
case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder
335+
case t if isSubtype(t, localTypeOf[Geography]) =>
336+
DEFAULT_GEOGRAPHY_ENCODER
337+
case t if isSubtype(t, localTypeOf[Geometry]) =>
338+
DEFAULT_GEOMETRY_ENCODER
335339
case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder
336340

337341
// UDT encoders

sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ object AgnosticEncoders {
246246
case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType())
247247
case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType())
248248
case object VariantEncoder extends LeafEncoder[VariantVal](VariantType)
249+
case class GeographyEncoder(dt: GeographyType) extends LeafEncoder[Geography](dt)
250+
case class GeometryEncoder(dt: GeometryType) extends LeafEncoder[Geometry](dt)
249251
case class DateEncoder(override val lenientSerialization: Boolean)
250252
extends LeafEncoder[jsql.Date](DateType)
251253
case class LocalDateEncoder(override val lenientSerialization: Boolean)
@@ -277,6 +279,10 @@ object AgnosticEncoders {
277279
ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
278280
val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
279281
JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false)
282+
val DEFAULT_GEOMETRY_ENCODER: GeometryEncoder =
283+
GeometryEncoder(GeometryType(Geometry.DEFAULT_SRID))
284+
val DEFAULT_GEOGRAPHY_ENCODER: GeographyEncoder =
285+
GeographyEncoder(GeographyType(Geography.DEFAULT_SRID))
280286

281287
/**
282288
* Encoder that transforms external data into a representation that can be further processed by

sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121
import scala.reflect.classTag
2222

2323
import org.apache.spark.sql.{AnalysisException, Row}
24-
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
24+
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
2525
import org.apache.spark.sql.errors.DataTypeErrorsBase
2626
import org.apache.spark.sql.internal.SqlApiConf
2727
import org.apache.spark.sql.types._
@@ -120,6 +120,8 @@ object RowEncoder extends DataTypeErrorsBase {
120120
field.nullable,
121121
field.metadata)
122122
}.toImmutableArraySeq)
123+
case g: GeographyType => GeographyEncoder(g)
124+
case g: GeometryType => GeometryEncoder(g)
123125

124126
case _ =>
125127
throw new AnalysisException(

sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.types
1919

2020
import org.json4s.JsonAST.{JString, JValue}
2121

22-
import org.apache.spark.SparkIllegalArgumentException
22+
import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
2323
import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.sql.internal.types.GeographicSpatialReferenceSystemMapper
2525

@@ -133,6 +133,27 @@ class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAl
133133
// If the SRID is not mixed, we can only accept the same SRID.
134134
isMixedSrid || gt.srid == srid
135135
}
136+
137+
private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
138+
// If SRID is not mixed, SRIDs must match.
139+
if (!isMixedSrid && otherSrid != srid) {
140+
throw new SparkRuntimeException(
141+
errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
142+
messageParameters = Map(
143+
"type" -> "GEOGRAPHY",
144+
"valueSrid" -> otherSrid.toString,
145+
"typeSrid" -> srid.toString))
146+
} else if (isMixedSrid) {
147+
// For fixed SRID geom types, we have a check that value matches the type srid.
148+
// For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID.
149+
// However it should accept only valid SRIDs.
150+
if (!GeographyType.isSridSupported(otherSrid)) {
151+
throw new SparkIllegalArgumentException(
152+
errorClass = "ST_INVALID_SRID_VALUE",
153+
messageParameters = Map("srid" -> otherSrid.toString))
154+
}
155+
}
156+
}
136157
}
137158

138159
@Experimental
@@ -157,6 +178,11 @@ object GeographyType extends SpatialType {
157178
private final val GEOGRAPHY_MIXED_TYPE: GeographyType =
158179
GeographyType(MIXED_CRS, GEOGRAPHY_DEFAULT_ALGORITHM)
159180

181+
/** Returns whether the given SRID is supported. */
182+
private[types] def isSridSupported(srid: Int): Boolean = {
183+
GeographicSpatialReferenceSystemMapper.getStringId(srid) != null
184+
}
185+
160186
/**
161187
* Constructors for GeographyType.
162188
*/

sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.types
1919

2020
import org.json4s.JsonAST.{JString, JValue}
2121

22-
import org.apache.spark.SparkIllegalArgumentException
22+
import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
2323
import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.sql.internal.types.CartesianSpatialReferenceSystemMapper
2525

@@ -130,6 +130,27 @@ class GeometryType private (val crs: String) extends AtomicType with Serializabl
130130
// If the SRID is not mixed, we can only accept the same SRID.
131131
isMixedSrid || gt.srid == srid
132132
}
133+
134+
private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
135+
// If SRID is not mixed, SRIDs must match.
136+
if (!isMixedSrid && otherSrid != srid) {
137+
throw new SparkRuntimeException(
138+
errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
139+
messageParameters = Map(
140+
"type" -> "GEOMETRY",
141+
"valueSrid" -> otherSrid.toString,
142+
"typeSrid" -> srid.toString))
143+
} else if (isMixedSrid) {
144+
// For fixed SRID geom types, we have a check that value matches the type srid.
145+
// For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID.
146+
// However it should accept only valid SRIDs.
147+
if (!GeometryType.isSridSupported(otherSrid)) {
148+
throw new SparkIllegalArgumentException(
149+
errorClass = "ST_INVALID_SRID_VALUE",
150+
messageParameters = Map("srid" -> otherSrid.toString))
151+
}
152+
}
153+
}
133154
}
134155

135156
@Experimental
@@ -149,6 +170,11 @@ object GeometryType extends SpatialType {
149170
private final val GEOMETRY_MIXED_TYPE: GeometryType =
150171
GeometryType(MIXED_CRS)
151172

173+
/** Returns whether the given SRID is supported. */
174+
private[types] def isSridSupported(srid: Int): Boolean = {
175+
CartesianSpatialReferenceSystemMapper.getStringId(srid) != null
176+
}
177+
152178
/**
153179
* Constructors for GeometryType.
154180
*/

0 commit comments

Comments
 (0)