Skip to content

Commit 5380678

Browse files
add vector dimension option and update tests to support both HNSW and PHNSW
… PHNSW indexes
1 parent 86b2df2 commit 5380678

File tree

9 files changed

+144
-72
lines changed

9 files changed

+144
-72
lines changed

posting/index.go

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"fmt"
1515
"math"
1616
"os"
17+
"strconv"
1718
"strings"
1819
"sync/atomic"
1920
"time"
@@ -1373,37 +1374,42 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
13731374
return err
13741375
}
13751376

1376-
numVectorsToCheck := 100
1377-
lenFreq := make(map[int]int, numVectorsToCheck)
1378-
maxFreq := 0
1379-
dimension := 0
1380-
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1381-
Prefix: pk.DataPrefix(),
1382-
ReadTs: rb.StartTs,
1383-
AllVersions: false,
1384-
Reverse: false,
1385-
CheckInclusion: func(uid uint64) error {
1386-
return nil
1387-
},
1388-
Function: func(l *List, pk x.ParsedKey) error {
1389-
val, err := l.Value(rb.StartTs)
1390-
if err != nil {
1391-
return err
1392-
}
1393-
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1394-
lenFreq[len(inVec)] += 1
1395-
if lenFreq[len(inVec)] > maxFreq {
1396-
maxFreq = lenFreq[len(inVec)]
1397-
dimension = len(inVec)
1398-
}
1399-
numVectorsToCheck -= 1
1400-
if numVectorsToCheck <= 0 {
1401-
return ErrStopIteration
1402-
}
1403-
return nil
1404-
},
1405-
StartKey: x.DataKey(rb.Attr, 0),
1406-
})
1377+
dimension := indexer.Dimension()
1378+
if dimension == 0 {
1379+
numVectorsToCheck := 100
1380+
lenFreq := make(map[int]int, numVectorsToCheck)
1381+
maxFreq := 0
1382+
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1383+
Prefix: pk.DataPrefix(),
1384+
ReadTs: rb.StartTs,
1385+
AllVersions: false,
1386+
Reverse: false,
1387+
CheckInclusion: func(uid uint64) error {
1388+
return nil
1389+
},
1390+
Function: func(l *List, pk x.ParsedKey) error {
1391+
val, err := l.Value(rb.StartTs)
1392+
if err != nil {
1393+
return err
1394+
}
1395+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1396+
lenFreq[len(inVec)] += 1
1397+
if lenFreq[len(inVec)] > maxFreq {
1398+
maxFreq = lenFreq[len(inVec)]
1399+
dimension = len(inVec)
1400+
}
1401+
numVectorsToCheck -= 1
1402+
if numVectorsToCheck <= 0 {
1403+
return ErrStopIteration
1404+
}
1405+
return nil
1406+
},
1407+
StartKey: x.DataKey(rb.Attr, 0),
1408+
})
1409+
1410+
indexer.SetDimension(dimension)
1411+
addDimensionOptionInSchema(rb.CurrentSchema, dimension)
1412+
}
14071413

14081414
fmt.Println("Selecting vector dimension to be:", dimension)
14091415

@@ -1648,6 +1654,17 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
16481654
// return nil
16491655
}
16501656

1657+
func addDimensionOptionInSchema(schema *pb.SchemaUpdate, dimension int) {
1658+
for _, vs := range schema.IndexSpecs {
1659+
if vs.Name == "partionedhnsw" {
1660+
vs.Options = append(vs.Options, &pb.OptionPair{
1661+
Key: "dimension",
1662+
Value: strconv.Itoa(dimension),
1663+
})
1664+
}
1665+
}
1666+
}
1667+
16511668
// rebuildTokIndex rebuilds index for a given attribute.
16521669
// We commit mutations with startTs and ignore the errors.
16531670
func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {

schema/parse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ func parseTokenOrVectorIndexSpec(
306306
tokenizer, has := tok.GetTokenizer(tokenOrFactoryName)
307307
if !has {
308308
return tokenOrFactoryName, nil, false,
309-
next.Errorf("Invalid tokenizer 1 %s", next.Val)
309+
next.Errorf("Invalid tokenizer %s", next.Val)
310310
}
311311
tokenizerType, ok := types.TypeForName(tokenizer.Type())
312312
x.AssertTrue(ok) // Type is validated during tokenizer loading.

systest/vector/backup_test.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"fmt"
1313
"slices"
1414
"strings"
15-
"testing"
1615
"time"
1716

1817
"github.com/stretchr/testify/require"
@@ -23,7 +22,8 @@ import (
2322
"github.com/hypermodeinc/dgraph/v25/x"
2423
)
2524

26-
func TestVectorIncrBackupRestore(t *testing.T) {
25+
func (vsuite *VectorTestSuite) TestVectorIncrBackupRestore() {
26+
t := vsuite.T()
2727
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
2828
c, err := dgraphtest.NewLocalCluster(conf)
2929
require.NoError(t, err)
@@ -41,7 +41,7 @@ func TestVectorIncrBackupRestore(t *testing.T) {
4141
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
4242
dgraphapi.DefaultPassword, x.RootNamespace))
4343

44-
require.NoError(t, gc.SetupSchema(testSchema))
44+
require.NoError(t, gc.SetupSchema(vsuite.schema))
4545

4646
numVectors := 500
4747
pred := "project_description_v"
@@ -100,7 +100,8 @@ func TestVectorIncrBackupRestore(t *testing.T) {
100100
}
101101
}
102102

103-
func TestVectorBackupRestore(t *testing.T) {
103+
func (vsuite *VectorTestSuite) TestVectorBackupRestore() {
104+
t := vsuite.T()
104105
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
105106
c, err := dgraphtest.NewLocalCluster(conf)
106107
require.NoError(t, err)
@@ -118,7 +119,7 @@ func TestVectorBackupRestore(t *testing.T) {
118119
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
119120
dgraphapi.DefaultPassword, x.RootNamespace))
120121

121-
require.NoError(t, gc.SetupSchema(testSchema))
122+
require.NoError(t, gc.SetupSchema(vsuite.schema))
122123

123124
numVectors := 1000
124125
pred := "project_description_v"
@@ -138,7 +139,8 @@ func TestVectorBackupRestore(t *testing.T) {
138139
testVectorQuery(t, gc, vectors, rdfs, pred, numVectors)
139140
}
140141

141-
func TestVectorBackupRestoreDropIndex(t *testing.T) {
142+
func (vsuite *VectorTestSuite) TestVectorBackupRestoreDropIndex() {
143+
t := vsuite.T()
142144
// setup cluster
143145
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
144146
c, err := dgraphtest.NewLocalCluster(conf)
@@ -158,7 +160,7 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) {
158160
dgraphapi.DefaultPassword, x.RootNamespace))
159161

160162
// add vector predicate + index
161-
require.NoError(t, gc.SetupSchema(testSchema))
163+
require.NoError(t, gc.SetupSchema(vsuite.schema))
162164
// add data to the vector predicate
163165
numVectors := 3
164166
pred := "project_description_v"
@@ -195,7 +197,7 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) {
195197
require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir))
196198

197199
// add index
198-
require.NoError(t, gc.SetupSchema(testSchema))
200+
require.NoError(t, gc.SetupSchema(vsuite.schema))
199201

200202
t.Log("taking second incr backup \n")
201203
require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir))
@@ -227,7 +229,8 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) {
227229
}
228230
}
229231

230-
func TestVectorBackupRestoreReIndexing(t *testing.T) {
232+
func (vsuite *VectorTestSuite) TestVectorBackupRestoreReIndexing() {
233+
t := vsuite.T()
231234
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
232235
c, err := dgraphtest.NewLocalCluster(conf)
233236
require.NoError(t, err)
@@ -245,7 +248,7 @@ func TestVectorBackupRestoreReIndexing(t *testing.T) {
245248
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
246249
dgraphapi.DefaultPassword, x.RootNamespace))
247250

248-
require.NoError(t, gc.SetupSchema(testSchema))
251+
require.NoError(t, gc.SetupSchema(vsuite.schema))
249252

250253
numVectors := 1000
251254
pred := "project_description_v"
@@ -271,7 +274,7 @@ func TestVectorBackupRestoreReIndexing(t *testing.T) {
271274
// drop index
272275
require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex))
273276
// add index
274-
require.NoError(t, gc.SetupSchema(testSchema))
277+
require.NoError(t, gc.SetupSchema(vsuite.schema))
275278
}
276279
vectors = append(vectors, vectors2...)
277280
rdfs = rdfs + rdfs2

systest/vector/load_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,18 @@ type Node struct {
2727
Vtest []float32 `json:"vtest"`
2828
}
2929

30-
func TestLiveLoadAndExportRDFFormat(t *testing.T) {
30+
func (vsuite *VectorTestSuite) TestLiveLoadAndExportRDFFormat() {
31+
t := vsuite.T()
3132
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
3233
c, err := dgraphtest.NewLocalCluster(conf)
3334
require.NoError(t, err)
3435
defer func() { c.Cleanup(t.Failed()) }()
3536
require.NoError(t, c.Start())
3637

37-
testExportAndLiveLoad(t, c, "rdf")
38+
testExportAndLiveLoad(t, c, "rdf", vsuite.schema)
3839
}
3940

40-
func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportFormat string) {
41+
func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportFormat string, schema string) {
4142
gc, cleanup, err := c.Client()
4243
require.NoError(t, err)
4344
defer cleanup()
@@ -49,7 +50,7 @@ func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportForma
4950
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
5051
dgraphapi.DefaultPassword, x.RootNamespace))
5152

52-
require.NoError(t, gc.SetupSchema(testSchema))
53+
require.NoError(t, gc.SetupSchema(schema))
5354

5455
numVectors := 100
5556
pred := "project_description_v"

0 commit comments

Comments
 (0)