Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.table;

import org.apache.paimon.fs.FileIO;
import org.apache.paimon.predicate.VectorSearch;
import org.apache.paimon.table.source.InnerTableRead;
import org.apache.paimon.table.source.InnerTableScan;
import org.apache.paimon.types.RowType;

import java.util.List;
import java.util.Map;

/**
* A table wrapper to hold vector search information. This is used to pass vector search pushdown
* information from logical plan optimization to physical plan execution. For now, it is only used
* by internal for Spark engine.
*/
public class VectorSearchTable implements ReadonlyTable {

private final InnerTable origin;
private final VectorSearch vectorSearch;

private VectorSearchTable(InnerTable origin, VectorSearch vectorSearch) {
this.origin = origin;
this.vectorSearch = vectorSearch;
}

public static VectorSearchTable create(InnerTable origin, VectorSearch vectorSearch) {
return new VectorSearchTable(origin, vectorSearch);
}

public VectorSearch vectorSearch() {
return vectorSearch;
}

public InnerTable origin() {
return origin;
}

@Override
public String name() {
return origin.name();
}

@Override
public RowType rowType() {
return origin.rowType();
}

@Override
public List<String> primaryKeys() {
return origin.primaryKeys();
}

@Override
public List<String> partitionKeys() {
return origin.partitionKeys();
}

@Override
public Map<String, String> options() {
return origin.options();
}

@Override
public FileIO fileIO() {
return origin.fileIO();
}

@Override
public InnerTableRead newRead() {
return origin.newRead();
}

@Override
public InnerTableScan newScan() {
throw new UnsupportedOperationException();
}

@Override
public Table copy(Map<String, String> dynamicOptions) {
return new VectorSearchTable((InnerTable) origin.copy(dynamicOptions), vectorSearch);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.paimon.spark
import org.apache.paimon.CoreOptions
import org.apache.paimon.partition.PartitionPredicate
import org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates
import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate}
import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN, VectorSearch}
import org.apache.paimon.table.SpecialFields.rowTypeWithRowTracking
import org.apache.paimon.table.Table
import org.apache.paimon.types.RowType
Expand Down Expand Up @@ -50,6 +50,9 @@ abstract class PaimonBaseScanBuilder

protected var pushedPartitionFilters: Array[PartitionPredicate] = Array.empty
protected var pushedDataFilters: Array[Predicate] = Array.empty
protected var pushedLimit: Option[Int] = None
protected var pushedTopN: Option[TopN] = None
protected var pushedVectorSearch: Option[VectorSearch] = None

protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.paimon.spark

import org.apache.paimon.partition.PartitionPredicate
import org.apache.paimon.predicate.{Predicate, TopN}
import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
import org.apache.paimon.table.InnerTable

import org.apache.spark.sql.PaimonUtils.fieldReference
Expand All @@ -37,6 +37,7 @@ case class PaimonScan(
pushedDataFilters: Seq[Predicate],
override val pushedLimit: Option[Int] = None,
override val pushedTopN: Option[TopN] = None,
override val pushedVectorSearch: Option[VectorSearch] = None,
bucketedScanDisabled: Boolean = true)
extends PaimonBaseScan(table)
with SupportsRuntimeFiltering {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ import org.apache.spark.sql.connector.read.Scan
class PaimonScanBuilder(val table: InnerTable) extends PaimonBaseScanBuilder {

override def build(): Scan = {
PaimonScan(table, requiredSchema, pushedPartitionFilters, pushedDataFilters)
PaimonScan(
table,
requiredSchema,
pushedPartitionFilters,
pushedDataFilters,
pushedLimit,
pushedTopN,
pushedVectorSearch)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.paimon.spark

import org.apache.paimon.partition.PartitionPredicate
import org.apache.paimon.predicate.{Predicate, TopN}
import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
import org.apache.paimon.table.source.{DataSplit, Split}

Expand All @@ -39,6 +39,7 @@ case class PaimonScan(
pushedDataFilters: Seq[Predicate],
override val pushedLimit: Option[Int],
override val pushedTopN: Option[TopN],
override val pushedVectorSearch: Option[VectorSearch],
bucketedScanDisabled: Boolean = false)
extends PaimonBaseScan(table)
with SupportsRuntimeFiltering
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.PaimonScan

/** Tests for vector search table-valued function with global vector index. */
class VectorSearchPushDownTest extends BaseVectorSearchPushDownTest {
test("vector search with global index") {
withTable("T") {
spark.sql("""
|CREATE TABLE T (id INT, v ARRAY<FLOAT>)
|TBLPROPERTIES (
| 'bucket' = '-1',
| 'global-index.row-count-per-shard' = '10000',
| 'row-tracking.enabled' = 'true',
| 'data-evolution.enabled' = 'true')
|""".stripMargin)

// Insert 100 rows with predictable vectors
val values = (0 until 100)
.map(
i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))")
.mkString(",")
spark.sql(s"INSERT INTO T VALUES $values")

// Create vector index
val output = spark
.sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
.collect()
.head
assert(output.getBoolean(0))

// Test vector search with table-valued function syntax
val result = spark
.sql("""
|SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f), 5)
|""".stripMargin)
.collect()

// The result should contain 5 rows
assert(result.length == 5)

// Vector (50, 51, 52) should be most similar to the row with id=50
assert(result.map(_.getInt(0)).contains(50))
}
}

test("vector search pushdown is applied in plan") {
withTable("T") {
spark.sql("""
|CREATE TABLE T (id INT, v ARRAY<FLOAT>)
|TBLPROPERTIES (
| 'bucket' = '-1',
| 'global-index.row-count-per-shard' = '10000',
| 'row-tracking.enabled' = 'true',
| 'data-evolution.enabled' = 'true')
|""".stripMargin)

val values = (0 until 10)
.map(
i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))")
.mkString(",")
spark.sql(s"INSERT INTO T VALUES $values")

// Create vector index
spark
.sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
.collect()

// Check that vector search is pushed down with table function syntax
val df = spark.sql("""
|SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f), 5)
|""".stripMargin)

// Get the scan from the executed plan (physical plan)
val executedPlan = df.queryExecution.executedPlan
val batchScans = executedPlan.collect {
case scan: org.apache.spark.sql.execution.datasources.v2.BatchScanExec => scan
}

assert(batchScans.nonEmpty, "Should have a BatchScanExec in executed plan")
val paimonScans = batchScans.filter(_.scan.isInstanceOf[PaimonScan])
assert(paimonScans.nonEmpty, "Should have a PaimonScan in executed plan")

val paimonScan = paimonScans.head.scan.asInstanceOf[PaimonScan]
assert(paimonScan.pushedVectorSearch.isDefined, "Vector search should be pushed down")
assert(paimonScan.pushedVectorSearch.get.fieldName() == "v", "Field name should be 'v'")
assert(paimonScan.pushedVectorSearch.get.limit() == 5, "Limit should be 5")
}
}

test("vector search topk returns correct results") {
withTable("T") {
spark.sql("""
|CREATE TABLE T (id INT, v ARRAY<FLOAT>)
|TBLPROPERTIES (
| 'bucket' = '-1',
| 'global-index.row-count-per-shard' = '10000',
| 'row-tracking.enabled' = 'true',
| 'data-evolution.enabled' = 'true')
|""".stripMargin)

// Insert rows with distinct vectors
val values = (1 to 100)
.map {
i =>
val v = math.sqrt(3.0 * i * i)
val normalized = i.toFloat / v.toFloat
s"($i, array($normalized, $normalized, $normalized))"
}
.mkString(",")
spark.sql(s"INSERT INTO T VALUES $values")

// Create vector index
spark.sql(
"CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")

// Query for top 10 similar to (1, 1, 1) normalized
val result = spark
.sql("""
|SELECT * FROM vector_search('T', 'v', array(0.577f, 0.577f, 0.577f), 10)
|""".stripMargin)
.collect()

assert(result.length == 10)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.paimon.spark
import org.apache.paimon.CoreOptions
import org.apache.paimon.partition.PartitionPredicate
import org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates
import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN}
import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, TopN, VectorSearch}
import org.apache.paimon.table.{SpecialFields, Table}
import org.apache.paimon.types.RowType

Expand Down Expand Up @@ -52,6 +52,7 @@ abstract class PaimonBaseScanBuilder
protected var pushedDataFilters: Array[Predicate] = Array.empty
protected var pushedLimit: Option[Int] = None
protected var pushedTopN: Option[TopN] = None
protected var pushedVectorSearch: Option[VectorSearch] = None

protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.paimon.spark

import org.apache.paimon.CoreOptions.BucketFunctionType
import org.apache.paimon.partition.PartitionPredicate
import org.apache.paimon.predicate.{Predicate, TopN}
import org.apache.paimon.predicate.{Predicate, TopN, VectorSearch}
import org.apache.paimon.spark.commands.BucketExpression.quote
import org.apache.paimon.table.{BucketMode, FileStoreTable, InnerTable}
import org.apache.paimon.table.source.{DataSplit, Split}
Expand All @@ -41,6 +41,7 @@ case class PaimonScan(
pushedDataFilters: Seq[Predicate],
override val pushedLimit: Option[Int],
override val pushedTopN: Option[TopN],
override val pushedVectorSearch: Option[VectorSearch],
bucketedScanDisabled: Boolean = false)
extends PaimonBaseScan(table)
with SupportsReportPartitioning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,26 @@ class PaimonScanBuilder(val table: InnerTable)
localScan match {
case Some(scan) => scan
case None =>
val (actualTable, vectorSearch) = table match {
case vst: org.apache.paimon.table.VectorSearchTable =>
val tableVectorSearch = Option(vst.vectorSearch())
val vs = (tableVectorSearch, pushedVectorSearch) match {
case (Some(_), _) => tableVectorSearch
case (None, Some(_)) => pushedVectorSearch
case (None, None) => None
}
(vst.origin(), vs)
case _ => (table, pushedVectorSearch)
}

PaimonScan(
table,
actualTable,
requiredSchema,
pushedPartitionFilters,
pushedDataFilters,
pushedLimit,
pushedTopN)
pushedTopN,
vectorSearch)
}
}
}
Loading