Skip to content

Commit 86b2df2

Browse files
added filter for vector lengths
1 parent 91b443b commit 86b2df2

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

posting/index.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,40 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
13731373
return err
13741374
}
13751375

1376+
numVectorsToCheck := 100
1377+
lenFreq := make(map[int]int, numVectorsToCheck)
1378+
maxFreq := 0
1379+
dimension := 0
1380+
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
1381+
Prefix: pk.DataPrefix(),
1382+
ReadTs: rb.StartTs,
1383+
AllVersions: false,
1384+
Reverse: false,
1385+
CheckInclusion: func(uid uint64) error {
1386+
return nil
1387+
},
1388+
Function: func(l *List, pk x.ParsedKey) error {
1389+
val, err := l.Value(rb.StartTs)
1390+
if err != nil {
1391+
return err
1392+
}
1393+
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1394+
lenFreq[len(inVec)] += 1
1395+
if lenFreq[len(inVec)] > maxFreq {
1396+
maxFreq = lenFreq[len(inVec)]
1397+
dimension = len(inVec)
1398+
}
1399+
numVectorsToCheck -= 1
1400+
if numVectorsToCheck <= 0 {
1401+
return ErrStopIteration
1402+
}
1403+
return nil
1404+
},
1405+
StartKey: x.DataKey(rb.Attr, 0),
1406+
})
1407+
1408+
fmt.Println("Selecting vector dimension to be:", dimension)
1409+
13761410
if indexer.NumSeedVectors() > 0 {
13771411
count := 0
13781412
MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{
@@ -1389,6 +1423,9 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
13891423
return err
13901424
}
13911425
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1426+
if len(inVec) != dimension {
1427+
return nil
1428+
}
13921429
count += 1
13931430
indexer.AddSeedVector(inVec)
13941431
if count == indexer.NumSeedVectors() {
@@ -1423,6 +1460,9 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14231460
}
14241461

14251462
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1463+
if len(inVec) != dimension {
1464+
return []*pb.DirectedEdge{}, nil
1465+
}
14261466
indexer.BuildInsert(ctx, uid, inVec)
14271467
return edges, nil
14281468
}
@@ -1449,6 +1489,12 @@ func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSp
14491489
}
14501490

14511491
inVec := types.BytesAsFloatArray(val.Value.([]byte))
1492+
if len(inVec) != dimension {
1493+
if pass_idx == 0 {
1494+
glog.Warningf("Skipping vector with invalid dimension uid: %d, dimension: %d", uid, len(inVec))
1495+
}
1496+
return []*pb.DirectedEdge{}, nil
1497+
}
14521498
indexer.BuildInsert(ctx, uid, inVec)
14531499
return edges, nil
14541500
}

tok/kmeans/kmeans.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
c "github.com/hypermodeinc/dgraph/v25/tok/constraints"
99
"github.com/hypermodeinc/dgraph/v25/tok/index"
10+
"github.com/hypermodeinc/dgraph/v25/x"
1011
)
1112

1213
type Kmeans[T c.Float] struct {
@@ -106,8 +107,13 @@ func (vc *vectorCentroids[T]) addVector(vec []T) error {
106107
}
107108

108109
func (vc *vectorCentroids[T]) updateCentroids() {
110+
x.AssertTrue(len(vc.centroids) == vc.numCenters)
111+
x.AssertTrue(len(vc.counts) == vc.numCenters)
112+
x.AssertTrue(len(vc.weights) == vc.numCenters)
109113
for i := 0; i < vc.numCenters; i++ {
110114
for j := 0; j < vc.dimension; j++ {
115+
x.AssertTrue(len(vc.centroids[i]) == vc.dimension)
116+
x.AssertTrue(len(vc.weights[i]) == vc.dimension)
111117
vc.centroids[i][j] = vc.weights[i][j] / T(vc.counts[i])
112118
vc.weights[i][j] = 0
113119
}
@@ -119,6 +125,9 @@ func (vc *vectorCentroids[T]) updateCentroids() {
119125

120126
func (vc *vectorCentroids[T]) randomInit() {
121127
vc.dimension = len(vc.centroids[0])
128+
for i := range vc.centroids {
129+
x.AssertTrue(len(vc.centroids[i]) == vc.dimension)
130+
}
122131
vc.numCenters = len(vc.centroids)
123132
vc.counts = make([]int64, vc.numCenters)
124133
vc.weights = make([][]T, vc.numCenters)

0 commit comments

Comments
 (0)