Skip to content

Commit 65b9da5

Browse files
szehon-hocloud-fan
authored andcommitted
[SPARK-46679][SQL] Fix for SparkUnsupportedOperationException Not found an encoder of the type T, when using Parameterized class
This pr is a rebase/continue of https://github.com/apache/spark/pull/48304/files, which was unexpectedly closed even though the issue is still there. ### What changes were proposed in this pull request? That pr fixes a bug in JavaTypeInference.encoderFor. In the parameterizableType case, it overrides the existing map of parameterizable type -> real type. So any nested parameterizable type will fail. ### Why are the changes needed? Fix nested parameterization cases for Java Encoder ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? The original pr adds unit test to JavaTypeInferenceSuite and JavaDatasetSuite. The test from the original pr are simplified a little to make it obvious which one is wrapping. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52444 from szehon-ho/parameterized_encoders. Authored-by: Szehon Ho <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 0cad1cd commit 65b9da5

File tree

4 files changed

+108
-6
lines changed

4 files changed

+108
-6
lines changed

sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ object JavaTypeInference {
139139
encoderFor(typeVariables(tv), seenTypeSet, typeVariables)
140140

141141
case pt: ParameterizedType =>
142-
encoderFor(pt.getRawType, seenTypeSet, JavaTypeUtils.getTypeArguments(pt).asScala.toMap)
142+
val newTvs = JavaTypeUtils.getTypeArguments(pt).asScala.toMap
143+
val allTvs = typeVariables ++ newTvs
144+
encoderFor(pt.getRawType, seenTypeSet, allTvs)
143145

144146
case c: Class[_] =>
145147
if (seenTypeSet.contains(c)) {

sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaTypeInferenceBeans.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,61 @@ static class JavaBeanWithGenericBase extends JavaBeanWithGenerics<String, String
7878
static class JavaBeanWithGenericHierarchy extends JavaBeanWithGenericsABC<Integer> {
7979

8080
}
81+
82+
// SPARK-46679: Test classes for nested parameterized types with multi-level inheritance
83+
static class Foo<T> {
84+
private T t;
85+
86+
public T getT() {
87+
return t;
88+
}
89+
90+
public void setT(T t) {
91+
this.t = t;
92+
}
93+
}
94+
95+
static class FooWrapper<U> {
96+
private Foo<U> foo;
97+
98+
public Foo<U> getFoo() {
99+
return foo;
100+
}
101+
102+
public void setFoo(Foo<U> foo) {
103+
this.foo = foo;
104+
}
105+
}
106+
107+
static class StringFooWrapper extends FooWrapper<String> {
108+
}
109+
110+
// SPARK-46679: Additional test classes for same type variable names at different levels
111+
static class StringBarWrapper extends BarWrapper<String> {
112+
}
113+
114+
static class BarWrapper<T> {
115+
private Bar<T> bar;
116+
117+
public Bar<T> getBar() {
118+
return bar;
119+
}
120+
121+
public void setBar(Bar<T> bar) {
122+
this.bar = bar;
123+
}
124+
}
125+
126+
static class Bar<T> {
127+
private T t;
128+
129+
public T getT() {
130+
return t;
131+
}
132+
133+
public void setT(T t) {
134+
this.t = t;
135+
}
136+
}
81137
}
82138

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.beans.{BeanProperty, BooleanBeanProperty}
2424
import scala.reflect.{classTag, ClassTag}
2525

2626
import org.apache.spark.SparkFunSuite
27-
import org.apache.spark.sql.catalyst.JavaTypeInferenceBeans.{JavaBeanWithGenericBase, JavaBeanWithGenericHierarchy, JavaBeanWithGenericsABC}
27+
import org.apache.spark.sql.catalyst.JavaTypeInferenceBeans.{Bar, Foo, JavaBeanWithGenericBase, JavaBeanWithGenericHierarchy, JavaBeanWithGenericsABC, StringBarWrapper, StringFooWrapper}
2828
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, UDTCaseClass, UDTForCaseClass}
2929
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
3030
import org.apache.spark.sql.types.{DecimalType, MapType, Metadata, StringType, StructField, StructType}
@@ -279,4 +279,26 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
279279
))
280280
assert(encoder === expected)
281281
}
282+
283+
test("SPARK-46679: resolve generics with multi-level inheritance") {
284+
val encoder = JavaTypeInference.encoderFor(classOf[StringFooWrapper])
285+
val expected = JavaBeanEncoder(ClassTag(classOf[StringFooWrapper]), Seq(
286+
encoderField("foo", JavaBeanEncoder(
287+
ClassTag(classOf[Foo[String]]),
288+
Seq(encoderField("t", StringEncoder))
289+
))
290+
))
291+
assert(encoder === expected)
292+
}
293+
294+
test("SPARK-46679: resolve generics with multi-level inheritance same type names") {
295+
val encoder = JavaTypeInference.encoderFor(classOf[StringBarWrapper])
296+
val expected = JavaBeanEncoder(ClassTag(classOf[StringBarWrapper]), Seq(
297+
encoderField("bar", JavaBeanEncoder(
298+
ClassTag(classOf[Bar[String]]),
299+
Seq(encoderField("t", StringEncoder))
300+
))
301+
))
302+
assert(encoder === expected)
303+
}
282304
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,7 @@ public static class SimpleJavaBean implements Serializable {
928928
private List<Long> f;
929929
private Map<Integer, String> g;
930930
private Map<List<Long>, Map<String, String>> h;
931+
private List<List<Long>> i;
931932

932933
public boolean isA() {
933934
return a;
@@ -993,6 +994,14 @@ public void setH(Map<List<Long>, Map<String, String>> h) {
993994
this.h = h;
994995
}
995996

997+
public List<List<Long>> getI() {
998+
return i;
999+
}
1000+
1001+
public void setI(List<List<Long>> i) {
1002+
this.i = i;
1003+
}
1004+
9961005
@Override
9971006
public boolean equals(Object o) {
9981007
if (this == o) return true;
@@ -1007,7 +1016,8 @@ public boolean equals(Object o) {
10071016
if (!e.equals(that.e)) return false;
10081017
if (!f.equals(that.f)) return false;
10091018
if (!g.equals(that.g)) return false;
1010-
return h.equals(that.h);
1019+
if (!h.equals(that.h)) return false;
1020+
return i.equals(that.i);
10111021

10121022
}
10131023

@@ -1021,6 +1031,7 @@ public int hashCode() {
10211031
result = 31 * result + f.hashCode();
10221032
result = 31 * result + g.hashCode();
10231033
result = 31 * result + h.hashCode();
1034+
result = 31 * result + i.hashCode();
10241035
return result;
10251036
}
10261037
}
@@ -1110,6 +1121,10 @@ public void testJavaBeanEncoder() {
11101121
Map<List<Long>, Map<String, String>> complexMap1 = new HashMap<>();
11111122
complexMap1.put(Arrays.asList(1L, 2L), nestedMap1);
11121123
obj1.setH(complexMap1);
1124+
List<Long> nestedList1 = List.of(1L, 2L, 3L);
1125+
List<Long> nestedList2 = List.of(4L, 5L, 6L);
1126+
List<List<Long>> complexList1 = List.of(nestedList1, nestedList2);
1127+
obj1.setI(complexList1);
11131128

11141129
SimpleJavaBean obj2 = new SimpleJavaBean();
11151130
obj2.setA(false);
@@ -1128,6 +1143,10 @@ public void testJavaBeanEncoder() {
11281143
Map<List<Long>, Map<String, String>> complexMap2 = new HashMap<>();
11291144
complexMap2.put(Arrays.asList(3L, 4L), nestedMap2);
11301145
obj2.setH(complexMap2);
1146+
List<Long> nestedList3 = List.of(1L, 2L, 7L);
1147+
List<Long> nestedList4 = List.of(4L, 5L, 8L);
1148+
List<List<Long>> complexList2 = List.of(nestedList3, nestedList4);
1149+
obj2.setI(complexList2);
11311150

11321151
List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
11331152
Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class));
@@ -1148,7 +1167,8 @@ public void testJavaBeanEncoder() {
11481167
Arrays.asList("a", "b"),
11491168
Arrays.asList(100L, null, 200L),
11501169
map1,
1151-
complexMap1});
1170+
complexMap1,
1171+
complexList1});
11521172
Row row2 = new GenericRow(new Object[]{
11531173
false,
11541174
30,
@@ -1157,7 +1177,8 @@ public void testJavaBeanEncoder() {
11571177
Arrays.asList("x", "y"),
11581178
Arrays.asList(300L, null, 400L),
11591179
map2,
1160-
complexMap2});
1180+
complexMap2,
1181+
complexList2});
11611182
StructType schema = new StructType()
11621183
.add("a", BooleanType, false)
11631184
.add("b", IntegerType, false)
@@ -1166,7 +1187,8 @@ public void testJavaBeanEncoder() {
11661187
.add("e", createArrayType(StringType))
11671188
.add("f", createArrayType(LongType))
11681189
.add("g", createMapType(IntegerType, StringType))
1169-
.add("h",createMapType(createArrayType(LongType), createMapType(StringType, StringType)));
1190+
.add("h", createMapType(createArrayType(LongType), createMapType(StringType, StringType)))
1191+
.add("i", createArrayType(createArrayType(LongType)));
11701192
Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema)
11711193
.as(Encoders.bean(SimpleJavaBean.class));
11721194
Assertions.assertEquals(data, ds3.collectAsList());

0 commit comments

Comments
 (0)