@@ -38,10 +38,11 @@ func jsonExpression(t *testing.T, val interface{}) sql.Expression {
3838}
3939
4040type vectorIndexTestCase struct {
41- name string
42- inputPlan sql.Node
43- expectedPlan string
44- expectedRows []sql.Row
41+ name string
42+ usesVectorIndex bool
43+ inputPlan sql.Node
44+ expectedPlan string
45+ expectedRows []sql.Row
4546}
4647
4748func vectorIndexTestCases (t * testing.T , db * memory.Database , table sql.IndexedTable ) []vectorIndexTestCase {
@@ -50,13 +51,9 @@ func vectorIndexTestCases(t *testing.T, db *memory.Database, table sql.IndexedTa
5051 name : "without limit" ,
5152 inputPlan : plan .NewSort (
5253 sql.SortFields {
53- {Column : vector .NewDistance (vector.DistanceL2Squared {}, jsonExpression (t , "[0.0, 0.0]" ), expression .NewGetField ( 1 , types .JSON , "v" , false )), Order : sql .Ascending },
54+ {Column : vector .NewDistance (vector.DistanceL2Squared {}, jsonExpression (t , "[0.0, 0.0]" ), expression .NewGetFieldWithTable ( 2 , 1 , types .JSON , "" , "test-table" , "v" , false )), Order : sql .Ascending },
5455 }, plan .NewResolvedTable (table , db , nil )),
55- expectedPlan : `
56- IndexedTableAccess(test)
57- ├─ index: [v]
58- └─ order: VEC_DISTANCE_L2_SQUARED([0, 0], v)
59- ` ,
56+ usesVectorIndex : false ,
6057 expectedRows : []sql.Row {
6158 sql .NewRow (int64 (3 ), jsontests .ConvertToJson (t , "[1.0, 1.0]" )),
6259 sql .NewRow (int64 (2 ), jsontests .ConvertToJson (t , "[2.0, 2.0]" )),
@@ -67,12 +64,13 @@ IndexedTableAccess(test)
6764 name : "with limit" ,
6865 inputPlan : plan .NewTopN (
6966 sql.SortFields {
70- {Column : vector .NewDistance (vector.DistanceL2Squared {}, jsonExpression (t , "[0.0, 0.0]" ), expression .NewGetField ( 1 , types .JSON , "v" , false )), Order : sql .Ascending },
67+ {Column : vector .NewDistance (vector.DistanceL2Squared {}, jsonExpression (t , "[0.0, 0.0]" ), expression .NewGetFieldWithTable ( 2 , 1 , types .JSON , "" , "test-table" , "v" , false )), Order : sql .Ascending },
7168 }, expression .NewLiteral (1 , types .Int64 ), plan .NewResolvedTable (table , db , nil )),
69+ usesVectorIndex : true ,
7270 expectedPlan : `
7371IndexedTableAccess(test)
74- ├─ index: [v]
75- └─ order: VEC_DISTANCE_L2_SQUARED([0, 0], v) LIMIT 1 (bigint)
72+ ├─ index: [test-table. v]
73+ └─ order: VEC_DISTANCE_L2_SQUARED([0, 0], test-table. v) LIMIT 1 (bigint)
7674` ,
7775 expectedRows : []sql.Row {
7876 sql .NewRow (int64 (3 ), jsontests .ConvertToJson (t , "[1.0, 1.0]" )),
@@ -118,11 +116,14 @@ func TestVectorIndex(t *testing.T) {
118116 t .Run (testCase .name , func (t * testing.T ) {
119117 res , same , err := replaceIdxOrderByDistanceHelper (nil , nil , testCase .inputPlan , nil )
120118 require .NoError (t , err )
121- require .False (t , bool (same ))
122- require .Equal (t ,
123- strings .TrimSpace (testCase .expectedPlan ),
124- strings .TrimSpace (res .String ()),
125- "expected:\n %s,\n found:\n %s\n " , testCase .expectedPlan , res .String ())
119+ require .Equal (t , testCase .usesVectorIndex , ! bool (same ))
120+ res = offsetAssignIndexes (res )
121+ if testCase .usesVectorIndex {
122+ require .Equal (t ,
123+ strings .TrimSpace (testCase .expectedPlan ),
124+ strings .TrimSpace (res .String ()),
125+ "expected:\n %s,\n found:\n %s\n " , testCase .expectedPlan , res .String ())
126+ }
126127
127128 iter , err := rowexec .DefaultBuilder .Build (ctx , res , nil )
128129 require .NoError (t , err )
@@ -167,7 +168,7 @@ func TestShowCreateTableWithVectorIndex(t *testing.T) {
167168 "CREATE TABLE `test-table` (\n `pk` int,\n " +
168169 " `v` json,\n " +
169170 " PRIMARY KEY (`pk`),\n " +
170- " VECTOR KEY `test` (`v`), \n " +
171+ " VECTOR KEY `test` (`v`)\n " +
171172 ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin" ,
172173 )
173174
@@ -205,8 +206,7 @@ func (i vectorIndexTable) Comment() string {
205206}
206207
207208func (i vectorIndexTable ) Partitions (context * sql.Context ) (sql.PartitionIter , error ) {
208- //TODO implement me
209- panic ("implement me" )
209+ return i .underlying .Partitions (context )
210210}
211211
212212func (i vectorIndexTable ) PartitionRows (context * sql.Context , partition sql.Partition ) (sql.RowIter , error ) {
@@ -230,9 +230,9 @@ var vectorIndex = memory.Index{
230230 DB : database ,
231231 DriverName : "" ,
232232 Tbl : nil ,
233- TableName : "test" ,
233+ TableName : "test-table " ,
234234 Exprs : []sql.Expression {
235- expression .NewGetField (1 , types .JSON , "v" , false ),
235+ expression .NewGetFieldWithTable (1 , 1 , types .JSON , "" , "test-table" , "v" , false ),
236236 },
237237 Name : "test" ,
238238 Unique : false ,
0 commit comments