Skip to content

Commit 2c88f05

Browse files
committed
Update vector index tests.
1 parent be72c6e commit 2c88f05

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

enginetest/queries/vector_index_queries.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ var VectorIndexQueries = []ScriptTest{
5353
{4, types.MustJSON(`[-2.0, 0.0]`)},
5454
{1, types.MustJSON(`[3.0, 4.0]`)},
5555
},
56-
ExpectedIndexes: []string{""},
56+
ExpectedIndexes: nil,
5757
},
5858
{
5959
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[-2.0,0.0]', v) limit 4",

sql/analyzer/vector_index_test.go

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ func jsonExpression(t *testing.T, val interface{}) sql.Expression {
3838
}
3939

4040
type vectorIndexTestCase struct {
41-
name string
42-
inputPlan sql.Node
43-
expectedPlan string
44-
expectedRows []sql.Row
41+
name string
42+
usesVectorIndex bool
43+
inputPlan sql.Node
44+
expectedPlan string
45+
expectedRows []sql.Row
4546
}
4647

4748
func vectorIndexTestCases(t *testing.T, db *memory.Database, table sql.IndexedTable) []vectorIndexTestCase {
@@ -50,13 +51,9 @@ func vectorIndexTestCases(t *testing.T, db *memory.Database, table sql.IndexedTa
5051
name: "without limit",
5152
inputPlan: plan.NewSort(
5253
sql.SortFields{
53-
{Column: vector.NewDistance(vector.DistanceL2Squared{}, jsonExpression(t, "[0.0, 0.0]"), expression.NewGetField(1, types.JSON, "v", false)), Order: sql.Ascending},
54+
{Column: vector.NewDistance(vector.DistanceL2Squared{}, jsonExpression(t, "[0.0, 0.0]"), expression.NewGetFieldWithTable(2, 1, types.JSON, "", "test-table", "v", false)), Order: sql.Ascending},
5455
}, plan.NewResolvedTable(table, db, nil)),
55-
expectedPlan: `
56-
IndexedTableAccess(test)
57-
├─ index: [v]
58-
└─ order: VEC_DISTANCE_L2_SQUARED([0, 0], v)
59-
`,
56+
usesVectorIndex: false,
6057
expectedRows: []sql.Row{
6158
sql.NewRow(int64(3), jsontests.ConvertToJson(t, "[1.0, 1.0]")),
6259
sql.NewRow(int64(2), jsontests.ConvertToJson(t, "[2.0, 2.0]")),
@@ -67,12 +64,13 @@ IndexedTableAccess(test)
6764
name: "with limit",
6865
inputPlan: plan.NewTopN(
6966
sql.SortFields{
70-
{Column: vector.NewDistance(vector.DistanceL2Squared{}, jsonExpression(t, "[0.0, 0.0]"), expression.NewGetField(1, types.JSON, "v", false)), Order: sql.Ascending},
67+
{Column: vector.NewDistance(vector.DistanceL2Squared{}, jsonExpression(t, "[0.0, 0.0]"), expression.NewGetFieldWithTable(2, 1, types.JSON, "", "test-table", "v", false)), Order: sql.Ascending},
7168
}, expression.NewLiteral(1, types.Int64), plan.NewResolvedTable(table, db, nil)),
69+
usesVectorIndex: true,
7270
expectedPlan: `
7371
IndexedTableAccess(test)
74-
├─ index: [v]
75-
└─ order: VEC_DISTANCE_L2_SQUARED([0, 0], v) LIMIT 1 (bigint)
72+
├─ index: [test-table.v]
73+
└─ order: VEC_DISTANCE_L2_SQUARED([0, 0], test-table.v) LIMIT 1 (bigint)
7674
`,
7775
expectedRows: []sql.Row{
7876
sql.NewRow(int64(3), jsontests.ConvertToJson(t, "[1.0, 1.0]")),
@@ -118,11 +116,14 @@ func TestVectorIndex(t *testing.T) {
118116
t.Run(testCase.name, func(t *testing.T) {
119117
res, same, err := replaceIdxOrderByDistanceHelper(nil, nil, testCase.inputPlan, nil)
120118
require.NoError(t, err)
121-
require.False(t, bool(same))
122-
require.Equal(t,
123-
strings.TrimSpace(testCase.expectedPlan),
124-
strings.TrimSpace(res.String()),
125-
"expected:\n%s,\nfound:\n%s\n", testCase.expectedPlan, res.String())
119+
require.Equal(t, testCase.usesVectorIndex, !bool(same))
120+
res = offsetAssignIndexes(res)
121+
if testCase.usesVectorIndex {
122+
require.Equal(t,
123+
strings.TrimSpace(testCase.expectedPlan),
124+
strings.TrimSpace(res.String()),
125+
"expected:\n%s,\nfound:\n%s\n", testCase.expectedPlan, res.String())
126+
}
126127

127128
iter, err := rowexec.DefaultBuilder.Build(ctx, res, nil)
128129
require.NoError(t, err)
@@ -167,7 +168,7 @@ func TestShowCreateTableWithVectorIndex(t *testing.T) {
167168
"CREATE TABLE `test-table` (\n `pk` int,\n"+
168169
" `v` json,\n"+
169170
" PRIMARY KEY (`pk`),\n"+
170-
" VECTOR KEY `test` (`v`),\n"+
171+
" VECTOR KEY `test` (`v`)\n"+
171172
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin",
172173
)
173174

@@ -205,8 +206,7 @@ func (i vectorIndexTable) Comment() string {
205206
}
206207

207208
func (i vectorIndexTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
208-
//TODO implement me
209-
panic("implement me")
209+
return i.underlying.Partitions(context)
210210
}
211211

212212
func (i vectorIndexTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
@@ -230,9 +230,9 @@ var vectorIndex = memory.Index{
230230
DB: database,
231231
DriverName: "",
232232
Tbl: nil,
233-
TableName: "test",
233+
TableName: "test-table",
234234
Exprs: []sql.Expression{
235-
expression.NewGetField(1, types.JSON, "v", false),
235+
expression.NewGetFieldWithTable(1, 1, types.JSON, "", "test-table", "v", false),
236236
},
237237
Name: "test",
238238
Unique: false,

0 commit comments

Comments
 (0)