From 7ba24ca33dd87ca27fdf0bdee5e7d8ef285bf585 Mon Sep 17 00:00:00 2001 From: Steven Aerts Date: Thu, 2 Oct 2025 12:02:33 +0000 Subject: [PATCH] SPARK-53790 support generic beans in arrow encoder In the past the arrow encoder used lookup.findVirtual to lookup getters and setters for java beans. `findVirtual` is however very sensitive for the exact type used. This type was extrated from the encoder. For generic beans it could however be erased and as such be a different type from the expected type. This patch will try the super classes of the expected type when the expected signature is not found. --- .../client/arrow/ArrowEncoderSuite.scala | 36 +++++++++++++++++++ .../client/arrow/ArrowDeserializer.scala | 21 ++++++++--- .../client/arrow/ArrowSerializer.scala | 5 +-- 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index b29d73be359b5..d64cf6b671c88 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -858,6 +858,20 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } + test("SPARK-53790: bean encoders with specific generics") { + val encoder = JavaTypeInference.encoderFor(classOf[JavaBeanWithGenericsWrapper]) + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(3) + Iterator.tabulate(10)(i => { + val bean = new JavaBeanWithGenericsWrapper() + val inner = new JavaBeanWithGenerics[String]() + inner.setValue(maybeNull(i.toString)) + bean.setValue(maybeNull(inner)) + bean + }) + } + } + /* ******************************************************************** * * Arrow deserialization upcasting * ******************************************************************** */ @@ -1190,6 +1204,28 @@ class DummyBean { } } +class JavaBeanWithGenerics[T] { + @BeanProperty var value: T = _ + + override def hashCode(): Int = Objects.hashCode(value) + + override def equals(obj: Any): Boolean = obj match { + case bean: JavaBeanWithGenerics[_] => Objects.equals(value, bean.value) + case _ => false + } +} + +class JavaBeanWithGenericsWrapper { + @BeanProperty var value: JavaBeanWithGenerics[String] = _ + + override def hashCode(): Int = Objects.hashCode(value) + + override def equals(obj: Any): Boolean = obj match { + case bean: JavaBeanWithGenericsWrapper => Objects.equals(value, bean.value) + case _ => false + } +} + object FooEnum extends Enumeration { type FooEnum = Value val E1, E2 = Value diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 7597a0ceeb8cd..0da843ea40108 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connect.client.arrow import java.io.{ByteArrayInputStream, IOException} -import java.lang.invoke.{MethodHandles, MethodType} +import java.lang.invoke.{MethodHandle, MethodHandles, MethodType} import java.lang.reflect.Modifier import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger} import java.time._ @@ -378,10 +378,8 @@ object ArrowDeserializers { .map { field => val vector = lookup(field.name) val deserializer = deserializerFor(field.enc, vector, timeZoneId) - val setter = methodLookup.findVirtual( - tag.runtimeClass, - field.writeMethod.get, - MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) + val setter = + findSetter(tag.runtimeClass, field.writeMethod.get, field.enc.clsTag.runtimeClass) (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) } new StructFieldSerializer[Any](struct) { @@ -408,6 +406,19 @@ object ArrowDeserializers { } } + private def findSetter(refc: Class[_], name: String, ftype: Class[_]): MethodHandle = { + try { + methodLookup.findVirtual( + refc, + name, + MethodType.methodType(classOf[Unit], ftype)) + } catch { + case e: NoSuchMethodException => + val superClass: Class[_] = ftype.getSuperclass + if (superClass != null) findSetter(refc, name, superClass) else throw e + } + } + private val methodLookup = MethodHandles.lookup() /** diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 4acb11f014d19..efb5d8b2c0ea4 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -489,10 +489,7 @@ object ArrowSerializer { case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => structSerializerFor(fields, struct, vectors) { (field, _) => - val getter = methodLookup.findVirtual( - tag.runtimeClass, - field.readMethod.get, - MethodType.methodType(field.enc.clsTag.runtimeClass)) + val getter = methodLookup.unreflect(tag.runtimeClass.getMethod(field.readMethod.get)) o => getter.invoke(o) }