Skip to content

Commit 2236d3d

Browse files
committed
add test
1 parent a4f03cf commit 2236d3d

File tree

6 files changed

+77
-21
lines changed

6 files changed

+77
-21
lines changed

paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,20 @@
2424
import org.apache.paimon.table.source.InnerTableScan;
2525
import org.apache.paimon.types.RowType;
2626

27-
import javax.annotation.Nullable;
28-
2927
import java.util.List;
3028
import java.util.Map;
3129

3230
/**
33-
* A table wrapper to hold vector search information. This is used by Spark engine to pass vector
34-
* search pushdown information from logical plan optimization to physical plan execution.
31+
* A table wrapper to hold vector search information. This is used to pass vector search pushdown
32+
* information from logical plan optimization to physical plan execution. For now, it is only used
33+
* by internal for Spark engine.
3534
*/
3635
public class VectorSearchTable implements ReadonlyTable {
3736

3837
private final InnerTable origin;
3938
private final VectorSearch vectorSearch;
4039

41-
VectorSearchTable(InnerTable origin, VectorSearch vectorSearch) {
40+
private VectorSearchTable(InnerTable origin, VectorSearch vectorSearch) {
4241
this.origin = origin;
4342
this.vectorSearch = vectorSearch;
4443
}
@@ -47,7 +46,6 @@ public static VectorSearchTable create(InnerTable origin, VectorSearch vectorSea
4746
return new VectorSearchTable(origin, vectorSearch);
4847
}
4948

50-
@Nullable
5149
public VectorSearch vectorSearch() {
5250
return vectorSearch;
5351
}
@@ -93,7 +91,7 @@ public InnerTableRead newRead() {
9391

9492
@Override
9593
public InnerTableScan newScan() {
96-
return origin.newScan();
94+
throw new UnsupportedOperationException();
9795
}
9896

9997
@Override

paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ case class PaimonScan(
3939
pushedDataFilters: Seq[Predicate],
4040
override val pushedLimit: Option[Int],
4141
override val pushedTopN: Option[TopN],
42-
override val pushedVectorSearch: Option[VectorSearch] = None,
42+
override val pushedVectorSearch: Option[VectorSearch],
4343
bucketedScanDisabled: Boolean = false)
4444
extends PaimonBaseScan(table)
4545
with SupportsRuntimeFiltering

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ case class PaimonScan(
4141
pushedDataFilters: Seq[Predicate],
4242
override val pushedLimit: Option[Int],
4343
override val pushedTopN: Option[TopN],
44-
override val pushedVectorSearch: Option[VectorSearch] = None,
44+
override val pushedVectorSearch: Option[VectorSearch],
4545
bucketedScanDisabled: Boolean = false)
4646
extends PaimonBaseScan(table)
4747
with SupportsReportPartitioning

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,13 @@ class PaimonScanBuilder(val table: InnerTable)
130130
case None =>
131131
val (actualTable, vectorSearch) = table match {
132132
case vst: org.apache.paimon.table.VectorSearchTable =>
133-
(vst.origin(), Some(vst.vectorSearch()))
133+
val tableVectorSearch = Option(vst.vectorSearch())
134+
val vs = (tableVectorSearch, pushedVectorSearch) match {
135+
case (Some(_), _) => tableVectorSearch
136+
case (None, Some(_)) => pushedVectorSearch
137+
case (None, None) => None
138+
}
139+
(vst.origin(), vs)
134140
case _ => (table, pushedVectorSearch)
135141
}
136142

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ object PaimonTableValuedFunctions {
117117
argsWithoutTable: Seq[Expression]): LogicalPlan = {
118118
sparkTable match {
119119
case st @ SparkTable(innerTable: InnerTable) =>
120-
val vectorSearch = vsq.createVectorSearch(argsWithoutTable)
120+
val vectorSearch = vsq.createVectorSearch(innerTable, argsWithoutTable)
121121
val vectorSearchTable = VectorSearchTable.create(innerTable, vectorSearch)
122122
DataSourceV2Relation.create(
123123
st.copy(table = vectorSearchTable),
@@ -126,7 +126,8 @@ object PaimonTableValuedFunctions {
126126
CaseInsensitiveStringMap.empty())
127127
case _ =>
128128
throw new RuntimeException(
129-
s"vector_search only supports Paimon tables, got ${sparkTable.getClass.getName}")
129+
"vector_search only supports Paimon SparkTable backed by InnerTable, " +
130+
s"but got table implementation: ${sparkTable.getClass.getName}")
130131
}
131132
}
132133

@@ -258,21 +259,32 @@ case class VectorSearchQuery(override val args: Seq[Expression])
258259
Map.empty
259260
}
260261

261-
def createVectorSearch(argsWithoutTable: Seq[Expression]): VectorSearch = {
262-
assert(
263-
argsWithoutTable.size == 3,
264-
s"$VECTOR_SEARCH needs four parameters: table_name, column_name, query_vector, limit. " +
265-
s"Got ${argsWithoutTable.size + 1} parameters."
266-
)
267-
262+
def createVectorSearch(
263+
innerTable: InnerTable,
264+
argsWithoutTable: Seq[Expression]): VectorSearch = {
265+
if (argsWithoutTable.size != 3) {
266+
throw new RuntimeException(
267+
s"$VECTOR_SEARCH needs three parameters after table_name: column_name, query_vector, limit. " +
268+
s"Got ${argsWithoutTable.size} parameters after table_name."
269+
)
270+
}
268271
val columnName = argsWithoutTable.head.eval().toString
272+
if (!innerTable.rowType().containsField(columnName)) {
273+
throw new RuntimeException(
274+
s"Column $columnName does not exist in table ${innerTable.name()}"
275+
)
276+
}
269277
val queryVector = extractQueryVector(argsWithoutTable(1))
270278
val limit = argsWithoutTable(2).eval() match {
271279
case i: Int => i
272280
case l: Long => l.toInt
273281
case other => throw new RuntimeException(s"Invalid limit type: ${other.getClass.getName}")
274282
}
275-
283+
if (limit <= 0) {
284+
throw new IllegalArgumentException(
285+
s"Limit must be a positive integer, but got: $limit"
286+
)
287+
}
276288
new VectorSearch(queryVector, limit, columnName)
277289
}
278290

@@ -281,7 +293,7 @@ case class VectorSearchQuery(override val args: Seq[Expression])
281293
case Literal(arrayData, _) if arrayData != null =>
282294
val arr = arrayData.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData]
283295
arr.toFloatArray()
284-
case CreateArray(elements, _) =>
296+
case CreateArray(elements, _) if elements != null =>
285297
elements.map {
286298
case Literal(v: Float, _) => v
287299
case Literal(v: Double, _) => v.toFloat

paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,46 @@ class BaseVectorSearchPushDownTest extends PaimonSparkTestBase with StreamTest {
5555

5656
// Should return results (actual filtering depends on vector index)
5757
assert(result.nonEmpty)
58+
59+
// Test invalid limit (negative)
60+
val ex1 = intercept[Exception] {
61+
spark
62+
.sql("""
63+
|SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f), -3)
64+
|""".stripMargin)
65+
.collect()
66+
}
67+
assert(ex1.getMessage.contains("Limit must be a positive integer"))
68+
69+
// Test invalid limit (zero)
70+
val ex2 = intercept[Exception] {
71+
spark
72+
.sql("""
73+
|SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f), 0)
74+
|""".stripMargin)
75+
.collect()
76+
}
77+
assert(ex2.getMessage.contains("Limit must be a positive integer"))
78+
79+
// Test missing parameters
80+
val ex3 = intercept[Exception] {
81+
spark
82+
.sql("""
83+
|SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f))
84+
|""".stripMargin)
85+
.collect()
86+
}
87+
assert(ex3.getMessage.contains("vector_search needs three parameters after table_name"))
88+
89+
// Test non-existent column
90+
val ex4 = intercept[Exception] {
91+
spark
92+
.sql("""
93+
|SELECT * FROM vector_search('T', 'non_existent_col', array(1.0f, 0.0f, 0.0f), 3)
94+
|""".stripMargin)
95+
.collect()
96+
}
97+
assert(ex4.getMessage.nonEmpty)
5898
}
5999
}
60100
}

0 commit comments

Comments
 (0)