diff --git a/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java b/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java new file mode 100644 index 000000000000..cb98e2505581 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table; + +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.predicate.VectorSearch; +import org.apache.paimon.table.source.InnerTableRead; +import org.apache.paimon.table.source.InnerTableScan; +import org.apache.paimon.types.RowType; + +import java.util.List; +import java.util.Map; + +/** + * A table wrapper to hold vector search information. This is used to pass vector search pushdown + * information from logical plan optimization to physical plan execution. For now, it is only used + * by internal for Spark engine. + */ +public class VectorSearchTable implements ReadonlyTable { + + private final InnerTable origin; + private final VectorSearch vectorSearch; + + private VectorSearchTable(InnerTable origin, VectorSearch vectorSearch) { + this.origin = origin; + this.vectorSearch = vectorSearch; + } + + public static VectorSearchTable create(InnerTable origin, VectorSearch vectorSearch) { + return new VectorSearchTable(origin, vectorSearch); + } + + public VectorSearch vectorSearch() { + return vectorSearch; + } + + public InnerTable origin() { + return origin; + } + + @Override + public String name() { + return origin.name(); + } + + @Override + public RowType rowType() { + return origin.rowType(); + } + + @Override + public List primaryKeys() { + return origin.primaryKeys(); + } + + @Override + public List partitionKeys() { + return origin.partitionKeys(); + } + + @Override + public Map options() { + return origin.options(); + } + + @Override + public FileIO fileIO() { + return origin.fileIO(); + } + + @Override + public InnerTableRead newRead() { + return origin.newRead(); + } + + @Override + public InnerTableScan newScan() { + throw new UnsupportedOperationException(); + } + + @Override + public Table copy(Map dynamicOptions) { + return new VectorSearchTable((InnerTable) origin.copy(dynamicOptions), vectorSearch); + } +} diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala index 41a1d552f1ef..4f5451c95d5b 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala @@ -21,7 +21,7 @@ package org.apache.paimon.spark import org.apache.paimon.CoreOptions import org.apache.paimon.partition.PartitionPredicate import org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates -import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate} +import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN, VectorSearch} import org.apache.paimon.table.SpecialFields.rowTypeWithRowTracking import org.apache.paimon.table.Table import org.apache.paimon.types.RowType @@ -50,6 +50,9 @@ abstract class PaimonBaseScanBuilder protected var pushedPartitionFilters: Array[PartitionPredicate] = Array.empty protected var pushedDataFilters: Array[Predicate] = Array.empty + protected var pushedLimit: Option[Int] = None + protected var pushedTopN: Option[TopN] = None + protected var pushedVectorSearch: Option[VectorSearch] = None protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType()) diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index e9eaa7d6cc83..d6292ad8cf02 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -19,7 +19,7 @@ package org.apache.paimon.spark import org.apache.paimon.partition.PartitionPredicate -import org.apache.paimon.predicate.{Predicate, TopN} +import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch} import org.apache.paimon.table.InnerTable import org.apache.spark.sql.PaimonUtils.fieldReference @@ -37,6 +37,7 @@ case class PaimonScan( pushedDataFilters: Seq[Predicate], override val pushedLimit: Option[Int] = None, override val pushedTopN: Option[TopN] = None, + override val pushedVectorSearch: Option[VectorSearch] = None, bucketedScanDisabled: Boolean = true) extends PaimonBaseScan(table) with SupportsRuntimeFiltering { diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 21ab46dabcaf..770bd8f802ba 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -25,6 +25,13 @@ import org.apache.spark.sql.connector.read.Scan class PaimonScanBuilder(val table: InnerTable) extends PaimonBaseScanBuilder { override def build(): Scan = { - PaimonScan(table, requiredSchema, pushedPartitionFilters, pushedDataFilters) + PaimonScan( + table, + requiredSchema, + pushedPartitionFilters, + pushedDataFilters, + pushedLimit, + pushedTopN, + pushedVectorSearch) } } diff --git a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 3afca15303ad..8d06751f57ab 100644 --- a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -19,7 +19,7 @@ package org.apache.paimon.spark import org.apache.paimon.partition.PartitionPredicate -import org.apache.paimon.predicate.{Predicate, TopN} +import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch} import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable} import org.apache.paimon.table.source.{DataSplit, Split} @@ -39,6 +39,7 @@ case class PaimonScan( pushedDataFilters: Seq[Predicate], override val pushedLimit: Option[Int], override val pushedTopN: Option[TopN], + override val pushedVectorSearch: Option[VectorSearch], bucketedScanDisabled: Boolean = false) extends PaimonBaseScan(table) with SupportsRuntimeFiltering diff --git a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala new file mode 100644 index 000000000000..7ac3c5df0d00 --- /dev/null +++ b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonScan + +/** Tests for vector search table-valued function with global vector index. */ +class VectorSearchPushDownTest extends BaseVectorSearchPushDownTest { + test("vector search with global index") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + // Insert 100 rows with predictable vectors + val values = (0 until 100) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + // Create vector index + val output = spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')") + .collect() + .head + assert(output.getBoolean(0)) + + // Test vector search with table-valued function syntax + val result = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f), 5) + |""".stripMargin) + .collect() + + // The result should contain 5 rows + assert(result.length == 5) + + // Vector (50, 51, 52) should be most similar to the row with id=50 + assert(result.map(_.getInt(0)).contains(50)) + } + } + + test("vector search pushdown is applied in plan") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 10) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + // Create vector index + spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')") + .collect() + + // Check that vector search is pushed down with table function syntax + val df = spark.sql(""" + |SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f), 5) + |""".stripMargin) + + // Get the scan from the executed plan (physical plan) + val executedPlan = df.queryExecution.executedPlan + val batchScans = executedPlan.collect { + case scan: org.apache.spark.sql.execution.datasources.v2.BatchScanExec => scan + } + + assert(batchScans.nonEmpty, "Should have a BatchScanExec in executed plan") + val paimonScans = batchScans.filter(_.scan.isInstanceOf[PaimonScan]) + assert(paimonScans.nonEmpty, "Should have a PaimonScan in executed plan") + + val paimonScan = paimonScans.head.scan.asInstanceOf[PaimonScan] + assert(paimonScan.pushedVectorSearch.isDefined, "Vector search should be pushed down") + assert(paimonScan.pushedVectorSearch.get.fieldName() == "v", "Field name should be 'v'") + assert(paimonScan.pushedVectorSearch.get.limit() == 5, "Limit should be 5") + } + } + + test("vector search topk returns correct results") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + // Insert rows with distinct vectors + val values = (1 to 100) + .map { + i => + val v = math.sqrt(3.0 * i * i) + val normalized = i.toFloat / v.toFloat + s"($i, array($normalized, $normalized, $normalized))" + } + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + // Create vector index + spark.sql( + "CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')") + + // Query for top 10 similar to (1, 1, 1) normalized + val result = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(0.577f, 0.577f, 0.577f), 10) + |""".stripMargin) + .collect() + + assert(result.length == 10) + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala index 8179f504b31f..47723171e4d6 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala @@ -21,7 +21,7 @@ package org.apache.paimon.spark import org.apache.paimon.CoreOptions import org.apache.paimon.partition.PartitionPredicate import org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates -import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN} +import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN, VectorSearch} import org.apache.paimon.table.{SpecialFields, Table} import org.apache.paimon.types.RowType @@ -52,6 +52,7 @@ abstract class PaimonBaseScanBuilder protected var pushedDataFilters: Array[Predicate] = Array.empty protected var pushedLimit: Option[Int] = None protected var pushedTopN: Option[TopN] = None + protected var pushedVectorSearch: Option[VectorSearch] = None protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType()) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 06a97ee8b41a..c9f0e9506e98 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -20,7 +20,7 @@ package org.apache.paimon.spark import org.apache.paimon.CoreOptions.BucketFunctionType import org.apache.paimon.partition.PartitionPredicate -import org.apache.paimon.predicate.{Predicate, TopN} +import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch} import org.apache.paimon.spark.commands.BucketExpression.quote import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable} import org.apache.paimon.table.source.{DataSplit, Split} @@ -41,6 +41,7 @@ case class PaimonScan( pushedDataFilters: Seq[Predicate], override val pushedLimit: Option[Int], override val pushedTopN: Option[TopN], + override val pushedVectorSearch: Option[VectorSearch], bucketedScanDisabled: Boolean = false) extends PaimonBaseScan(table) with SupportsReportPartitioning diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 6eeaaf7b93ed..de75bb823dde 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -128,13 +128,26 @@ class PaimonScanBuilder(val table: InnerTable) localScan match { case Some(scan) => scan case None => + val (actualTable, vectorSearch) = table match { + case vst: org.apache.paimon.table.VectorSearchTable => + val tableVectorSearch = Option(vst.vectorSearch()) + val vs = (tableVectorSearch, pushedVectorSearch) match { + case (Some(_), _) => tableVectorSearch + case (None, Some(_)) => pushedVectorSearch + case (None, None) => None + } + (vst.origin(), vs) + case _ => (table, pushedVectorSearch) + } + PaimonScan( - table, + actualTable, requiredSchema, pushedPartitionFilters, pushedDataFilters, pushedLimit, - pushedTopN) + pushedTopN, + vectorSearch) } } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala index e4f5e7856c88..6bb6004db859 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala @@ -19,9 +19,10 @@ package org.apache.paimon.spark.catalyst.plans.logical import org.apache.paimon.CoreOptions +import org.apache.paimon.predicate.VectorSearch import org.apache.paimon.spark.SparkTable import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions._ -import org.apache.paimon.table.DataTable +import org.apache.paimon.table.{DataTable, InnerTable, VectorSearchTable} import org.apache.paimon.table.source.snapshot.TimeTravelUtil.InconsistentTagBucketException import org.apache.spark.sql.PaimonUtils.createDataset @@ -29,7 +30,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateArray, Expression, ExpressionInfo, Literal} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -42,9 +43,10 @@ object PaimonTableValuedFunctions { val INCREMENTAL_QUERY = "paimon_incremental_query" val INCREMENTAL_BETWEEN_TIMESTAMP = "paimon_incremental_between_timestamp" val INCREMENTAL_TO_AUTO_TAG = "paimon_incremental_to_auto_tag" + val VECTOR_SEARCH = "vector_search" val supportedFnNames: Seq[String] = - Seq(INCREMENTAL_QUERY, INCREMENTAL_BETWEEN_TIMESTAMP, INCREMENTAL_TO_AUTO_TAG) + Seq(INCREMENTAL_QUERY, INCREMENTAL_BETWEEN_TIMESTAMP, INCREMENTAL_TO_AUTO_TAG, VECTOR_SEARCH) private type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder) @@ -56,6 +58,8 @@ object PaimonTableValuedFunctions { FunctionRegistryBase.build[IncrementalBetweenTimestamp](fnName, since = None) case INCREMENTAL_TO_AUTO_TAG => FunctionRegistryBase.build[IncrementalToAutoTag](fnName, since = None) + case VECTOR_SEARCH => + FunctionRegistryBase.build[VectorSearchQuery](fnName, since = None) case _ => throw new Exception(s"Function $fnName isn't a supported table valued function.") } @@ -85,17 +89,45 @@ object PaimonTableValuedFunctions { val sparkCatalog = catalogManager.catalog(catalogName).asInstanceOf[TableCatalog] val ident: Identifier = Identifier.of(Array(dbName), tableName) val sparkTable = sparkCatalog.loadTable(ident) - val options = tvf.parseArgs(args.tail) - usingSparkIncrementQuery(tvf, sparkTable, options) match { - case Some(snapshotIdPair: (Long, Long)) => - sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident, options, snapshotIdPair) + // Handle vector_search specially + tvf match { + case vsq: VectorSearchQuery => + resolveVectorSearchQuery(sparkTable, sparkCatalog, ident, vsq, args.tail) case _ => + val options = tvf.parseArgs(args.tail) + usingSparkIncrementQuery(tvf, sparkTable, options) match { + case Some(snapshotIdPair: (Long, Long)) => + sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident, options, snapshotIdPair) + case _ => + DataSourceV2Relation.create( + sparkTable, + Some(sparkCatalog), + Some(ident), + new CaseInsensitiveStringMap(options.asJava)) + } + } + } + + private def resolveVectorSearchQuery( + sparkTable: Table, + sparkCatalog: TableCatalog, + ident: Identifier, + vsq: VectorSearchQuery, + argsWithoutTable: Seq[Expression]): LogicalPlan = { + sparkTable match { + case st @ SparkTable(innerTable: InnerTable) => + val vectorSearch = vsq.createVectorSearch(innerTable, argsWithoutTable) + val vectorSearchTable = VectorSearchTable.create(innerTable, vectorSearch) DataSourceV2Relation.create( - sparkTable, + st.copy(table = vectorSearchTable), Some(sparkCatalog), Some(ident), - new CaseInsensitiveStringMap(options.asJava)) + CaseInsensitiveStringMap.empty()) + case _ => + throw new RuntimeException( + "vector_search only supports Paimon SparkTable backed by InnerTable, " + + s"but got table implementation: ${sparkTable.getClass.getName}") } } @@ -207,3 +239,70 @@ case class IncrementalToAutoTag(override val args: Seq[Expression]) Map(CoreOptions.INCREMENTAL_TO_AUTO_TAG.key -> endTagName) } } + +/** + * Plan for the [[VECTOR_SEARCH]] table-valued function. + * + * Usage: vector_search(table_name, column_name, query_vector, limit) + * - table_name: the Paimon table to search + * - column_name: the vector column name + * - query_vector: array of floats representing the query vector + * - limit: the number of top results to return + * + * Example: SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f), 5) + */ +case class VectorSearchQuery(override val args: Seq[Expression]) + extends PaimonTableValueFunction(VECTOR_SEARCH) { + + override def parseArgs(args: Seq[Expression]): Map[String, String] = { + // This method is not used for VectorSearchQuery as we handle it specially + Map.empty + } + + def createVectorSearch( + innerTable: InnerTable, + argsWithoutTable: Seq[Expression]): VectorSearch = { + if (argsWithoutTable.size != 3) { + throw new RuntimeException( + s"$VECTOR_SEARCH needs three parameters after table_name: column_name, query_vector, limit. " + + s"Got ${argsWithoutTable.size} parameters after table_name." + ) + } + val columnName = argsWithoutTable.head.eval().toString + if (!innerTable.rowType().containsField(columnName)) { + throw new RuntimeException( + s"Column $columnName does not exist in table ${innerTable.name()}" + ) + } + val queryVector = extractQueryVector(argsWithoutTable(1)) + val limit = argsWithoutTable(2).eval() match { + case i: Int => i + case l: Long => l.toInt + case other => throw new RuntimeException(s"Invalid limit type: ${other.getClass.getName}") + } + if (limit <= 0) { + throw new IllegalArgumentException( + s"Limit must be a positive integer, but got: $limit" + ) + } + new VectorSearch(queryVector, limit, columnName) + } + + private def extractQueryVector(expr: Expression): Array[Float] = { + expr match { + case Literal(arrayData, _) if arrayData != null => + val arr = arrayData.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData] + arr.toFloatArray() + case CreateArray(elements, _) if elements != null => + elements.map { + case Literal(v: Float, _) => v + case Literal(v: Double, _) => v.toFloat + case Literal(v: java.lang.Float, _) if v != null => v.floatValue() + case Literal(v: java.lang.Double, _) if v != null => v.floatValue() + case other => throw new RuntimeException(s"Cannot extract float from: $other") + }.toArray + case _ => + throw new RuntimeException(s"Cannot extract query vector from expression: $expr") + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala index dcd3dda67af2..3e6a3f0319e2 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/scan/BaseScan.scala @@ -20,7 +20,7 @@ package org.apache.paimon.spark.scan import org.apache.paimon.CoreOptions import org.apache.paimon.partition.PartitionPredicate -import org.apache.paimon.predicate.{Predicate, TopN} +import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch} import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition, PaimonNumSplitMetric, PaimonPartitionSizeMetric, PaimonReadBatchTimeMetric, PaimonResultedTableFilesMetric, PaimonResultedTableFilesTaskMetric, SparkTypeUtils} import org.apache.paimon.spark.schema.PaimonMetadataColumn import org.apache.paimon.spark.schema.PaimonMetadataColumn._ @@ -49,6 +49,7 @@ trait BaseScan extends Scan with SupportsReportStatistics with Logging { def pushedDataFilters: Seq[Predicate] def pushedLimit: Option[Int] = None def pushedTopN: Option[TopN] = None + def pushedVectorSearch: Option[VectorSearch] = None // Input splits def inputSplits: Array[Split] @@ -104,6 +105,7 @@ trait BaseScan extends Scan with SupportsReportStatistics with Logging { } pushedLimit.foreach(_readBuilder.withLimit) pushedTopN.foreach(_readBuilder.withTopN) + pushedVectorSearch.foreach(_readBuilder.withVectorSearch) _readBuilder.dropStats() } @@ -173,6 +175,7 @@ trait BaseScan extends Scan with SupportsReportStatistics with Logging { pushedPartitionFiltersStr + pushedDataFiltersStr + pushedTopN.map(topN => s", TopN: [$topN]").getOrElse("") + - pushedLimit.map(limit => s", Limit: [$limit]").getOrElse("") + pushedLimit.map(limit => s", Limit: [$limit]").getOrElse("") + + pushedVectorSearch.map(vs => s", VectorSearch: [$vs]").getOrElse("") } } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala new file mode 100644 index 000000000000..c283326cf3fa --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BaseVectorSearchPushDownTest.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase + +import org.apache.spark.sql.streaming.StreamTest + +/** Tests for vector search table-valued function. */ +class BaseVectorSearchPushDownTest extends PaimonSparkTestBase with StreamTest { + + test("vector_search table function basic syntax") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + // Insert data with known vectors + spark.sql(""" + |INSERT INTO T VALUES + |(1, array(1.0, 0.0, 0.0)), + |(2, array(0.0, 1.0, 0.0)), + |(3, array(0.0, 0.0, 1.0)), + |(4, array(1.0, 1.0, 0.0)), + |(5, array(1.0, 1.0, 1.0)) + |""".stripMargin) + + // Test vector_search table function syntax + // Note: Without a global vector index, this will scan all rows + val result = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f), 3) + |""".stripMargin) + .collect() + + // Should return results (actual filtering depends on vector index) + assert(result.nonEmpty) + + // Test invalid limit (negative) + val ex1 = intercept[Exception] { + spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f), -3) + |""".stripMargin) + .collect() + } + assert(ex1.getMessage.contains("Limit must be a positive integer")) + + // Test invalid limit (zero) + val ex2 = intercept[Exception] { + spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f), 0) + |""".stripMargin) + .collect() + } + assert(ex2.getMessage.contains("Limit must be a positive integer")) + + // Test missing parameters + val ex3 = intercept[Exception] { + spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(1.0f, 0.0f, 0.0f)) + |""".stripMargin) + .collect() + } + assert(ex3.getMessage.contains("vector_search needs three parameters after table_name")) + + // Test non-existent column + val ex4 = intercept[Exception] { + spark + .sql(""" + |SELECT * FROM vector_search('T', 'non_existent_col', array(1.0f, 0.0f, 0.0f), 3) + |""".stripMargin) + .collect() + } + assert(ex4.getMessage.nonEmpty) + } + } +}