Skip to content

Commit cea09ad

Browse files
store vectors in 0th HNSW partition if count < 1000 and store dimension in specs
1 parent 5380678 commit cea09ad

File tree

6 files changed

+293
-18
lines changed

6 files changed

+293
-18
lines changed

posting/index.go

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"github.com/hypermodeinc/dgraph/v25/tok"
3636
"github.com/hypermodeinc/dgraph/v25/tok/hnsw"
3737
tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index"
38+
"github.com/hypermodeinc/dgraph/v25/tok/kmeans"
3839

3940
"github.com/hypermodeinc/dgraph/v25/types"
4041
"github.com/hypermodeinc/dgraph/v25/x"
@@ -1364,8 +1365,6 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) {
13641365
return prefixes, nil
13651366
}
13661367

1367-
const numCentroids = 1000
1368-
13691368
func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error {
13701369
pk := x.ParsedKey{Attr: rb.Attr}
13711370

@@ -1413,9 +1412,10 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14131412

14141413
fmt.Println("Selecting vector dimension to be:", dimension)
14151414

1415+
count := 0
1416+
14161417
if indexer.NumSeedVectors() > 0 {
1417-
count := 0
1418-
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1418+
err := MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
14191419
Prefix: pk.DataPrefix(),
14201420
ReadTs: rb.StartTs,
14211421
AllVersions: false,
@@ -1430,7 +1430,7 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14301430
}
14311431
inVec := types.BytesAsFloatArray(val.Value.([]byte))
14321432
if len(inVec) != dimension {
1433-
return nil
1433+
return fmt.Errorf("vector dimension mismatch expected dimension %d but got %d", dimension, len(inVec))
14341434
}
14351435
count += 1
14361436
indexer.AddSeedVector(inVec)
@@ -1441,6 +1441,9 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14411441
},
14421442
StartKey: x.DataKey(rb.Attr, 0),
14431443
})
1444+
if err != nil {
1445+
return err
1446+
}
14441447
}
14451448

14461449
txns := make([]*Txn, indexer.NumThreads())
@@ -1452,6 +1455,11 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14521455
caches[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs)
14531456
}
14541457

1458+
if count < indexer.NumSeedVectors() {
1459+
indexer.SetNumPasses(0)
1460+
1461+
}
1462+
14551463
for pass_idx := range indexer.NumBuildPasses() {
14561464
fmt.Println("Building pass", pass_idx)
14571465

@@ -1481,7 +1489,30 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14811489
indexer.EndBuild()
14821490
}
14831491

1484-
for pass_idx := range indexer.NumIndexPasses() {
1492+
txn := NewTxn(rb.StartTs)
1493+
centroids := indexer.GetCentroids()
1494+
1495+
bCentroids, err := json.Marshal(centroids)
1496+
if err != nil {
1497+
return err
1498+
}
1499+
1500+
if err := addCentroidInDB(ctx, rb.Attr, bCentroids, txn); err != nil {
1501+
return err
1502+
}
1503+
txn.Update()
1504+
writer := NewTxnWriter(pstore)
1505+
if err := txn.CommitToDisk(writer, rb.StartTs); err != nil {
1506+
return err
1507+
}
1508+
1509+
numIndexPasses := indexer.NumIndexPasses()
1510+
1511+
if count < indexer.NumSeedVectors() {
1512+
numIndexPasses = 1
1513+
}
1514+
1515+
for pass_idx := range numIndexPasses {
14851516
fmt.Println("Indexing pass", pass_idx)
14861517

14871518
indexer.StartBuild(caches)
@@ -1654,11 +1685,31 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
16541685
// return nil
16551686
}
16561687

1688+
func addCentroidInDB(ctx context.Context, attr string, vec []byte, txn *Txn) error {
1689+
indexCountAttr := hnsw.ConcatStrings(attr, kmeans.CentroidPrefix)
1690+
countKey := x.DataKey(indexCountAttr, 1)
1691+
pl, err := txn.Get(countKey)
1692+
if err != nil {
1693+
return err
1694+
}
1695+
1696+
edge := &pb.DirectedEdge{
1697+
Entity: 1,
1698+
Attr: indexCountAttr,
1699+
Value: vec,
1700+
ValueType: pb.Posting_ValType(12),
1701+
}
1702+
if err := pl.addMutation(ctx, txn, edge); err != nil {
1703+
return err
1704+
}
1705+
return nil
1706+
}
1707+
16571708
func addDimensionOptionInSchema(schema *pb.SchemaUpdate, dimension int) {
16581709
for _, vs := range schema.IndexSpecs {
16591710
if vs.Name == "partionedhnsw" {
16601711
vs.Options = append(vs.Options, &pb.OptionPair{
1661-
Key: "dimension",
1712+
Key: "vectorDimension",
16621713
Value: strconv.Itoa(dimension),
16631714
})
16641715
}

systest/vector/vector_test.go

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ const (
3131

3232
var schemas = map[string]string{
3333
"hnsw": `project_description_v: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .`,
34-
"partitionedhnsw": `project_description_v: float32vector @index(partitionedhnsw(numClusters: "100", partitionStratOpt: "kmeans", vectorDimention: "100", metric: "euclidean")) .`,
34+
"partitionedhnsw": `project_description_v: float32vector @index(partionedhnsw(numClusters: "100", partitionStratOpt: "kmeans", vectorDimension: "10", metric: "euclidean")) .`,
3535
}
3636

3737
func testVectorQuery(t *testing.T, gc *dgraphapi.GrpcClient, vectors [][]float32, rdfs, pred string, topk int) {
@@ -69,7 +69,7 @@ func (vsuite *VectorTestSuite) TestVectorDropAll() {
6969
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
7070
dgraphapi.DefaultPassword, x.RootNamespace))
7171

72-
numVectors := 100
72+
numVectors := 10
7373

7474
testVectorSimilarTo := func(vectors [][]float32) {
7575
for _, vector := range vectors {
@@ -475,16 +475,118 @@ func (vsuite *VectorTestSuite) TestVectorIndexWithoutSchemaWithoutIndex() {
475475
require.NoError(t, err)
476476
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))
477477
}
478+
func (vsuite *VectorTestSuite) TestPartitionedHNSWIndex() {
479+
if !vsuite.isForPartitionedIndex {
480+
return
481+
}
482+
t := vsuite.T()
483+
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1)
484+
c, err := dgraphtest.NewLocalCluster(conf)
485+
486+
require.NoError(t, err)
487+
defer func() { c.Cleanup(t.Failed()) }()
488+
require.NoError(t, c.Start())
489+
490+
gc, cleanup, err := c.Client()
491+
defer cleanup()
492+
require.NoError(t, err)
493+
494+
schemaWithoutIndex := `project_description_v: float32vector .`
495+
pred := "project_description_v"
496+
schemaWithIndex := `project_description_v: float32vector @index(partionedhnsw` +
497+
`(numClusters:"1000", partitionStratOpt: "kmeans",metric: "euclidean",vectorDimension: "10")) .`
498+
499+
t.Run("with more than 1000 vectors", func(t *testing.T) {
500+
require.NoError(t, gc.DropAll())
501+
502+
numVectors := 5000
503+
504+
require.NoError(t, gc.SetupSchema(schemaWithoutIndex))
505+
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred)
506+
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
507+
_, err = gc.Mutate(mu)
508+
require.NoError(t, err)
509+
510+
err = gc.SetupSchema(schemaWithIndex)
511+
require.NoError(t, err)
512+
513+
testVectorQuery(t, gc, vectors, rdfs, pred, 5)
514+
})
515+
516+
t.Run("without providing vector dimension", func(t *testing.T) {
517+
require.NoError(t, gc.DropAll())
518+
519+
numVectors := 1001
520+
521+
require.NoError(t, gc.SetupSchema(schemaWithoutIndex))
522+
523+
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred)
524+
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
525+
_, err = gc.Mutate(mu)
526+
require.NoError(t, err)
527+
528+
s := `project_description_v: float32vector @index(partionedhnsw` +
529+
`(numClusters:"1000", partitionStratOpt: "kmeans",metric: "euclidean")) .`
530+
err = gc.SetupSchema(s)
531+
require.NoError(t, err)
532+
533+
testVectorQuery(t, gc, vectors, rdfs, pred, 1000)
534+
})
535+
536+
t.Run("with less than 1000 vectors", func(t *testing.T) {
537+
require.NoError(t, gc.DropAll())
538+
numVectors := 100
539+
require.NoError(t, gc.SetupSchema(schemaWithoutIndex))
540+
541+
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred)
542+
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
543+
_, err = gc.Mutate(mu)
544+
require.NoError(t, err)
545+
546+
err = gc.SetupSchema(schemaWithIndex)
547+
require.NoError(t, err)
548+
549+
testVectorQuery(t, gc, vectors, rdfs, pred, numVectors)
550+
})
551+
552+
t.Run("with different length of vectors", func(t *testing.T) {
553+
require.NoError(t, gc.DropAll())
554+
numVectors := 1100
555+
require.NoError(t, gc.SetupSchema(schemaWithoutIndex))
556+
557+
q := `schema {}`
558+
result, err := gc.Query(q)
559+
require.NoError(t, err)
560+
561+
rdfs, _ := dgraphapi.GenerateRandomVectors(0, numVectors, 8, pred)
562+
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
563+
_, err = gc.Mutate(mu)
564+
require.NoError(t, err)
565+
566+
err = gc.SetupSchema(schemaWithIndex)
567+
require.NoError(t, err)
568+
569+
// here check schema it should not be changed
570+
q = `schema {}`
571+
result1, err := gc.Query(q)
572+
require.NoError(t, err)
573+
require.JSONEq(t, string(result.GetJson()), string(result1.GetJson()))
574+
})
575+
}
478576

479577
type VectorTestSuite struct {
480578
suite.Suite
481-
schema string
579+
schema string
580+
isForPartitionedIndex bool
482581
}
483582

484583
func TestVectorSuite(t *testing.T) {
485584
for _, schema := range schemas {
486585
var ssuite VectorTestSuite
487586
ssuite.schema = schema
587+
if strings.Contains(schema, "partionedhnsw") {
588+
ssuite.isForPartitionedIndex = true
589+
}
488590
suite.Run(t, &ssuite)
489591
if t.Failed() {
490592
x.Panic(errors.New("vector tests failed"))

tok/hnsw/persistent_hnsw.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ func (ph *persistentHNSW[T]) NumBuildPasses() int {
117117
return 0
118118
}
119119

120+
func (ph *persistentHNSW[T]) SetNumPasses(int) {
121+
return
122+
}
123+
120124
func (ph *persistentHNSW[T]) Dimension() int {
121125
return 0
122126
}
@@ -317,6 +321,13 @@ type resultRow[T c.Float] struct {
317321
dist T
318322
}
319323

324+
// MergeResults takes a list of UIDs and returns the maxResults nearest neighbors
325+
// in order of increasing distance. It returns an error if any of the UIDs are
326+
// not present in the index.
327+
//
328+
// The filter parameter is not used by this method.
329+
//
330+
// This method is part of the index.MultipleIndex interface.
320331
func (ph *persistentHNSW[T]) MergeResults(ctx context.Context, c index.CacheType, list []uint64, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) {
321332
var result []resultRow[T]
322333

@@ -342,6 +353,7 @@ func (ph *persistentHNSW[T]) MergeResults(ctx context.Context, c index.CacheType
342353
})
343354

344355
uids := []uint64{}
356+
// out of range error
345357
for i := range maxResults {
346358
if i > len(result) {
347359
break
@@ -498,6 +510,9 @@ func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType,
498510
_, edges, err := ph.insertHelper(ctx, tc, inUuid, inVec)
499511
return edges, err
500512
}
513+
func (ph *persistentHNSW[T]) GetCentroids() [][]T {
514+
return nil
515+
}
501516

502517
// InsertToPersistentStorage inserts a node into the hnsw graph and returns the
503518
// traversal path and the edges created

tok/index/index.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,13 @@ type VectorPartitionStrat[T c.Float] interface {
9393
FindIndexForSearch(vec []T) ([]int, error)
9494
FindIndexForInsert(vec []T) (int, error)
9595
NumPasses() int
96+
SetNumPasses(int)
9697
NumSeedVectors() int
9798
StartBuildPass()
9899
EndBuildPass()
99100
AddSeedVector(vec []T)
100101
AddVector(vec []T) error
102+
GetCentroids() [][]T
101103
}
102104

103105
// A VectorIndex can be used to Search for vectors and add vectors to an index.
@@ -132,8 +134,10 @@ type VectorIndex[T c.Float] interface {
132134
Insert(ctx context.Context, c CacheType, uuid uint64, vec []T) ([]*KeyValue, error)
133135

134136
BuildInsert(ctx context.Context, uuid uint64, vec []T) error
137+
GetCentroids() [][]T
135138
AddSeedVector(vec []T)
136139
NumBuildPasses() int
140+
SetNumPasses(int)
137141
NumIndexPasses() int
138142
NumSeedVectors() int
139143
StartBuild(caches []CacheType)

0 commit comments

Comments
 (0)