Skip to content

Commit cd926c3

Browse files
convert vectors from default to vfloat when rebuilding the index
1 parent 85cf5f9 commit cd926c3

File tree

4 files changed

+113
-8
lines changed

4 files changed

+113
-8
lines changed

posting/index.go

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,7 +1373,8 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
13731373
}
13741374

13751375
dimension := indexer.Dimension()
1376-
if dimension == 0 {
1376+
// If dimension is -1, it means that the dimension is not set through options in case of partitioned hnsw.
1377+
if dimension == -1 {
13771378
numVectorsToCheck := 100
13781379
lenFreq := make(map[int]int, numVectorsToCheck)
13791380
maxFreq := 0
@@ -1410,6 +1411,48 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14101411

14111412
fmt.Println("Selecting vector dimension to be:", dimension)
14121413

1414+
norm := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
1415+
norm.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1416+
val, err := pl.Value(rb.StartTs)
1417+
if err != nil {
1418+
return nil, err
1419+
}
1420+
if val.Tid == types.VFloatID {
1421+
return nil, nil
1422+
}
1423+
1424+
// Convert to VFloatID and persist as binary bytes.
1425+
sv, err := types.Convert(val, types.VFloatID)
1426+
if err != nil {
1427+
return nil, err
1428+
}
1429+
b := types.ValueForType(types.BinaryID)
1430+
if err = types.Marshal(sv, &b); err != nil {
1431+
return nil, err
1432+
}
1433+
1434+
edge := &pb.DirectedEdge{
1435+
Attr: rb.Attr,
1436+
Entity: uid,
1437+
Value: b.Value.([]byte),
1438+
ValueType: types.VFloatID.Enum(),
1439+
}
1440+
inKey := x.DataKey(edge.Attr, uid)
1441+
p, err := txn.Get(inKey)
1442+
if err != nil {
1443+
return []*pb.DirectedEdge{}, err
1444+
}
1445+
1446+
if err := p.addMutation(ctx, txn, edge); err != nil {
1447+
return []*pb.DirectedEdge{}, err
1448+
}
1449+
return nil, nil
1450+
}
1451+
1452+
if err := norm.RunWithoutTemp(ctx); err != nil {
1453+
return err
1454+
}
1455+
14131456
count := 0
14141457

14151458
if indexer.NumSeedVectors() > 0 {
@@ -1426,6 +1469,22 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14261469
if err != nil {
14271470
return err
14281471
}
1472+
1473+
if val.Tid != types.VFloatID {
1474+
// Here, we convert the defaultID type vector into vfloat.
1475+
sv, err := types.Convert(val, types.VFloatID)
1476+
if err != nil {
1477+
return err
1478+
}
1479+
b := types.ValueForType(types.BinaryID)
1480+
if err = types.Marshal(sv, &b); err != nil {
1481+
return err
1482+
}
1483+
1484+
val.Value = b.Value
1485+
val.Tid = types.VFloatID
1486+
}
1487+
14291488
inVec := types.BytesAsFloatArray(val.Value.([]byte))
14301489
if len(inVec) != dimension {
14311490
return fmt.Errorf("vector dimension mismatch expected dimension %d but got %d", dimension, len(inVec))
@@ -1464,7 +1523,6 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14641523

14651524
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
14661525
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1467-
edges := []*pb.DirectedEdge{}
14681526
val, err := pl.Value(rb.StartTs)
14691527
if err != nil {
14701528
return []*pb.DirectedEdge{}, err
@@ -1475,7 +1533,7 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14751533
return []*pb.DirectedEdge{}, nil
14761534
}
14771535
indexer.BuildInsert(ctx, uid, inVec)
1478-
return edges, nil
1536+
return []*pb.DirectedEdge{}, nil
14791537
}
14801538

14811539
err := builder.RunWithoutTemp(ctx)
@@ -1519,21 +1577,22 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
15191577

15201578
builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs}
15211579
builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) {
1522-
edges := []*pb.DirectedEdge{}
15231580
val, err := pl.Value(rb.StartTs)
15241581
if err != nil {
15251582
return []*pb.DirectedEdge{}, err
15261583
}
15271584

15281585
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1529-
if len(inVec) != dimension {
1586+
if len(inVec) != dimension && centroids != nil {
15301587
if pass_idx == 0 {
15311588
glog.Warningf("Skipping vector with invalid dimension uid: %d, dimension: %d", uid, len(inVec))
15321589
}
15331590
return []*pb.DirectedEdge{}, nil
15341591
}
1592+
15351593
indexer.BuildInsert(ctx, uid, inVec)
1536-
return edges, nil
1594+
1595+
return []*pb.DirectedEdge{}, nil
15371596
}
15381597

15391598
err := builder.RunWithoutTemp(ctx)

systest/vector/vector_test.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ const (
3131

3232
var schemas = map[string]string{
3333
"hnsw": `project_description_v: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .`,
34-
"partitionedhnsw": `project_description_v: float32vector @index(partionedhnsw(numClusters: "100", partitionStratOpt: "kmeans", vectorDimension: "10", metric: "euclidean")) .`,
34+
"partitionedhnsw": `project_description_v: float32vector @index(partionedhnsw(numClusters: "1000", partitionStratOpt: "kmeans", vectorDimension: "100", metric: "euclidean")) .`,
3535
}
3636

3737
func testVectorQuery(t *testing.T, gc *dgraphapi.GrpcClient, vectors [][]float32, rdfs, pred string, topk int) {
@@ -431,6 +431,48 @@ func (vsuite *VectorTestSuite) TestVectorIndexWithoutSchema() {
431431
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))
432432
}
433433

434+
func (vsuite *VectorTestSuite) TestIndexRebuildingWithoutSchema() {
435+
t := vsuite.T()
436+
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
437+
c, err := dgraphtest.NewLocalCluster(conf)
438+
require.NoError(t, c.Start())
439+
440+
defer func() { c.Cleanup(t.Failed()) }()
441+
442+
gc, cleanup, err := c.Client()
443+
require.NoError(t, err)
444+
defer cleanup()
445+
446+
require.NoError(t, gc.LoginIntoNamespace(context.Background(),
447+
dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace))
448+
449+
require.NoError(t, gc.DropAll())
450+
require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex))
451+
452+
numVectors := 1000
453+
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred)
454+
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
455+
_, err = gc.Mutate(mu)
456+
require.NoError(t, err)
457+
require.NoError(t, gc.SetupSchema(vsuite.schema))
458+
459+
query := `{
460+
vector(func: has(project_description_v)) {
461+
count(uid)
462+
}
463+
}`
464+
465+
result, err := gc.Query(query)
466+
require.NoError(t, err)
467+
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))
468+
469+
for _, vect := range vectors {
470+
similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 100)
471+
require.NoError(t, err)
472+
require.Equal(t, 100, len(similarVects))
473+
}
474+
}
475+
434476
func (vsuite *VectorTestSuite) TestVectorIndexWithoutSchemaWithoutIndex() {
435477
t := vsuite.T()
436478
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)

tok/kmeans/kmeans.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Kmeans[T c.Float] struct {
2525
func CreateKMeans[T c.Float](floatBits int, pred string, distFunc func(a, b []T, floatBits int) (T, error)) index.VectorPartitionStrat[T] {
2626
return &Kmeans[T]{
2727
floatBits: floatBits,
28+
numPasses: 5,
2829
centroids: &vectorCentroids[T]{
2930
distFunc: distFunc,
3031
floatBits: floatBits,
@@ -46,6 +47,9 @@ func (km *Kmeans[T]) GetCentroids() [][]T {
4647
}
4748

4849
func (km *Kmeans[T]) FindIndexForSearch(vec []T) ([]int, error) {
50+
if km.NumPasses() == 0 {
51+
return []int{0}, nil
52+
}
4953
res := make([]int, km.NumSeedVectors())
5054
for i := range res {
5155
res[i] = i

tok/partitioned_hnsw/partitioned_hnsw.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type partitionedHNSW[T c.Float] struct {
3939

4040
func (ph *partitionedHNSW[T]) applyOptions(o opt.Options) error {
4141
ph.numClusters, _, _ = opt.GetOpt(o, NumClustersOpt, 1000)
42-
ph.vectorDimension, _, _ = opt.GetOpt(o, vectorDimension, 0)
42+
ph.vectorDimension, _, _ = opt.GetOpt(o, vectorDimension, -1)
4343
ph.partitionStrat, _, _ = opt.GetOpt(o, PartitionStratOpt, "kmeans")
4444

4545
if ph.partitionStrat != "kmeans" && ph.partitionStrat != "query" {

0 commit comments

Comments
 (0)